其他分享
首页 > 其他分享> > Pytorch 加载数据集 Fashion-MNIST

Pytorch 加载数据集 Fashion-MNIST

作者:互联网

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

手动下载地址:https://github.com/zalandoresearch/fashion-mnist/blob/master/README.zh-CN.md

大无语事件,虽然说官方下载地址这里明说了这个数据集是集成在pytorch里了,而且查到的torchvision.dataset,也说有这么个函数,不过这个函数里没有定义任何功能函数,这可能就是加载失败的问题所在:

事实胜于雄辩,自动下载数据集失败了:

AttributeError: module 'torchvision.datasets' has no attribute 'FashionMNIST'

好吧,那就是没有了吧。数据集也不大,手动下载:

 加载方式就要相应改变,参考代码(https://blog.csdn.net/CBCZJL/article/details/104414904):

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..")
# 打开读取压缩文件
import gzip
import os
import numpy as np
def data_load(path, kind):
    images_path = os.path.join(path,'%s-images-idx3-ubyte.gz' % kind)
    labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz' % kind)
    with gzip.open(labels_path,'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(),dtype=np.uint8, offset=8)    
    with gzip.open(images_path,'rb') as impath:
        images = np.frombuffer(impath.read(),dtype=np.uint8, offset=16).reshape(len(labels),784)
    return images, labels

# 读取转化数据
X_train, y_train = data_load('C:/Users/CSS/Documents/jupyter-file/Fashion_mnist_dataset','train')
X_test, y_test = data_load('C:/Users/CSS/Documents/jupyter-file/Fashion_mnist_dataset','t10k')
X_train_tensor = torch.Tensor(X_train).reshape(-1,1,28,28)*(1/255)
X_test_tensor = torch.from_numpy(X_test).to(torch.float32).view(-1,1,28,28)*(1/255)
y_train_tensor = torch.from_numpy(y_train).to(torch.float32).view(-1,1)
y_test_tensor = torch.from_numpy(y_test).to(torch.float32).view(-1,1)

mnist_train = torch.utils.data.TensorDataset(X_train_tensor, y_train_tensor)
mnist_test = torch.utils.data.TensorDataset(X_test_tensor, y_test_tensor)

batch_size = 256
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

print(type(mnist_train))
print(len(mnist_train), len(mnist_test))

输出:

<class 'torch.utils.data.dataset.TensorDataset'>
60000 10000

加载过程中唯一有与参考代码不同的地方是这行:X_train_tensor = torch.Tensor(X_train).reshape(-1,1,28,28)*(1/255)

原文是:X_train_tensor = torch.from_numpy(X_train).to(torch.float32).view(-1,1,28,28)*(1/255)

当时运行时报错了,大概就是numpy与tensor的转换问题。

为啥X_test_tensor没报错,我没看明白,如有大佬明白,请不吝赐教。

标签:Fashion,tensor,torch,Pytorch,train,mnist,test,path,MNIST
来源: https://blog.csdn.net/lililinglingling/article/details/119114106