其他分享
首页 > 其他分享> > pytorch常用transforms

pytorch常用transforms

作者:互联网

import os.path
import numpy as np
import torch
import cv2
from PIL import Image
from torch.utils.data import Dataset
import re
from functools import reduce
from torch.utils.tensorboard import SummaryWriter as Writer
from torchvision import transforms
#需要安装pip install tb-nightly
#正则表达式匹配出最后的数字:12
#print(re.findall("(\d+)","flower")[-1])
#创建自定义DataSet类
class myDataSet(Dataset):
    #每个分类的子文件夹独立成一个标签数据集,标签例如flower0
    def __init__(self,rootdir,labeldir):
        self.rootdir=rootdir
        self.labeldir=labeldir
        self.imagePaths=os.path.join(rootdir,labeldir)

    #item作为编号:opencv版本
    def __getitem__(self, item):
        imagePath=os.listdir(self.imagePaths)[item]
        imagePath=os.path.join(self.imagePaths,imagePath)
        img=cv2.imdecode(np.fromfile(imagePath,np.uint8),-1)
        #bgr转rgb,避免改变数组的连续性,不然后续transform会报错
        img = img[:, :, ::-1].copy()
        labelComopent =re.findall("(\d+)",self.labeldir)
        #如果在标签中取不出对应tag
        if len(labelComopent)==0:
            raise ValueError
        label=int(labelComopent[-1])
        return img,label

    def __len__(self):
        return len(self.imagePaths)
#使用r标识路径防止转义:
rootdir=r"D:\17flowers"
labelList=os.listdir(rootdir)
allDataSet=[]
#生成各子数据集
for label in labelList:
    allDataSet.append(myDataSet(rootdir,label))
trainDataSet=reduce(lambda x,y:x+y,allDataSet)
#PIL Image or numpy.ndarray转换为tensor,opencv直接读出numpy,PIL读出PIL格式,能够把灰度范围从0-255变换到0-1之间
#同时输入为numpy时,经转换后,通道数会转移至第一位:
t1=transforms.ToTensor()
#中心裁剪
t2=transforms.CenterCrop(300)
#原来的0-1最小值0则变成(0-0.5)/0.5=-1,而最大值1则变成(1-0.5)/0.5=1.
t3=transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
#传入tensor,输出tensor;传入PIL,输出PIL
t4=transforms.Resize((500,500))
t=transforms.Compose([t1,t2,t3,t4])


#载入日志写入器:
writer=Writer("./myBorderText")
for index,datas in enumerate(trainDataSet):
    #存储100张图像:
    if index>10:
        break
    #(500, 689, 3)
    #print(datas[0].shape)
    #注意使用dataformats转变输入图像通道顺序:
    writer.add_image("图片未处理",img_tensor=datas[0],global_step=index,dataformats="HWC")
    #通道数已转移至第一维:
    writer.add_image("图片中心裁剪处理", img_tensor=t(datas[0]), global_step=index)
writer.close()
#查看命令:tensorboard --logdir=./myBorderText

标签:__,常用,self,0.5,rootdir,pytorch,transforms,import
来源: https://blog.csdn.net/hh1357102/article/details/123589485