其他分享
首页 > 其他分享> > pytorch:cifar-10+lenet5代码实现

pytorch:cifar-10+lenet5代码实现

作者:互联网

本篇代码有不清楚的地方,可以参考:
cifar-10+resnet.
这篇除了搭建的CNN不一样,其他地方完全一样。

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

cifar_train = datasets.CIFAR10('cifar',True,transform=transforms.Compose((  #true表示加载的是训练集
                                transforms.Resize(32,32),
                                transforms.ToTensor()))) 

cifar_train_batch = DataLoader(cifar_train,batch_size = 30,shuffle = True)

cifar_test = datasets.CIFAR10('cifar',False,transform=transforms.Compose((  #false表示加载的是测试集
                                transforms.Resize(32,32),
                                transforms.ToTensor()))) 

cifar_test_batch = DataLoader(cifar_test_one,batch_size = 30,shuffle = True)

搭建CNN:

from torch import nn
class lenet5(nn.Module):
    def __init__(self):
        super(lenet5,self).__init__()
        
        #两层卷积
        self.conv_unit = nn.Sequential(     
        nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5,stride=1,padding=0),
        nn.AvgPool2d(kernel_size=2,stride=2,padding=0),            
        nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,stride=1,padding=0),
        nn.AvgPool2d(kernel_size=2,stride=2,padding=0)
        )
        
        '''
        卷积之后先flatten,然后再全连接,但是nn.Module中没有flatten的操作,
        所以flatten不能包含在sequential中
        因此先用一个sequential完成卷积操作,然后flatten,然后再用一个sequential完成全连接
        '''
        
        #全连接层
        self.fc_unit = nn.Sequential(
        nn.Linear(16*5*5,120),
        nn.ReLU(),
        nn.Linear(120,84),
        nn.ReLU(),
        nn.Linear(84,10)
        )        
        
    def forward(self,x):
        batch_size = x.shape[0]
        x = self.conv_unit(x)
        x = x.reshape(batch_size,16*5*5)
        logits = self.fc_unit(x)
        return logits
device = torch.device('cuda')
net = lenet5()
net = net.to(device) #将网络部署到GPU上
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(net.parameters(),lr=1e-3)
#开始训练
for epoch in range(5):
    for batchidx,(x,label) in enumerate(cifar_train_batch):
        x,label = x.to(device),label.to(device) #x.size (bcs,3,32,32) label.size (bcs)
        logits = net.forward(x)
        loss = loss_fn(logits,label) #logits.size:bcs*10,label.size:bcs
        
        #开始反向传播:
        optimizer.zero_grad()
        loss.backward() #计算gradient
        optimizer.step() #更新参数
        if (batchidx+1)%400 == 0:
            print('这是本次迭代的第{}个batch'.format(batchidx+1))  #本例中一共有50000张照片,每个batch有30张照片,所以一个epoch有1667个batch
    
    print('这是第{}迭代,loss是{}'.format(epoch+1,loss.item()))
        
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第1迭代,loss是1.1926124095916748
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第2迭代,loss是1.1064329147338867
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第3迭代,loss是0.8839625120162964
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第4迭代,loss是1.0676394701004028
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第5迭代,loss是1.0307090282440186
#测试
net.eval()
with torch.no_grad():
    correct_num = 0 #预测正确的个数
    total_num = 0 #测试集中总的照片张数
    batch_num = 0 #第几个batch
    for x,label in cifar_test: #x的size是30*3*32*32(30是batch_size,3是通道数),label的size是30.
                               #cifar_test中一共有10000张照片,所以一共有334个batch,因此要循环334次
        x,label = x.to(device),label.to(device) 
        logits = net.forward(x)
        pred = logits.argmax(dim=1)
        correct_num += torch.eq(pred,label).float().sum().item()
        total_num += x.size(0)
        batch_num += 1
        if batch_num%50 == 0:
            print('这是第{}个batch'.format(batch_num)) #一共有10000/30≈334个batch
            
    acc = correct_num/total_num  #最终的total_num是10000
    print('测试集上的准确率为:',acc)
    
这是第50个batch
这是第100个batch
这是第150个batch
这是第200个batch
这是第250个batch
这是第300个batch
测试集上的准确率为: 0.5496


在这里插入图片描述

标签:10,lenet5,迭代,loss,batch,cifar,nn,size
来源: https://blog.csdn.net/weixin_41391619/article/details/104883994