Seven---pytorch学习---维度变换
作者:互联网
## pytorch学习(4)
### 维度变换
- view & reshape
- squeeze & unsqueeze
- transpose & permute
- expand & repeat
- contiguous
#### view & reshape
> view() 与 reshape() 的区别
- view() 只适用于满足连续性条件的tensor,且不会开辟新的内存空间
- reshape() 的返回值既可以是视图,也可以是副本,当曼珠连续性条件时返回 view() ,否则返回副本;且使用 reshape() 时,会开辟新的内存空间
- 故当不确定能否使用 view() 时,可以使用 reshape()
##### view()函数
```python
import torch
a = torch.rand(4,3,32,32) #维度为4
b = a.view(4,3,32*32) #维度为3
c = a.view(4,-1) #维度为2,使用-1可进行省略缩进
print(a.shape)
print(b.shape)
print(c.shape)
-------------------------------------------------------
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 1024])
torch.Size([4, 3072])
```
##### reshape()函数
```python
import torch
a = torch.rand(4,3,32,32)
b = a.reshape(4,3,-1)
c = a.reshape(4,-1)
print(b.shape)
print(c.shape)
-------------------------------------------------------
torch.Size([4, 3, 1024])
torch.Size([4, 3072])
```
#### squeeze & unsqueeze
> squeeze & unsqueeze 的功能时维度的减少 / 增加
##### squeeze()函数
- torch.squeeze(input,dim) 返回一个tensor
- 当dim不设置时,去掉input的所有为1的维度
- 当dim为整数时(0 <= dim < input.dim()),判断dim是否为1,若是则删去,否则不变
- dim的取值范围:[-input.dim(),input.dim()-1 ]
```python
import torch
a = torch.rand(1,1,2,2,5)
print(a.squeeze().shape)
print(a.squeeze(0).shape)
print(a.squeeze(1).shape)
print(a.squeeze(2).shape)
print(a.squeeze(-1).shape)
-------------------------------------------------------
torch.Size([2, 2, 5])
torch.Size([1, 2, 2, 5])
torch.Size([1, 2, 2, 5])
torch.Size([1, 1, 2, 2, 5])
torch.Size([1, 1, 2, 2, 5])
```
##### unsqueeze()函数
- torch.unsqueeze(input,dim): 对 input 数据维度进行扩充,在 dim 维插入一个 dim = 1 的维度
- dim 是必选参数
- dim 的取值范围:[-input.dim()-1,input.dim()]
```python
import torch
a = torch.rand(1,1,2,2,5)
print(a.unsqueeze(0).shape)
print(a.unsqueeze(2).shape)
print(a.unsqueeze(3).shape)
print(a.unsqueeze(-1).shape)
-------------------------------------------------------
torch.Size([1, 1, 1, 2, 2, 5])
torch.Size([1, 1, 1, 2, 2, 5])
torch.Size([1, 1, 2, 1, 2, 5])
torch.Size([1, 1, 2, 2, 5, 1])
```
#### transpose & permute
##### transpose()函数
- transpose() 函数只能有两个相关的交换的位置参数
- 使用transpose() 函数的时候,tensor 自身不会改变,因此需要将结果重新赋值
```python
import torch
a = torch.rand(2,3,32,32)
print(a.transpose(0,3).shape)
print(a.transpose(1,2).shape)
-------------------------------------------------------
torch.Size([32, 3, 32, 2])
torch.Size([2, 32, 3, 32])
```
##### permute()函数
- permute() 函数可以一次操作多维数据,且必须传入所有维度数
- 对一个高维的 tensor 执行 permute,并没有改变数据的相对位置,只是旋转了一下这个立方体(或者说改变了对这个立方体的视觉角度)
```python
import torch
a = torch.rand(2,3,32,32)
print(a.permute(0,3,2,1).shape)
print(a.permute(1,3,2,0).shape)
-------------------------------------------------------
torch.Size([2, 32, 32, 3])
torch.Size([3, 32, 32, 2])
```
#### expand & repeat
##### expand()函数
- torch.expand() 返回 tensor 的一个视图,单个维度扩大为更大的维度,或者在第0维新增加一个维度来扩大为更高维
- 如果哪个维度为 -1,就是该维度不变
- 使用expand时,不会创建新的内存地址
##### repeat()函数
- torch.repeat() 会将 tensor 在指定的维度方向上进行重复,参数表示在不同的维度上重复的次数
- 使用repeat时,会重新赋值,重新创建新的内存地址进行占用
```python
import torch
a = torch.rand(1,1,32,32)
print('a的数据是:',a.shape)
print('a的地址为:',a.data_ptr())
b = a.expand(1,3,32,32)
print('b的数据是:',b.shape)
print('b的地址为:',b.data_ptr())
c = a.repeat(1,3,1,1)
print('c的数据是:',c.shape)
print('c的地址为:',c.data_ptr())
d = a.repeat(1,3,2,2)
print('d的数据是:',d.shape)
print('d的地址为:',d.data_ptr())
-------------------------------------------------------
a的数据是: torch.Size([1, 1, 32, 32])
a的地址为: 2191754885568
b的数据是: torch.Size([1, 3, 32, 32])
b的地址为: 2191754885568 #expand执行后的内存地址仍为原a的地址
c的数据是: torch.Size([1, 3, 32, 32])
c的地址为: 2191823748608 #repeat执行后的内存地址发生了改变
d的数据是: torch.Size([1, 3, 64, 64])
d的地址为: 2191823870400
```
#### contiguous
[torch.contiguous方法讲解(这篇blog很通俗便于理解)]: https://blog.csdn.net/qq_37828380/article/details/107855070
标签:Seven,32,torch,shape,---,维度,pytorch,print,Size 来源: https://www.cnblogs.com/311dih/p/16583853.html