其他分享
首页 > 其他分享> > 从零写CRNN文字识别 —— (6)训练

从零写CRNN文字识别 —— (6)训练

作者:互联网

前言

完整代码已经上传github:https://github.com/xmy0916/pytorch_crnn

训练

训练部分的代码逻辑如下:

for epoch in range(total_epoch):
  for data in dataloader:
    数据输入模型(前馈)
    根据输出计算loss
    loss反馈更新网络参数
  if epoch % eval_epoch == 0:
    评估数据输入模型(前馈)
    根据输出计算loss
    解码输出计算识别准确率
    if now_acc > best_acc:
      保存模型

对应的完整代码如下:

# 训练
    best_acc = 0.0
    for epoch in range(last_epoch,config.TRAIN.END_EPOCH):
      model.train()
      for i, (inp, idx) in enumerate(train_loader):
          # 前馈
          inp = inp.to(device)
          preds = model(inp).to(device)
          # 计算loss
          labels = get_batch_label(train_dataset, idx)
          batch_size = inp.size(0)
          text, length = encode(config.DICT,labels)
          preds_size = torch.IntTensor([preds.size(0)] * batch_size)
          loss = criterion(preds, text, preds_size, length)
          # 反馈
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          if i % config.PRINT_FREQ == 0:
            print("epoch:{} step:{} loss:{} lr:{}".format(epoch,i,loss.item(),lr_scheduler.get_lr()))
      # 每个epoch更新学习率
      lr_scheduler.step()

      # 每EVAL_FREQ评估一次并保存best模型
      if epoch % config.EVAL_FREQ == 0:
          model.eval()
          n_correct = 0
          test_num = len(val_loader) * config.TEST.BATCH_SIZE_PER_GPU
          with torch.no_grad():
              for i, (inp, idx) in enumerate(val_loader):
                  # 计算前馈
                  inp = inp.to(device)
                  preds = model(inp).cpu()
                  # 计算loss
                  labels = get_batch_label(val_dataset, idx)
                  batch_size = inp.size(0)
                  text, length = encode(config.DICT,labels)
                  preds_size = torch.IntTensor([preds.size(0)] * batch_size)
                  loss = criterion(preds, text, preds_size, length)
                  # 后处理解码
                  print("网络输出的preds的shape:",preds.cpu().detach().shape)
                  _, preds = preds.max(2)
                  print("max(2)的shape:",preds.cpu().detach().shape)
                  preds = preds.transpose(1, 0).contiguous().view(-1)
                  print("transpose的shape:",preds.cpu().detach().shape)
                  sim_preds = decode(preds.data, preds_size.data, config.DICT,raw=False)
                  for pred, target in zip(sim_preds, labels):
                    if pred == target:
                      n_correct += 1

              
          # 抓一个batch来显示
          raw_preds = decode(preds.data, preds_size.data, config.DICT, raw=True)[:config.TEST.NUM_TEST_DISP]
          for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
              print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
          print("preds:",preds.cpu().detach().numpy())
          print("preds_shape:",preds.cpu().detach().shape)
          print("dict:",config.DICT)
          now_acc = n_correct * 1.0 / test_num
          print("best_acc:{} correct:{}".format(now_acc,n_correct))
          if now_acc >= best_acc:
              torch.save(
                    {
                        "state_dict": model.state_dict(),
                        "epoch": epoch + 1,
                        # "optimizer": optimizer.state_dict(),
                        # "lr_scheduler": lr_scheduler.state_dict(),
                        "best_acc": best_acc,
                    },  os.path.join(config.OUTPUT_DIR, "checkpoint_{}_acc_{:.4f}.pth".format(epoch, now_acc)))
              best_acc = now_acc
              print("save_model!")

看看评估过程(摘一段代码出来):

preds = model(inp).cpu()
# 计算loss
labels = get_batch_label(val_dataset, idx)
batch_size = inp.size(0)
text, length = encode(config.DICT,labels)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
loss = criterion(preds, text, preds_size, length)
# 后处理解码
print("网络输出的preds的shape:",preds.cpu().detach().shape)
_, preds = preds.max(2)
print("max(2)的shape:",preds.cpu().detach().shape)
preds = preds.transpose(1, 0).contiguous().view(-1)
print("transpose的shape:",preds.cpu().detach().shape)

打印结果:
在这里插入图片描述
稍微解释下:
preds的shape[41,16,109]:

preds.max(2)得到了从属于那一类的向量,2表示在109的纬度上取所以输出的shape是[41,16]
transpose是把二维向量拉平,656=41*16
这里注意一点,测试的时候每个batch_size是16,但是我们数据集不一定是16的整数倍,所以最后一个batch的大小不一定有16,例如我们这里最后一个batch的大小是14:
在这里插入图片描述
在代码中我将最后一个batch的测试图片可视化的打印了,结果如下:
在这里插入图片描述
这是第一个epoch训练的输出,
在这里插入图片描述
上图的横杠是设置的空字符的占位符,在config/config.yml中设置这个字符BLANK_CHAR
在这里插入图片描述
上图一共574个0,574 = 41 * 14因为是最后一个batch所以不够16个,上图理论上可以解码成574个字符,因为这是第一个epoch训练的结果,网络参数基本不对所以没有输出。

第16个epoch输出如下:
在这里插入图片描述
第一行的37这个值就是dict中L的位置
在这里插入图片描述

标签:acc,batch,epoch,preds,shape,CRNN,零写,识别,size
来源: https://blog.csdn.net/qq_37668436/article/details/113794325