其他分享
首页 > 其他分享> > FashionMNIST数据集简要分析---深度学习&机器学习第五天

FashionMNIST数据集简要分析---深度学习&机器学习第五天

作者:互联网

图像分类数据集—FashionMNIST数据集


①简介:fashionmnist数据集中共有10种类别的服饰,分别为:

['t-shirt', 'toruser', 'pullover', 'dress', 'coat', 'sandal', 'shirt' ,'sneaker', 'bag', 'ankle boots']

部分服饰为:
在这里插入图片描述


②具体介绍:在该数据集中共有7万张图片,每张图片的形状为:[单通道,长28,宽28],并且每张图片对应一种服饰(一种标签)。其中训练集和测试集的图片是分开的,分别有6万张图片和1万张图片。


③探索FashionMNIST数据集

导入相应的库,并下载数据集

%matplotlib inline
import torch
from IPython import display
import torchvision                     # torchvision是关于图像操作的一些方便工具库,对于计算机视觉进行实现的一个库
from torch.utils import data           # 用来读取数据
from torchvision import transforms     # 为pytorch中图像预处理包,包含了很多种对图像进行变化的函数
from d2l import torch as d2l
import matplotlib.pyplot as plt
import time

def use_svg_display():
    # 用矢量图显示图片
    display.set_matplotlib_formats('svg')    # format格式
    
use_svg_display()    # 用svg显示图片,这样图片的清晰度会更高

# 下载数据集
trans = transforms.ToTensor()    # 把shape为(x, y, z)的转换为(z, x, y),并每个元素除以255
                                 # 得到每个元素的数值均在0到1之间
mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)

数据集的探索

len(mnist_train), len(mnist_test)
# answer (60000, 10000) 训练集60000张,测试集10000张

mnist_train[0][0].shape
# torch.Size([1, 28, 28]) 单张图片的通道数和尺寸

数据集的可视化,结果为简介中的图片

def get_fashion_mnist_labels(labels):
    """返回Fashion-MNIST数据集的文本标签。"""
    test_labels = ['t-shirt', 'toruser', 'pullover', 'dress', 'coat', 'sandal', 'shirt' ,'sneaker', 'bag', 'ankle boots']
    return [test_labels[int(i)] for i in labels]

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):    # 该函数还未研究
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
images = show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
images
plt.savefig('部分服饰.png', facecolor='white', edgecolor='red')    # 生成图片的保存 

④导入数据集

把数据集通过函数形式导入到内存中

def load_data_fashion_mnist(batch_size, resize=None):
    """加载Fashion-MNIST数据集到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))# 把图片放大成resize * resize大小
    trans = transforms.Compose(trans)             # 串联多个图片变换的操作
    mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans)
    mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()), 
            data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))          

解释两个参数的含义:
batch_size:我们一次读取多少张图片
resize:是否要对图片进行等比例的放大或缩小。eg: resize=66,则图片的尺寸变为66 x 66


⑤加载数据集

train_iter, test_iter = load_data_fashion_mnist(8, 12)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break

结果为:torch.Size([8, 1, 12, 12]) torch.float32 torch.Size([8]) torch.int64
说明:我们一次读取8张图片,每张图片为单通道,尺寸为12 x 12,并且每张图片都有对应的标签,一共8个标签。


⑥查看单张图片

for X, y in test_iter:
    print(X[0].tolist(), y[0])
    break

结果为:
[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003921568859368563, 0.003921568859368563], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.019607843831181526, 0.1411764770746231, 0.019607843831181526, 0.003921568859368563, 0.10196078568696976, 0.062745101749897], [0.0, 0.0, 0.0, 0.0, 0.0, 0.007843137718737125, 0.18431372940540314, 0.4745098054409027, 0.4745098054409027, 0.43921568989753723, 0.47058823704719543, 0.11372549086809158], [0.0, 0.0, 0.0, 0.003921568859368563, 0.003921568859368563, 0.125490203499794, 0.38823530077934265, 0.5333333611488342, 0.6039215922355652, 0.6352941393852234, 0.5803921818733215, 0.1921568661928177], [0.0, 0.003921568859368563, 0.003921568859368563, 0.03529411926865578, 0.14901961386203766, 0.3803921639919281, 0.4588235318660736, 0.5607843399047852, 0.5921568870544434, 0.6117647290229797, 0.5921568870544434, 0.3843137323856354], [0.08235294371843338, 0.1921568661928177, 0.26274511218070984, 0.3607843220233917, 0.4431372582912445, 0.4745098054409027, 0.5254902243614197, 0.5764706134796143, 0.6078431606292725, 0.6078431606292725, 0.6196078658103943, 0.5176470875740051], [0.33725491166114807, 0.47058823704719543, 0.5058823823928833, 0.49803921580314636, 0.5137255191802979, 0.5647059082984924, 0.6078431606292725, 0.6392157077789307, 0.6941176652908325, 0.800000011920929, 0.7686274647712708, 0.5333333611488342], [0.0470588244497776, 0.12156862765550613, 0.24313725531101227, 0.30588236451148987, 0.32156863808631897, 0.3176470696926117, 0.2235294133424759, 0.11764705926179886, 0.20392157137393951, 0.35686275362968445, 0.3176470696926117, 0.20000000298023224], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]
对应的标签为:tensor(9),说明为第9种类型的服饰

结束!!!✨✨✨✨✨✨

完整代码链接:FashionMNIST数据集

标签:学习,0.0,torch,---,train,FashionMNIST,data,mnist,图片
来源: https://blog.csdn.net/cristemw/article/details/119330787