其他分享
首页 > 其他分享> > Pytorch实战:CIFAR-10分类

Pytorch实战:CIFAR-10分类

作者:互联网

最近在学习Pytorch,先照着别人的代码过一遍,加油!!!

 

加载数据集

# 加载数据集及预处理
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch as t
show=ToPILImage() #可以将Tensor转成Image,方便可视化

划分数据集为训练集和测试集

#定义对数据的预处理
transform=transforms.Compose([
    transforms.ToTensor(),  #转为Tensor
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)), #归一化
])

#训练集
trainset=tv.datasets.CIFAR10(
    root='/home/cy/data',
    train=True,
    download=True,
    transform=transform
)

trainloader=t.utils.data.DataLoader(
    trainset,
    batch_size=4,
    shuffle=True,
    num_workers=2
)

testset=tv.datasets.CIFAR10(
    '/home/cy/data/',
    train=False,
    download=True,
    transform=transform
)

testloader=t.utils.data.DataLoader(
    testset,
    batch_size=4,
    shuffle=False,
    num_workers=2
)

classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
Files already downloaded and verified
Files already downloaded and verified

可视化看下图片效果
(data, label)=trainset[100]
print(classes[label])

#(data+1)是为了还原被归一化的数据
show((data+1)/2).resize((100,100))

展示一个mini-batch中的图片

dataiter=iter(trainloader)
images,labels=dataiter.next() #返回4张图片及标签
print(' '.join('%11s'%classes[labels[j]] for j in range(4)))
show(tv.utils.make_grid((images+1)/2)).resize((400,100))

 

定义网络结构,挺方便的

## 定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(3,6,5)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
        
        
    def forward(self,x):
        x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x=F.max_pool2d(F.relu(self.conv2(x)),2)
        x=x.view(x.size()[0],-1)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x

net=Net()
print(net)
Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

定义损失函数和优化器
## 定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9) #随机梯度下降,stochastic gradient descent

开始训练网络

一共有三个步骤。输入数据,前向传播+反向传播,更新参数

from torch.autograd import Variable

for epoch in range(2):
    running_loss=0.0
    for i,data in enumerate(trainloader,0):
        #输入数据
        inputs,labels=data
        inputs,labels=Variable(inputs),Variable(labels)
        
        #梯度清零
        optimizer.zero_grad()
        
        #forward+backward
        outputs=net(inputs)
        loss=criterion(outputs,labels)
        loss.backward()
        
        #更新参数
        optimizer.step()
        
        #打印log信息
        #running_loss +=loss.data[0]
        running_loss +=loss.item()
        if i%2000 ==1999:   #每2000个batch打印一次训练状态
            print('[%d, %5d] loss: %.3f' \
                 %(epoch+1,i+1,running_loss / 2000))
            running_loss=0.0
print('Finished Training')

 

检查一下网络在一个batch内的效果如何

## 检验网络效果
dataiter=iter(testloader)
images,labels=dataiter.next() #一个batch返回4张图片
print('实际的label: ',' '.join(\
            '%08s'%classes[labels[j]] for j in range(4)))
show(tv.utils.make_grid(images/2 -0.5)).resize((400,100))

# 计算网络预测的label
outputs=net(Variable(images))
_,predicted=t.max(outputs.data,1)
print('预测结果: ',' '.join('%5s'\
        % classes[predicted[j]] for j in range(4)))

 

测试集上计算正确率

correct=0
total=0
for data in testloader:
    images,labels=data
    outputs=net(Variable(images))
    _,predicted=t.max(outputs.data,1)
    total +=labels.size(0)
    correct +=(predicted==labels).sum()
    
print('1000张测试集中的准确率为: %d  %%' %(100* correct/total))
1000张测试集中的准确率为: 52  %

 

可以看到,在CIFAR-10上的正确率为52%,网络训练还是有些效果的。

 

标签:10,nn,loss,self,labels,CIFAR,Pytorch,print,data
来源: https://www.cnblogs.com/keeptry/p/13943820.html