搭建网络模型的笔记
作者:互联网
搭建网络模型
1. 导入模块
- import 模块
2. 选择设备
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3. 准备数据集
- 训练集:
- train_data = torchvision.datasets.CIFAR10(root="./data_CIFAR10", train=True,transform=torchvision.transforms.ToTensor(),download=True)
- 测试集:
- test_data = torchvision.datasets.CIFAR10(root="./data_CIFAR10", train=False,transform=torchvision.transforms.ToTensor(),download=True)
4. 加载数据集
- train_dataloader = DataLoader(train_data, batch_size=64)
- test_dataloader = DataLoader(test_data, batch_size=64)
5. 创建网络模型
- class MyModel(nn.Module):
- def init(self):
- xxxxxxx
- xxxxxxx
- def forward(self, x):
- x = self.model1(x)
- return x
- def init(self):
6. 实例化网络模型
- net_model = MyModel()
- net_model = net_model.to(device)
7. 定义损失函数
- loss_fn = nn.xxxxxxxLoss()
- if torch.cuda.is_available():
- loss_fn = loss_fn.cuda()
8. 定义优化器
- optimizer = torch.optim.SGD(net_model.parameters(), lr=learning_rate)
- optimizer = optim.SGD(net.parameters(), lr=opt.lr, momentum=0.9) # 选择优化器
- scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略
9. 训练部分
- for i in range(epoch):
-
训练步骤
-
for data in train_dataloader:
- 1.获取训练数据
- imgs, targets = data
- 2.选择设备
- imgs = imgs.to(device)
- 3.把图片传入网络模型进行训练,返回10个结果
- targets = targets.to(device)
- outputs = net_model(imgs)
- 4.进行损失函数处理
- loss = loss_fn(outputs, targets)
- 5.梯度清零
- optimizer.zero_grad()
- 6.反向传播
- loss.backward()
- 7.优化器,更新权重
- optimizer.step()
-
测试步骤
-
with torch.no_grad():
- for data in test_dataloader:
-
imgs, targets = data # 1.获取测试数据
-
imgs = imgs.to(device) # 2.选择设备
-
targets = targets.to(device)
-
outputs = net_model(imgs) # 3.将测试图片传入训练模型
-
loss = loss_fn(outputs, targets) # 4.计算损失值
-
total_test_loss = total_test_loss + loss.item() # 5.计算总的损失值
-
accuracy = (outputs.argmax(1) == targets).sum() # 6.计算准确率
-
total_accuarcy = total_accuarcy + accuracy # 7.计算总准确率
-
- for data in test_dataloader:
-
保存训练好的模型
-
torch.save(net_model, "net_model{}.path".format(i))
-
10. 验证数据
-
待验证图片预处理
- 转换为和测试集相同格式的图片,输入为同类型
-
- 加载
- transfrom = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()])
- image = transfrom(image)
-
- 转换
- image = torch.reshape(image, (1, 3, 32, 32))
-
加载保存的模型
- model = torch.load("net_model25.path")
-
验证
-
model.eval()
-
with torch.no_grad():
- image = image.to(device)# 转换成相同类型数据集
- output = model(image)
-
print(output)
-
print(output.argmax(1))
标签:loss,模型,torch,笔记,targets,net,data,model,搭建 来源: https://www.cnblogs.com/cmn-note/p/15190053.html