Pytorch卷积神经网络对MNIST数据集的手写数字识别
作者:互联网
这个程序由两个文件组成,一个训练脚本,一个测试脚本。安装好相应依赖环境之后即可进行训练,MNIST数据集使用torchvision.datasets.mnist
包自动下载。
mnistTrain.py
# -*- coding: utf-8 -*-
import torch
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from multiprocessing import cpu_count
from tqdm import tqdm
EPOCHS = 25 # 训练轮数
BATCH_SIZE = 64 # 每组数据多少张图片
DATA_FOLDER = 'dataset' # 数据集保存目录
MODEL_FILE = 'MNIST_CNN.pkl' # 模型文件路径
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class CNN(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, kernel_size=5, padding=2),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.fc = torch.nn.Linear(14 * 14 * 32, 10)
def forward(self, feature: torch.Tensor) -> torch.Tensor:
out: torch.Tensor = self.conv(feature)
out = out.flatten(1)
out = self.fc(out)
return out
if __name__ == '__main__':
torch.set_num_threads(cpu_count())
trainData = MNIST(DATA_FOLDER, train=True, transform=ToTensor(), download=True)
testData = MNIST(DATA_FOLDER, train=False, transform=ToTensor(), download=True)
trainLoader = DataLoader(trainData, batch_size=BATCH_SIZE, shuffle=True)
testLoader = DataLoader(testData, batch_size=128, shuffle=True)
cnn = CNN().to(DEVICE)
lossFunc = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.005)
bestAccuracy = 0
for epoch in range(EPOCHS):
# Train
for images, labels in tqdm(trainLoader, desc=f'Epoch {epoch + 1}/{EPOCHS}'):
images: torch.Tensor = images.to(DEVICE)
labels: torch.Tensor = labels.to(DEVICE)
predictions: torch.Tensor = cnn(images)
loss: torch.Tensor = lossFunc(predictions, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
accuracy = 0
for images, labels in testLoader:
images: torch.Tensor = images.to(DEVICE)
labels: torch.Tensor = labels.to(DEVICE)
predictions: torch.Tensor = cnn(images)
pred: torch.Tensor = predictions.max(dim=1)[1]
accuracy += (pred == labels).sum().item()
accuracy /= len(testData.targets)
if bestAccuracy < accuracy:
bestAccuracy = accuracy
torch.save(cnn, MODEL_FILE)
print(f'Accuracy: {accuracy * 100}% Best Accuracy: {bestAccuracy * 100}%')
mnistTest.py
# -*- coding: utf-8 -*-
from mnistTrain import CNN, BATCH_SIZE, DATA_FOLDER, DEVICE, MODEL_FILE
import torch
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from tqdm import tqdm
if __name__ == '__main__':
testData = MNIST(DATA_FOLDER, train=False, transform=ToTensor(), download=True)
testLoader = DataLoader(testData, batch_size=BATCH_SIZE, shuffle=True)
cnn: CNN = torch.load(MODEL_FILE).to(DEVICE)
accuracy = 0
for images, labels in tqdm(testLoader):
images: torch.Tensor = images.to(DEVICE)
labels: torch.Tensor = labels.to(DEVICE)
predictions: torch.Tensor = cnn.forward(images)
pred: torch.Tensor = predictions.max(dim=1)[1]
accuracy += (pred == labels).sum().item()
accuracy /= len(testData.targets)
print(f'Accuracy: {accuracy * 100}%')
标签:__,labels,Tensor,卷积,torch,Pytorch,images,import,MNIST 来源: https://www.cnblogs.com/fang-d/p/Pytorch_MNIST_CNN.html