其他分享
首页 > 其他分享> > PyTorch - fashion-MNIST数据集的使用

PyTorch - fashion-MNIST数据集的使用

作者:互联网

FashionMNIST数据集

Fashion-MNIST是一个10类服饰分类数据集, 我们可以使用它来检验不同算法的表现, 这是MNIST数据集不能做到的(原因在这里,想了解的可以看看介绍)。

torchvision的结构

torchvision包包含了很多图像相关的数据集以及处理方法, 并且有常用的模型结构。

# 导入需要的包
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
import matplotlib.pyplot as plt

加载数据

设置数据的缓存目录为 root_dir

随后获得训练集和测试集数据,第一次运行的时候会下载 FashionMNIST 数据集到指定的目录下

下载速度慢解决方案: Gitee 极速下载 Fashion-MNIST

将Fashion-MNIST/ data / fashion的四个压缩文件解压到指定的目录,不要删除原来的压缩包文件,因此数据集总共有八个文件。

# 通过标签得到描述语句
def get_f_mnist_labels(labels):
    """

    :param labels: 图片对应的标签(0-9的数字)
    :return: 标签对应的描述
    """
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


def show_fashion_mnist(images, labels):
    """

    :param images: 读取的图片
    :param labels: 图片对应的标签
    :return: None, 输出图片,并且在图片上方对应标签给出描述
    """
    _, figs = plt.subplots(1, len(images), figsize=(12, 2))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)))
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

root_dir = "./torchvision/data/"
f_mnist_train = FashionMNIST(root=root_dir, train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = FashionMNIST(root=root_dir, train=False, download=True, transform=transforms.ToTensor())

print("f_mnist_train length:", len(f_mnist_train), end='\n')
print("f_mnist_test length:", len(f_mnist_test), end='\n')

x, y = [], []
for i in range(10):
    x.append(f_mnist_train[i][0])
    y.append(f_mnist_train[i][1])
show_fashion_mnist(x, get_f_mnist_labels(y))
f_mnist_train length: 60000
f_mnist_test length: 10000

输出10张图片和对应的标签

读取小批量数据

from torch.utils.data import DataLoader

batch_size = 256
train_iter = DataLoader(f_mnist_train, batch_size, shuffle=True, num_workers = 0)

# 计算加载数据的时间
import time
start = time.time()
for X, y in train_iter:
    continue
print("read train data cost %.4f seconds" % (time.time()-start))

read train data cost 4.9213 seconds

注意

本章的介绍思路来自 Apple Store的 “Python AI” app, 作为学习目的使用, 以及在此文章中记录学习过程(如有侵权,请联系作者删除。)

标签:fashion,torchvision,labels,PyTorch,train,MNIST,import,root,mnist
来源: https://www.cnblogs.com/huiyanliu/p/14008883.html