torch中的mask:masked_fill, masked_select, masked_scatter
作者:互联网
1. 简介
pytorch提供mask机制用来提取数据中“感兴趣”的部分。过程如下:左边的矩阵是原数据,中间的mask是遮罩矩阵,标记为1的表明对这个位置的数据“感兴趣”-保留,反之舍弃。整个过程可以视作是在原数据上盖了一层mask,只有感兴趣的部分(值为1)显露出来,而其他部分则背遮住。(matlab中也有mask操作)
mask为一个和元数据size相匹配的tensor-bool,相匹配: broadcastable-广播机制。如一个2*3*3的原数据可以由一个3*3的mask来提取。
mask一般是先建立0/1矩阵,然后通过tensor.bool()来转为bool类型的tensor,其他true表示原数据被遮住或者被选中,false表示原数据没有被遮住或者未被选中:这句话在下面的演示中更容易理解。
2. 程序演示
这里涉及的是torch中的三个常见mask函数:masked_fill, masked_select, masked_scatter。
先构造好input和mask矩阵:
imgs = torch.randint(0, 255, [2, 3, 3], dtype=torch.float32) """ tensor([[[182., 242., 11.], [163., 92., 183.], [222., 54., 86.]], [[157., 139., 254.], [158., 148., 46.], [ 1., 13., 56.]]]) """ mask = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).bool() """ tensor([[ True, False, False], [False, True, False], [False, False, True]]) """
1)torch.masked_fill(input, mask, value)
参数:
- input:输入的原数据
- mask:遮罩矩阵
- value:被“遮住的”部分填充的数据,可以取0、1等值,数据类型不限,int、float均可
返回值:一个和input相同size的masked-tensor
使用:
- output = torch.masked_fill(input, mask, value)
- output = input.masked_fill(mask, value)
imgs_masked = torch.masked_fill(input=imgs, mask=~mask, value=0) # 这里mask取反:true表示被“遮住的” """ tensor([[[182., 0., 0.], [ 0., 92., 0.], [ 0., 0., 86.]], [[157., 0., 0.], [ 0., 148., 0.], [ 0., 0., 56.]]]) """
2)torch.masked_select(input, mask, out)
参数:
- input:输入的原数据
- mask:遮罩矩阵
- out:输出的结果,和原tensor不共用内存,一般在左侧接收,而不在形参中赋值
返回值:一维tensor,数据为“选中”的数据
使用:
- torch.masked_select(input, mask, out)
- output = input.masked_select(mask)
selected_ele = torch.masked_select(input=imgs, mask=mask) # true表示selected,false则未选中,所以这里没有取反
# tensor([182., 92., 86., 157., 148., 56.])
3)torch.masked_scatter(input, mask, source)
说明:将从input中mask得到的数据赋值到source-tensor中
参数:
- input:输入的原数据
- mask:遮罩矩阵
- source:遮罩矩阵的”样子“(全零还是全一或是其他),true表示遮住了
返回值:一个和source相同size的masked-tensor
使用:
- output = torch.masked_scatter(input, mask, source)
- output = input.masked_scatter(mask, source)
source = torch.zeros_like(imgs) imgs_masked_copied = torch.masked_scatter(input=imgs, mask=~mask, source=source) """ tensor([[[173., 0., 0.], [ 0., 77., 0.], [ 0., 0., 159.]], [[ 85., 0., 0.], [ 0., 184., 0.], [ 0., 0., 223.]]]) """
3. 参考链接
PyTorch documentation — PyTorch 1.11.0 documentation
标签:tensor,torch,mask,source,masked,input 来源: https://www.cnblogs.com/YuanShiRenY/p/torch_mask.html