PyTorch - fashion-MNIST数据集的使用
作者:互联网
FashionMNIST数据集
Fashion-MNIST是一个10类服饰分类数据集, 我们可以使用它来检验不同算法的表现, 这是MNIST数据集不能做到的(原因在这里,想了解的可以看看介绍)。
torchvision的结构
torchvision包包含了很多图像相关的数据集以及处理方法, 并且有常用的模型结构。
-
torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:
-
torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
-
torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
-
torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
-
torchvision.utils: 其他的一些有用的方法。
# 导入需要的包
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
import matplotlib.pyplot as plt
加载数据
设置数据的缓存目录为 root_dir
随后获得训练集和测试集数据,第一次运行的时候会下载 FashionMNIST 数据集到指定的目录下
下载速度慢解决方案: Gitee 极速下载 Fashion-MNIST
将Fashion-MNIST/ data / fashion的四个压缩文件解压到指定的目录,不要删除原来的压缩包文件,因此数据集总共有八个文件。
# 通过标签得到描述语句
def get_f_mnist_labels(labels):
"""
:param labels: 图片对应的标签(0-9的数字)
:return: 标签对应的描述
"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
def show_fashion_mnist(images, labels):
"""
:param images: 读取的图片
:param labels: 图片对应的标签
:return: None, 输出图片,并且在图片上方对应标签给出描述
"""
_, figs = plt.subplots(1, len(images), figsize=(12, 2))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)))
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
root_dir = "./torchvision/data/"
f_mnist_train = FashionMNIST(root=root_dir, train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = FashionMNIST(root=root_dir, train=False, download=True, transform=transforms.ToTensor())
print("f_mnist_train length:", len(f_mnist_train), end='\n')
print("f_mnist_test length:", len(f_mnist_test), end='\n')
x, y = [], []
for i in range(10):
x.append(f_mnist_train[i][0])
y.append(f_mnist_train[i][1])
show_fashion_mnist(x, get_f_mnist_labels(y))
f_mnist_train length: 60000
f_mnist_test length: 10000
读取小批量数据
from torch.utils.data import DataLoader
batch_size = 256
train_iter = DataLoader(f_mnist_train, batch_size, shuffle=True, num_workers = 0)
# 计算加载数据的时间
import time
start = time.time()
for X, y in train_iter:
continue
print("read train data cost %.4f seconds" % (time.time()-start))
read train data cost 4.9213 seconds
注意
本章的介绍思路来自 Apple Store的 “Python AI” app, 作为学习目的使用, 以及在此文章中记录学习过程(如有侵权,请联系作者删除。)
标签:fashion,torchvision,labels,PyTorch,train,MNIST,import,root,mnist 来源: https://www.cnblogs.com/huiyanliu/p/14008883.html