其他分享
首页 > 其他分享> > pytorch P28 -卷积神经网络demo

pytorch P28 -卷积神经网络demo

作者:互联网

卷积神经网络与 传统神经 网络的训练模块基本一致,网络 模型差异较大。

一 读取数据

# 导包
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np

#读取 数据
# 定义超参数
input_size = 28 # 图像大小:28 * 28
num_classes = 10 # 标签的种类
num_epochs = 3 # 迭代的次数
batch_size = 64 # 每个批次的大小,即每64章图片一块进行一次训练

# 加载训练集
train_dataset = datasets.MNIST(
                                root='./data',
                                train=True,
                                transform=transforms.ToTensor(),
                                download=True
                                )
# 记载测试集
test_dataset = datasets.MNIST(root='./data',
                             train=False,
                             transform=transforms.ToTensor())
# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                         batch_size=batch_size,
                                         )

前面是导包

数据源还是mnist,分为训练集与测试集。使用DataLoader来构建batch数据。

二 搭建卷积神经网络模型

#网络 模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        self.conv1 = nn.Sequential(  # 输入大小 (1,28,28)
            nn.Conv2d(
                in_channels=1,  # 说明是灰度图
                out_channels=16,  # 要得到多少个特征图
                kernel_size=5,  # 卷积核的大小
                stride=1,  # 步长
                padding=2),  # 边缘填充的大小
            nn.ReLU(),  #relu层
            nn.MaxPool2d(kernel_size=2)  # 池化操作 (2 * 2) 输出结果为: (16,14,14)

        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),#16 是输入,32 是特征图
            nn.ReLU(), #relu层 
            nn.MaxPool2d(2))  # 输出 (32, 7, 7)
        self.out = nn.Linear(32 * 7 * 7, 10)  # 全连接输入分类

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)  # flatten操作,结果为 (batch_size, 32*7*7)
        output = self.out(x)
        return output
#准确率
def accuracy(predictions,labels):
    pred = torch.max(predictions.data,1)[1]
    rights = pred.eq(labels.data.view_as(pred)).sum()
    return rights,len(labels)

这里面的主要 参数 ,conv1与conv2 里面的 需要结合老师的视频区理解。还有前向传播调用out之前的view 变形操作。要结合上一个节的从矩阵降维到全连接理解 。

   三 训练网络模型

# 实例化
net = CNN()
# 选择损失函数
criterion = nn.CrossEntropyLoss()
# 选择优化器
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器,采用随机梯度下降算法

