从零写CRNN文字识别 —— (6)训练
作者:互联网
前言
完整代码已经上传github:https://github.com/xmy0916/pytorch_crnn
训练
训练部分的代码逻辑如下:
for epoch in range(total_epoch):
for data in dataloader:
数据输入模型(前馈)
根据输出计算loss
loss反馈更新网络参数
if epoch % eval_epoch == 0:
评估数据输入模型(前馈)
根据输出计算loss
解码输出计算识别准确率
if now_acc > best_acc:
保存模型
对应的完整代码如下:
# 训练
best_acc = 0.0
for epoch in range(last_epoch,config.TRAIN.END_EPOCH):
model.train()
for i, (inp, idx) in enumerate(train_loader):
# 前馈
inp = inp.to(device)
preds = model(inp).to(device)
# 计算loss
labels = get_batch_label(train_dataset, idx)
batch_size = inp.size(0)
text, length = encode(config.DICT,labels)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
loss = criterion(preds, text, preds_size, length)
# 反馈
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % config.PRINT_FREQ == 0:
print("epoch:{} step:{} loss:{} lr:{}".format(epoch,i,loss.item(),lr_scheduler.get_lr()))
# 每个epoch更新学习率
lr_scheduler.step()
# 每EVAL_FREQ评估一次并保存best模型
if epoch % config.EVAL_FREQ == 0:
model.eval()
n_correct = 0
test_num = len(val_loader) * config.TEST.BATCH_SIZE_PER_GPU
with torch.no_grad():
for i, (inp, idx) in enumerate(val_loader):
# 计算前馈
inp = inp.to(device)
preds = model(inp).cpu()
# 计算loss
labels = get_batch_label(val_dataset, idx)
batch_size = inp.size(0)
text, length = encode(config.DICT,labels)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
loss = criterion(preds, text, preds_size, length)
# 后处理解码
print("网络输出的preds的shape:",preds.cpu().detach().shape)
_, preds = preds.max(2)
print("max(2)的shape:",preds.cpu().detach().shape)
preds = preds.transpose(1, 0).contiguous().view(-1)
print("transpose的shape:",preds.cpu().detach().shape)
sim_preds = decode(preds.data, preds_size.data, config.DICT,raw=False)
for pred, target in zip(sim_preds, labels):
if pred == target:
n_correct += 1
# 抓一个batch来显示
raw_preds = decode(preds.data, preds_size.data, config.DICT, raw=True)[:config.TEST.NUM_TEST_DISP]
for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
print("preds:",preds.cpu().detach().numpy())
print("preds_shape:",preds.cpu().detach().shape)
print("dict:",config.DICT)
now_acc = n_correct * 1.0 / test_num
print("best_acc:{} correct:{}".format(now_acc,n_correct))
if now_acc >= best_acc:
torch.save(
{
"state_dict": model.state_dict(),
"epoch": epoch + 1,
# "optimizer": optimizer.state_dict(),
# "lr_scheduler": lr_scheduler.state_dict(),
"best_acc": best_acc,
}, os.path.join(config.OUTPUT_DIR, "checkpoint_{}_acc_{:.4f}.pth".format(epoch, now_acc)))
best_acc = now_acc
print("save_model!")
看看评估过程(摘一段代码出来):
preds = model(inp).cpu()
# 计算loss
labels = get_batch_label(val_dataset, idx)
batch_size = inp.size(0)
text, length = encode(config.DICT,labels)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
loss = criterion(preds, text, preds_size, length)
# 后处理解码
print("网络输出的preds的shape:",preds.cpu().detach().shape)
_, preds = preds.max(2)
print("max(2)的shape:",preds.cpu().detach().shape)
preds = preds.transpose(1, 0).contiguous().view(-1)
print("transpose的shape:",preds.cpu().detach().shape)
打印结果:
稍微解释下:
preds的shape[41,16,109]:
- 41是卷积后的长度
- 16是测试时的batch_size大小
- 109是字典的类别数
preds.max(2)得到了从属于那一类的向量,2表示在109的纬度上取所以输出的shape是[41,16]
transpose是把二维向量拉平,656=41*16
这里注意一点,测试的时候每个batch_size是16,但是我们数据集不一定是16的整数倍,所以最后一个batch的大小不一定有16,例如我们这里最后一个batch的大小是14:
在代码中我将最后一个batch的测试图片可视化的打印了,结果如下:
这是第一个epoch训练的输出,
上图的横杠是设置的空字符的占位符,在config/config.yml中设置这个字符BLANK_CHAR
上图一共574个0,574 = 41 * 14因为是最后一个batch所以不够16个,上图理论上可以解码成574个字符,因为这是第一个epoch训练的结果,网络参数基本不对所以没有输出。
第16个epoch输出如下:
第一行的37这个值就是dict中L的位置
标签:acc,batch,epoch,preds,shape,CRNN,零写,识别,size 来源: https://blog.csdn.net/qq_37668436/article/details/113794325