深度学习项目中_5-训练模块的编写
作者:互联网
1.训练模块中的步骤
训练模块一般是保存在train.py 的文件中, 该模块中一般包含以下步骤:
-
导入各类模块, (标准库, 三方库, cv, torch, torchvision), 如果在model.py ( 自己定义网络模型文件), loss.py (自定义的损失函数), utils.py( 自定义的各种方法), config.py(整体项目的配置文件), 这些模块都需要导入
-
命令行解析
-
数据集加载
-
检测模型保存地址是否存在, 如果不存在则创建;
-
实例化网络模型;
-
实例化 损失函数和优化器
-
准备事件 文件, 方便 Tensorboard --logdir=“run”, 可视化训练过程;
-
检查是否采用,接着上一次的检查点checkpoint 训练, 若是加载chepoint. 模型;
-
开始训练, 循环epochs:
– 将梯度置零;
– 求 loss;
– 反向传播;
– 更新权重参数;
– 更新优化器中的学习率(可选) -
可视化指标;
-
验证valid 模型, (根据模型在验证集上损失和度量, 调整模型的超参数)
2. 训练模块的代码实例
2.1 训练模块 示范一
import os
import torch
from torch.utils.data import DataLoader
from torch import nn
import argparse
from tensorboardX import SummaryWriter
from data_preparation.data_preparation import FileDateset
from model.Baseline import Base_model
from model.ops import pytorch_LSD
def parse_args():
parser = argparse.ArgumentParser()
# 重头开始训练 defaule=None, 继续训练defaule设置为'/**.pth'
parser.add_argument("--model_name", type=str, default=None, help="是否加载模型继续训练 '/50.pth' None")
parser.add_argument("--batch-size", type=int, default=16, help="")
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument('--lr', type=float, default=3e-4, help='学习率 (default: 0.01)')
parser.add_argument('--train_data', default="./data_preparation/Synthetic/TRAIN", help='数据集的path')
parser.add_argument('--val_data', default="./data_preparation/Synthetic/VAL", help='验证样本的path')
parser.add_argument('--checkpoints_dir', default="./checkpoints/AEC_baseline", help='模型检查点文件的路径(以继续培训)')
parser.add_argument('--event_dir', default="./event_file/AEC_baseline", help='tensorboard事件文件的地址')
args = parser.parse_args()
return args
def main():
args = parse_args()
print("GPU是否可用:", torch.cuda.is_available()) # True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 实例化 Dataset
train_set = FileDateset(dataset_path=args.train_data) # 实例化训练数据集
val_set = FileDateset(dataset_path=args.val_data) # 实例化验证数据集
# 数据加载器
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, drop_last=True)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=True)
# ########### 保存检查点的地址(如果检查点不存在,则创建) ############
if not os.path.exists(args.checkpoints_dir):
os.makedirs(args.checkpoints_dir)
################################
# 实例化模型 #
################################
model = Base_model().to(device) # 实例化模型
# summary(model, input_size=(322, 999)) # 模型输出 torch.Size([64, 322, 999])
# ########### 损失函数 ############
criterion = nn.MSELoss(reduce=True, size_average=True, reduction='mean')
###############################
# 创建优化器 Create optimizers #
###############################
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, )
# lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20], gamma=0.1)
# ########### TensorBoard可视化 summary ############
writer = SummaryWriter(args.event_dir) # 创建事件文件
# ########### 加载模型检查点 ############
start_epoch = 0
if args.model_name:
print("加载模型:", args.checkpoints_dir + args.model_name)
checkpoint = torch.load(args.checkpoints_dir + args.model_name)
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
start_epoch = checkpoint['epoch']
# lr_schedule.load_state_dict(checkpoint['lr_schedule']) # 加载lr_scheduler
for epoch in range(start_epoch, args.epochs):
model.train() # 训练模型
for batch_idx, (train_X, train_mask, train_nearend_mic_magnitude, train_nearend_magnitude) in enumerate(
train_loader):
train_X = train_X.to(device) # 远端语音cat麦克风语音 [batch_size, 322, 999] (, F, T)
train_mask = train_mask.to(device) # IRM [batch_size 161, 999]
train_nearend_mic_magnitude = train_nearend_mic_magnitude.to(device)
train_nearend_magnitude = train_nearend_magnitude.to(device)
# 前向传播
pred_mask = model(train_X) # [batch_size, 322, 999]--> [batch_size, 161, 999]
train_loss = criterion(pred_mask, train_mask)
# 近端语音信号频谱 = mask * 麦克风信号频谱 [batch_size, 161, 999]
pred_near_spectrum = pred_mask * train_nearend_mic_magnitude
train_lsd = pytorch_LSD(train_nearend_magnitude, pred_near_spectrum)
# 反向传播
optimizer.zero_grad() # 将梯度清零
train_loss.backward() # 反向传播
optimizer.step() # 更新参数
# ########### 可视化打印 ############
print('Train Epoch: {} Loss: {:.6f} LSD: {:.6f}'.format(epoch + 1, train_loss.item(), train_lsd.item()))
# ########### TensorBoard可视化 summary ############
# lr_schedule.step() # 学习率衰减
# writer.add_scalar(tag="lr", scalar_value=model.state_dict()['param_groups'][0]['lr'], global_step=epoch + 1)
writer.add_scalar(tag="train_loss", scalar_value=train_loss.item(), global_step=epoch + 1)
writer.add_scalar(tag="train_lsd", scalar_value=train_lsd.item(), global_step=epoch + 1)
writer.flush()
# 神经网络在验证数据集上的表现
model.eval() # 测试模型
# 测试的时候不需要梯度
with torch.no_grad():
for val_batch_idx, (val_X, val_mask, val_nearend_mic_magnitude, val_nearend_magnitude) in enumerate(
val_loader):
val_X = val_X.to(device) # 远端语音cat麦克风语音 [batch_size, 322, 999] (, F, T)
val_mask = val_mask.to(device) # IRM [batch_size 161, 999]
val_nearend_mic_magnitude = val_nearend_mic_magnitude.to(device)
val_nearend_magnitude = val_nearend_magnitude.to(device)
# 前向传播
val_pred_mask = model(val_X)
val_loss = criterion(val_pred_mask, val_mask)
# 近端语音信号频谱 = mask * 麦克风信号频谱 [batch_size, 161, 999]
val_pred_near_spectrum = val_pred_mask * val_nearend_mic_magnitude
val_lsd = pytorch_LSD(val_nearend_magnitude, val_pred_near_spectrum)
# ########### 可视化打印 ############
print(' val Epoch: {} \tLoss: {:.6f}\tlsd: {:.6f}'.format(epoch + 1, val_loss.item(), val_lsd.item()))
######################
# 更新tensorboard #
######################
writer.add_scalar(tag="val_loss", scalar_value=val_loss.item(), global_step=epoch + 1)
writer.add_scalar(tag="val_lsd", scalar_value=val_lsd.item(), global_step=epoch + 1)
writer.flush()
# # ########### 保存模型 ############
if (epoch + 1) % 10 == 0:
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch + 1,
# 'lr_schedule': lr_schedule.state_dict()
}
torch.save(checkpoint, '%s/%d.pth' % (args.checkpoints_dir, epoch + 1))
if __name__ == "__main__":
main()
2.2 训练模块 示范二
作者:石郎
链接:https://www.zhihu.com/question/406133826/answer/1334319004
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
# 定义网络
net = Net()
# 定义数据
#数据预处理,1.转为tensor,2.归一化
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 验证集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 开始训练
net.train()
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# 将梯度置为0
# zero the parameter gradients
optimizer.zero_grad()
# 求loss
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
# 梯度反向传播
loss.backward()
# 由梯度,更新参数
optimizer.step()
# 可视化
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
# 查看在验证集上的效果
dataiter = iter(testloader)
images, labels = dataiter.next()
# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
net.eval()
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
for j in range(4)))
3 训练模块 中输入的超参数
3.1 超参数定义
在机器学习的过程中,
超参= 在开始机器学习之前,就人为设置好的参数。
模型参数=通过训练得到的参数数据。
通常情况下,需要对超参数进行优化,给学习机选择一组最优超参数,以提高学习的性能和效果
3.2 深度学习中的常见 超参数
一个深度学习网络有很多的参数可以配置,一般分成以下三类:
- 数据集参数(文件路径、batch_size等)
- 训练参数(学习率、训练epoch等)
- 模型参数(输入的大小,输出的大小)
这些参数可以写一个类保存,也可以写一个字典,然后使用json保存,这些都是需要自己去实现的,但是这些都是一些细枝末节东西,
多写几次,找到一个自己最喜欢的方式就可以,不是深度学习项目中必要的部分。
标签:loss,val,args,epoch,train,模块,深度,编写,model 来源: https://blog.csdn.net/chumingqian/article/details/123236656