其他分享
首页 > 其他分享> > 2021-11-02

2021-11-02

作者:互联网

pytorch中torchvision.transforms的一些理解

1.这个库里面主要是包含了一些图像处理的函数,也就是说使用.transforms的地方同样可以用其他图像库进行处理,例如opencv。
2.这个库一般只用于和torchvision.datasets一起使用的时候,其他的一般自己弄就行了。

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=BATCH_SIZE, shuffle=True)

3.我们使用pytorch的时候用的最多的就是这两句:

   transforms.ToTensor(),#归一化将shape为(H, W, C)的nump.ndarray或img转为shape为(C, H, W)的tensor
   transforms.Normalize((0.1307,), (0.3081,))  #标准化是为了加快收敛性 这里的0.1307和0.3081是MNIST数据集里的均值和标准差,因为只有一个通道,所以只写了一个这个东西一般是数据集提供方给出的。

对于其他的操作我们也可以用其他的库进行图像处理。

标签:11,02,0.1307,shape,Normalize,datasets,2021,0.3081,transforms
来源: https://blog.csdn.net/weixin_44587732/article/details/121104845