其他分享
首页 > 其他分享> > 【详解】einops 优美处理张量维度

【详解】einops 优美处理张量维度

作者:互联网

在这里插入图片描述

einops:Deep learning operations reinvented (for pytorch, tensorflow, jax and others)


Einops works with …


einops has a minimalistic yet powerful API. einops主要是 rearrange, reduce, repeat 这3个方法,(einops教程显示了覆盖堆叠、整形、换位、挤压/解压、重复、平铺、级联、视图和无数的缩减)

from einops import rearrange, reduce, repeat

# rearrange elements according to the pattern
output_tensor = rearrange(input_tensor, 't b c -> b c t')
# combine rearrangement and reduction
output_tensor = reduce(input_tensor, 'b c (h h2) (w w2) -> b h w c', 'mean', h2=2, w2=2)
# copy along a new axis 
output_tensor = repeat(input_tensor, 'h w -> h w c', c=3)

einops关注的是接口:关注输入和输出是什么,而不是如何计算输出(见下方代码)。第三行给读者一个提示:这不是我们正在处理的独立的一批图像,而是一个序列(视频)。einops 使得代码更易于阅读和维护。

y = x.view(x.shape[0], -1)
y = rearrange(x, 'b c h w -> b (c h w)')
y = rearrange(x, 'time c h w -> time (c h w)')

einops 基础

import numpy as np
from einops import rearrange, reduce, repeat

ims = np.load('test_images.npy', allow_pickle=False)
# There are 6 images of shape 96x96 with 3 color channels packed into tensor
print(ims.shape, ims.dtype)

(6, 96, 96, 3) float64

test_images.npy 可以类比成 一个 batch 的图像,batchsize=6,图像尺寸为 96x96,3 个channel。接下来的所有 einops操作都是对其进行处理。

einops主要是rearrange, reduce, repeat这3个方法,接下来将逐一解释。


1、rearrange:维度调整

rearrange(ims[0], 'h w c -> w h c')		# 调换维度

rearrange(ims, 'b h w c -> (b h) w c')  # 合并维度

# or compose a new dimension of batch and width
rearrange(ims, 'b h w c -> h (b w) c')
rearrange(ims, 'b h w c -> h (b w) c').shape

[6, 96, 96, 3] -> [96, (6 * 96), 3]

Decomposition of axis:

# decomposition is the inverse process - represent an axis as a combination of new axes
# several decompositions possible, so b1=2 is to decompose 6 to b1=2 and b2=3
rearrange(ims, '(b1 b2) h w c -> b1 b2 h w c ', b1=2).shape

(2, 3, 96, 96, 3)

此处的括号 (b1 b2) 将会按照 b1=2 b2自动计算,然后对原 b轴 进行拆解


2、reduce:

x.mean(-1)
reduce(x, 'b h w c -> b h w', 'mean')

# average over batch
reduce(ims, 'b h w c -> h w c', 'mean')



# this is mean-pooling with 2x2 kernel
# image is split into 2x2 patches, each patch is averaged
# 变小了
reduce(ims, 'b (h h2) (w w2) c -> h (b w) c', 'mean', h2=2, w2=2)

# yet another example. Can you compute result shape?
reduce(ims, '(b1 b2) h w c -> (b2 h) (b1 w)', 'mean', b1=2)

3、Addition or removal of axes:

x = rearrange(ims, 'b h w c -> b 1 h w 1 c') # functionality of numpy.expand_dims
print(x.shape)
print(rearrange(x, 'b 1 h w 1 c -> b h w c').shape) # functionality of numpy.squeeze
(6, 1, 96, 96, 1, 3) 
(6, 96, 96, 3)


# compute max in each image individually, then show a difference 
x = reduce(ims, 'b h w c -> b () () c', 'max') - ims
rearrange(x, 'b h w c -> h (b w) c')

4、Reductions

Simple global average pooling:

y = reduce(x, 'b c h w -> b c', reduction='mean')
y.shape
(torch.Tensor, torch.Size([10, 32]))

max-poolingwith a kernel 2x2:

y = reduce(x, 'b c (h h1) (w w1) -> b c h w', reduction='max', h1=2, w1=2)
y.shape
(torch.Tensor, torch.Size([10, 32, 50, 100]))

Squeeze and unsqueeze (expand_dims):

# models typically work only with batches, 
# so to predict a single image ...
image = rearrange(x[0, :3], 'c h w -> h w c')
# ... create a dummy 1-element axis ...
y = rearrange(image, 'h w c -> () c h w')
# ... imagine you predicted this with a convolutional network for classification,
# we'll just flatten axes ...
predictions = rearrange(y, 'b c h w -> b (c h w)')
# ... finally, decompose (remove) dummy axis
predictions = rearrange(predictions, '() classes -> classes')

per-channel mean-normalization foreach image:

y = x - reduce(x, 'b c h w -> b c 1 1', 'mean')
y.shape
(torch.Tensor, torch.Size([10, 32, 100, 200]))

per-channel mean-normalization forwhole batch:

y = x - reduce(y, 'b c h w -> 1 c 1 1', 'mean')
y.shape
(torch.Tensor, torch.Size([10, 32, 100, 200]))

5、Concatenation:

concatenate over the first dimension:

tensors = rearrange(list_of_tensors, 'b c h w -> (b h) w c')
tensors.shape
(torch.Tensor, torch.Size([1000, 200, 32]))

6、Shuffling within a dimension:

channel shuffle:

y = rearrange(x, 'b (g1 g2 c) h w-> b (g2 g1 c) h w', g1=4, g2=4)
y.shape
(torch.Tensor, torch.Size([10, 32, 100, 200]))

标签:reduce,张量,rearrange,shape,ims,einops,维度,mean
来源: https://blog.csdn.net/ViatorSun/article/details/116010049