其他分享
首页 > 其他分享> > torch.nn.Unfold()详细解释

torch.nn.Unfold()详细解释

作者:互联网

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

功能:从一个批次的输入张量中提取出滑动的局部区域块。(Extracts sliding local blocks from a batched input tensor.)

参数:

输入: inputs (B, C, W, H )

B:batchsize C:channel W:width H:height

//Currently, only 4-D input tensors (batched image-like tensors) are supported.

输出: outputs (B, N, L)

N:表示每个滑块的大小,N=C×∏(kernel_size)=C*W*H

L:表示有多少个滑块,

其中,spatial_size表示输入张量的空间维度,这里spatial_size=(W, H ) ,d用来遍历这些维度,这里即为{0,1}。

import torch
import torch.nn as nn
inp = torch.tensor([[[[1.0, 2, 3, 4, 5, 6],
                      [7, 8, 9, 10, 11, 12],
                      [13, 14, 15, 16, 17, 18],
                      [19, 20, 21, 22, 23, 24],
                      [25, 26, 27, 28, 29, 30],
                      ]]])
print('inp=')
print(inp)

unfold = nn.Unfold(kernel_size=(3, 3), dilation=1, padding=0, stride=(2, 1))
inp_unf = unfold(inp)
print('inp_unf=')
print(inp_unf)

官网解释道:

Convolution = Unfold + Matrix Multiplication + Fold

>>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
>>> inp = torch.randn(1, 3, 10, 12)
>>> w = torch.randn(2, 3, 4, 5)
>>> inp_unf = torch.nn.functional.unfold(inp, (4, 5))
>>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
>>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
>>> # or equivalently (and avoiding a copy),
>>> # out = out_unf.view(1, 2, 7, 8)
>>> (torch.nn.functional.conv2d(inp, w) - out).abs().max()
tensor(1.9073e-06)

标签:Unfold,nn,torch,inp,unf,out,size
来源: https://blog.csdn.net/A3630623/article/details/120639367