# 开始进行训练
for epoch in range(num_epochs):
    train_right = []  # 保存当前epoch的结果,和之前定义一个保存loss的是一个道理

    for batch_idx, (data, target) in enumerate(train_loader):
        net.train()
        output = net(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()  # 优化器调用step(),不是loss
        right = accuracy(output, target)
        train_right.append(right)

        if batch_idx % 100 == 0:#每100次到验证集看效果
            net.eval()
            val_right = []
            for (data, target) in test_loader:
                output = net(data)
                right = accuracy(output, target)
                val_right.append(right)

            # 准确率的计算
            train_rate = (sum([tup[0] for tup in train_right]), sum(tup[1] for tup in train_right))
            val_rate = (sum([tup[0] for tup in val_right]), sum(tup[1] for tup in val_right))

            print('当前epoch:{}[{}/{}({:.0f}%)]\t 损失:{:.6f}\t 训练集准确率:{:.2f}%\t 测试集准确率:{:.2f}%'.format(
                epoch, batch_idx * batch_size, len(train_loader.dataset),
                       100.0 * batch_idx / len(train_loader),
                loss.data,
                       100.0 * train_rate[0].numpy() / train_rate[1],
                       100.0 * val_rate[0].numpy() / val_rate[1]
            ))

没100次,看验证集输出结果:

当前epoch:0[0/60000(0%)]	 损失:2.300367	 训练集准确率:12.50%	 测试集准确率:10.14%
当前epoch:0[6400/60000(11%)]	 损失:0.280655	 训练集准确率:74.13%	 测试集准确率:92.48%
当前epoch:0[12800/60000(21%)]	 损失:0.166317	 训练集准确率:83.76%	 测试集准确率:95.60%
当前epoch:0[19200/60000(32%)]	 损失:0.105674	 训练集准确率:87.71%	 测试集准确率:95.65%
当前epoch:0[25600/60000(43%)]	 损失:0.094606	 训练集准确率:89.83%	 测试集准确率:97.22%
当前epoch:0[32000/60000(53%)]	 损失:0.065384	 训练集准确率:91.26%	 测试集准确率:97.60%
当前epoch:0[38400/60000(64%)]	 损失:0.049964	 训练集准确率:92.25%	 测试集准确率:97.51%
当前epoch:0[44800/60000(75%)]	 损失:0.035163	 训练集准确率:93.01%	 测试集准确率:97.83%
当前epoch:0[51200/60000(85%)]	 损失:0.055695	 训练集准确率:93.56%	 测试集准确率:98.14%
当前epoch:0[57600/60000(96%)]	 损失:0.014890	 训练集准确率:94.03%	 测试集准确率:97.77%
当前epoch:1[0/60000(0%)]	 损失:0.081240	 训练集准确率:93.75%	 测试集准确率:98.20%
当前epoch:1[6400/60000(11%)]	 损失:0.049458	 训练集准确率:98.04%	 测试集准确率:98.24%
当前epoch:1[12800/60000(21%)]	 损失:0.026402	 训练集准确率:98.12%	 测试集准确率:98.18%
当前epoch:1[19200/60000(32%)]	 损失:0.056982	 训练集准确率:98.11%	 测试集准确率:98.49%
当前epoch:1[25600/60000(43%)]	 损失:0.098775	 训练集准确率:98.13%	 测试集准确率:98.63%
当前epoch:1[32000/60000(53%)]	 损失:0.119748	 训练集准确率:98.15%	 测试集准确率:98.26%
当前epoch:1[38400/60000(64%)]	 损失:0.024341	 训练集准确率:98.18%	 测试集准确率:98.49%
当前epoch:1[44800/60000(75%)]	 损失:0.017717	 训练集准确率:98.20%	 测试集准确率:97.95%
当前epoch:1[51200/60000(85%)]	 损失:0.084650	 训练集准确率:98.20%	 测试集准确率:98.45%
当前epoch:1[57600/60000(96%)]	 损失:0.014650	 训练集准确率:98.18%	 测试集准确率:98.68%
当前epoch:2[0/60000(0%)]	 损失:0.089021	 训练集准确率:96.88%	 测试集准确率:98.54%
当前epoch:2[6400/60000(11%)]	 损失:0.048318	 训练集准确率:98.72%	 测试集准确率:98.68%
当前epoch:2[12800/60000(21%)]	 损失:0.051317	 训练集准确率:98.71%	 测试集准确率:98.62%
当前epoch:2[19200/60000(32%)]	 损失:0.033962	 训练集准确率:98.67%	 测试集准确率:98.53%
当前epoch:2[25600/60000(43%)]	 损失:0.025890	 训练集准确率:98.72%	 测试集准确率:98.79%
当前epoch:2[32000/60000(53%)]	 损失:0.007487	 训练集准确率:98.72%	 测试集准确率:98.57%
当前epoch:2[38400/60000(64%)]	 损失:0.015440	 训练集准确率:98.74%	 测试集准确率:98.81%
当前epoch:2[44800/60000(75%)]	 损失:0.006676	 训练集准确率:98.73%	 测试集准确率:98.84%
当前epoch:2[51200/60000(85%)]	 损失:0.034487	 训练集准确率:98.72%	 测试集准确率:98.85%
当前epoch:2[57600/60000(96%)]	 损失:0.042631	 训练集准确率:98.73%	 测试集准确率:98.73%

标签:训练,demo,epoch,损失,准确率,pytorch,60000,测试,P28
来源: https://blog.csdn.net/bohu83/article/details/122766429