其他分享
首页 > 其他分享> > 动手学深度学习 图像分类数据集(一) Fashion-MNIST的获取与查看

动手学深度学习 图像分类数据集(一) Fashion-MNIST的获取与查看

作者:互联网

动手学深度学习 图像分类数据集(一) Fashion-MNIST的获取与查看

动手学深度学习 图像分类数据系列:


Fashion-MNIST在书中多次使用,本文的内容是讲解如何获取并查看此数据集


1.下载数据集

使用torchvision.datasets来下载数据集

更多transform的操作可以点击这篇文章来查看

书本原话: 
注意:由于像素值为0到255的整数,所以刚好是uint8所能表示的范围,包括
transforms.ToTensor() 在内的一些关于图片的函数就默认输入的是uint8型,若不是,可能不会报错
但可能得不到想要的结果。所以,如果用像素值(0-255整数)表示图片数据,那么一律将其类型设置成
uint8,避免不必要的bug。
import torchvision
import torchvision.transforms as transforms
mnist_train = torchvision.datasets.FashionMNIST(root=r'D:\Source\Datasets\FashionMNIST', train=True, download=True,
                                                transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root=r'D:\Source\Datasets\FashionMNIST', train=False, download=True,
                                               transform=transforms.ToTensor())

查看一下读取的结果
在这里插入图片描述

2.查看数据集结构

对训练集切片查看一下数据类型和标签类型
在这里插入图片描述
这里的标签已经转换为数值型数据来存储
所以我们可以编写一个函数将其转换为 图像数据集原本对应的标签

def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

在这里插入图片描述

3.查看图片与标签

先提取出其中的一张图片与标签来查看

img, label = mnist_train[0]
title = get_fashion_mnist_labels([label])[0] # 获取标签
plt.imshow(img.view((28,28)).numpy())	# 数据格式转换
plt.title(title)	# 设置标题
plt.savefig('test.jpg')	# 存储图片

在这里插入图片描述
查看多个图片和标签(以前十张为例)

import matplotlib.pyplot as plt
def show_fashion_mnist(images, labels):
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
plt.show()

在这里插入图片描述

4.按小批次读取数据集

使用DataLoader 它可以允许多线程来加速数据读取

具体的可以看下面链接中的文章,有对DataLoaderDataset的详细介绍
Pytorch 快速详解如何构建自己的Dataset完成数据预处理(附详细过程)

from torch.utils.data import DataLoader
import sys
batch_size = 256
if sys.platform.startswith('win'):
    # 0表示不用额外的进程来加速读取数据
    num_workers = 0
else:
    num_workers = 4
train_iter = DataLoader(mnist_train,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_workers)
test_iter = DataLoader(mnist_test,
                       batch_size=batch_size,
                       shuffle=False,
                       num_workers=num_workers)

DataLoader是个可遍历的对象

start = time()
for X, y in train_iter:
	continue
print('%.2f sec' % (time() - start))

可以通过上述代码来查看读取一遍训练集需要的时间

引用资料来源

本文内容来自吴振宇博士的Github项目
对中文版《动手学深度学习》中的代码进行整理,并用Pytorch实现
【深度学习】李沐《动手学深度学习》的PyTorch实现已完成

标签:Fashion,查看,workers,labels,train,MNIST,图像,数据,mnist
来源: https://blog.csdn.net/Weary_PJ/article/details/113790046