其他分享
首页 > 其他分享> > Pytorch卷积神经网络对MNIST数据集的手写数字识别

Pytorch卷积神经网络对MNIST数据集的手写数字识别

作者:互联网

这个程序由两个文件组成,一个训练脚本,一个测试脚本。安装好相应依赖环境之后即可进行训练,MNIST数据集使用torchvision.datasets.mnist包自动下载。

mnistTrain.py

# -*- coding: utf-8 -*-
import torch
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from multiprocessing import cpu_count
from tqdm import tqdm


EPOCHS = 25                     # 训练轮数
BATCH_SIZE = 64                 # 每组数据多少张图片
DATA_FOLDER = 'dataset'         # 数据集保存目录
MODEL_FILE = 'MNIST_CNN.pkl'    # 模型文件路径
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class CNN(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=5, padding=2),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2)
        )
        self.fc = torch.nn.Linear(14 * 14 * 32, 10)

    def forward(self, feature: torch.Tensor) -> torch.Tensor:
        out: torch.Tensor = self.conv(feature)
        out = out.flatten(1)
        out = self.fc(out)
        return out


if __name__ == '__main__':
    torch.set_num_threads(cpu_count())

    trainData = MNIST(DATA_FOLDER, train=True, transform=ToTensor(), download=True)
    testData = MNIST(DATA_FOLDER, train=False, transform=ToTensor(), download=True)
    trainLoader = DataLoader(trainData, batch_size=BATCH_SIZE, shuffle=True)
    testLoader = DataLoader(testData, batch_size=128, shuffle=True)

    cnn = CNN().to(DEVICE)
    lossFunc = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(cnn.parameters(), lr=0.005)

    bestAccuracy = 0
    for epoch in range(EPOCHS):
        # Train
        for images, labels in tqdm(trainLoader, desc=f'Epoch {epoch + 1}/{EPOCHS}'):
            images: torch.Tensor = images.to(DEVICE)
            labels: torch.Tensor = labels.to(DEVICE)
            predictions: torch.Tensor = cnn(images)

            loss: torch.Tensor = lossFunc(predictions, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        accuracy = 0
        for images, labels in testLoader:
            images: torch.Tensor = images.to(DEVICE)
            labels: torch.Tensor = labels.to(DEVICE)
            predictions: torch.Tensor = cnn(images)
            pred: torch.Tensor = predictions.max(dim=1)[1]
            accuracy += (pred == labels).sum().item()

        accuracy /= len(testData.targets)

        if bestAccuracy < accuracy:
            bestAccuracy = accuracy
            torch.save(cnn, MODEL_FILE)

        print(f'Accuracy: {accuracy * 100}%    Best Accuracy: {bestAccuracy * 100}%')

mnistTest.py

# -*- coding: utf-8 -*-

from mnistTrain import CNN, BATCH_SIZE, DATA_FOLDER, DEVICE, MODEL_FILE
import torch
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from tqdm import tqdm

if __name__ == '__main__':
    testData = MNIST(DATA_FOLDER, train=False, transform=ToTensor(), download=True)
    testLoader = DataLoader(testData, batch_size=BATCH_SIZE, shuffle=True)
    cnn: CNN = torch.load(MODEL_FILE).to(DEVICE)

    accuracy = 0
    for images, labels in tqdm(testLoader):
        images: torch.Tensor = images.to(DEVICE)
        labels: torch.Tensor = labels.to(DEVICE)
        predictions: torch.Tensor = cnn.forward(images)
        pred: torch.Tensor = predictions.max(dim=1)[1]
        accuracy += (pred == labels).sum().item()

    accuracy /= len(testData.targets)
    print(f'Accuracy: {accuracy * 100}%')

标签:__,labels,Tensor,卷积,torch,Pytorch,images,import,MNIST
来源: https://www.cnblogs.com/fang-d/p/Pytorch_MNIST_CNN.html