其他分享
首页 > 其他分享> > Pytorch(3)-Torchvision的使用

Pytorch(3)-Torchvision的使用

作者:互联网

import torchvision
# 通过ToTensor()将数据集转为tensor数据类型,并通过compose连接
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])

# 加载数据集,其中CIFAR10是pytorch提供的一种数据集类型,具体参数介绍:
# root:要保存数据集的目录
# train:如果为true创建一个训练数据集,如果为false创建一个测试数据集
# transform:图像转换后的数据类型
# download:如果为true则会在网络中下载该数据集,如果为false则不会下载,如果数据集已经存在则会在控制台输出数据集已存在
train_set = torchvision.datasets.CIFAR10(root="F:\\pytorch\\pytorch01_hello\\dataset\\train\\torchvision_image", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="F:\\pytorch\\pytorch01_hello\\dataset\\train\\torchvision_image", train=False, transform=dataset_transform, download=True)
# print(test_set.classes)
# print(test_set[0])

# 将数据集加载到tensorboard查看,这里查看前十张图片
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("p10课程", img, i)
writer.close()

标签:set,torchvision,transform,dataset,Pytorch,train,使用,test,Torchvision
来源: https://www.cnblogs.com/XiaoMaGuai/p/16298979.html