PyTorch深度学习(3)Transforms CIFAR10
作者:互联网
使用Transforms,需要先引入 from torchvision import transforms
Tensor 张量 实际就是一个多维数组multidimensional array,其目的是能够创造更高维度的矩阵、向量
__call__方法:
魔法函数__call__,即把类当作函数使用,不需要再调用类中的函数
例如:person = Person() person("name") person.Name()
# 创建具体的工具 tool = transforms.ToTensor()
# 使用工具 result = tool(input) 输出结果 调用Tensor __call__是把类当作函数使用
# PIL Image numpy.ndarray --> tensor
# Tensor 张量 实际就是一个多维数组multidimensional array,其目的是能够创造更高维度的矩阵、向量
from PIL import Image
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
# python的用法 ——> tensor数据类型
# 通过transforms.ToTensor()
# 绝对路径 D:\PycharmProjects\learn_pytorch\train\ants\5650366_e22b7e1065.jpg 可以在前面添加r转义
# 相对路径 train/ants/5650366_e22b7e1065.jpg
img_path = "train/ants/5650366_e22b7e1065.jpg"
img_PIL = Image.open(img_path)
# transforms 如何使用(python)
# ToTensor()使用
tensor_trans = transforms.ToTensor()
img_trans = tensor_trans(img_PIL) # 调用Tensor的魔法函数__call__ 返回F.to_tensor(pic)
writer = SummaryWriter("logs")
writer.add_image("Tensor_image", img_trans)
writer.close()
ToTensor():
将PIL Image numpy.ndarray 转换为 tensor
创建具体的工具 tool = transforms.ToTensor()
使用工具 result = tool(input)
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
img_path = "train/ants/20935278_9190345f6b.jpg"
img = Image.open(img_path)
writer = SummaryWriter("logs")
# ToTensor
trans_tensor = transforms.ToTensor() # PIL Image or numpy.array
img_tensor = trans_tensor(img)
writer.add_image("ToTensor", img_tensor)
Normalize 归一化
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 参数:均值,标准差
# Normalize 归一化
print(img_tensor[0][0][0]) # 0层0行0列 transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])时,输出:0.8275
trans_norm = transforms.Normalize([1, 3, 5], [3, 2, 1]) # 参数:均值、标准差
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0]) # 输出:0.6549
writer.add_image("Normalize", img_norm)
Resize 重置大小
输入为PIL Image
给定一个序列(h, w),最小的边匹配number,等比缩放
# Resize
trans_resize = transforms.Resize((512, 512))
# img PIL -> resize -> img_resize PIL
img_resize = trans_resize(img)
# img_resize PIL -> toTensor -> img_resize tensor
img_tensor_resize = trans_tensor(img_resize)
writer.add_image("Resize", img_tensor_resize, 0)
Compose 将多个参数功能整合
Compose() 中参数需要是一个列表,列表的数据表示形式为[数据1, 数据2, ...]
Compose()中,数据需要时transforms类型,得Compose([transforms参数1, transforms参数2, ...])
# Compose中参数需要一个列表,列表形式为[数据1, 数据2, ...]
# 在Compose中,数据需要的是transforms类型, Compose([transforms参数1, transforms参数2, ...])
trans_resize_2 = transforms.Resize(512) # Resize中一个数,为按照图片最小边进行缩放
trans_compose = transforms.Compose([trans_resize_2, trans_tensor]) # 第一个参数:改变图片大小,第二个参数:转换类型
img_resize_2 = trans_compose(img)
writer.add_image("Resize2", img_resize_2, 1)
RandomCrop 随即裁剪
只会裁剪为指定(h, w) 宽高
# RandomCrop 随机裁剪
trans_rc = transforms.RandomCrop(128)
trans_comp = transforms.Compose([trans_rc, trans_tensor])
for i in range(10):
img_crop = trans_comp(img)
writer.add_image("RandomCropHW", img_crop, i)
writer.close()
Dataset 和 Transforms 联合使用
CIFAR10数据集中共有60000张彩色图像,图像32×32,分为10个类。50000张用于训练,10000张用于测试
10类分别为['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
数据集中数据类型为:(tensor, 类别标号)
# dataset 和 transforms 联合使用
import torchvision
from torch.utils.tensorboard import SummaryWriter
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataset_transformer = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# root-数据加载目录 train-true为训练集,false加载测试集 download-是否下载 transform-是否将PIL转换为tensor
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transformer, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transformer, download=True)
print(test_set[0]) # 输出: tensor(..., 3) tensor, 类别
print(test_set.classes) # 输出:总共十种类型,['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img, target = test_set[0]
print(img) # 输出:图片的tensor
print(target) # 输出:类别 3
print(test_set.classes[target]) # 输出: cat
img.show()
writer = SummaryWriter("logs")
for i in range(10):
img, target = test_set[i]
writer.add_image("Test_set", img, i)
writer.close()
标签:tensor,img,CIFAR10,writer,PyTorch,transforms,trans,resize,Transforms 来源: https://blog.csdn.net/jiangyangll/article/details/120783420