Pytorch常用工具箱
作者:互联网
神经网络工具箱nn
import torch.nn as nn
在nn中主要有两个重要模块:nn.Model、nn.functional,接着将分别介绍这两个模块。
nn.Model
nn.Model是nn的一个核心数据结构,最常用的做法就是继承nn.Model,如 class Nets(nn.Model),常用的全连接层、损失层、激活层、卷积层等都是nn.Model的子类,nn.Linear、nn.Conv2d等
nn.functional
import torch.nn.functional as F
性能方面与nn.Model有一些差异,但此处不做描述。调用常用的层时用nn.functional.xxx,如nn.functional.linear、nn.functional.conv2d。
utils.data
utils.data 主要包括4个类
(1)Dataset:是一个抽象类,其他数据需要继承这个类,并要覆写其中的两个方法(getitem、len)
(2)DataLoader:定义了一个新的迭代器,实现批量(batch)读取,打乱数据(shuffle)等
from pytorch.utils.data import DataLoader
(3)random_split:把数据集随机拆分为给定长度的非重叠的新数据集。
(4)*sampler:多采样函数。
Torchvision
torchvision主要包括4个类
(1)datasets:常用数据集的加载,如MMIST,CIFAR10,设计上继承于torch.utils.data.Dataset
(2)models:提供经典的网络结构以及训练好的模型;
(3)transforms:常用的数据预处理操作,主要包括Tensor及PIL Image对象的操作;
(4)utils:包括两个函数,一个是make_grid,主要是将多张图片拼接在一起,一个是save_img,能将tensor保存成图片。
标签:常用,nn,utils,functional,Pytorch,Model,data,工具箱 来源: https://blog.csdn.net/weixin_42601276/article/details/114577276