pytorch 和 tensorflow的 upsampling 互通代码
作者:互联网
pytorch 实现上采样
点击查看代码
import numpy as np
import torch.nn.functional as F
import torch
from torch import nn
input = torch.arange(0, 12, dtype=torch.float32).view(2, 2, 3).transpose(1, 2)
# size 和 scale_factor只能二选一
sample_layer = nn.Upsample(scale_factor=2, mode='nearest')
print(input)
print(sample_layer(input).transpose(1, 2), sample_layer(input).transpose(1, 2).shape)
输出
点击查看代码
tensor([[[ 0., 3.],
[ 1., 4.],
[ 2., 5.]],
[[ 6., 9.],
[ 7., 10.],
[ 8., 11.]]])
tensor([[[ 0., 1., 2.],
[ 0., 1., 2.],
[ 3., 4., 5.],
[ 3., 4., 5.]],
[[ 6., 7., 8.],
[ 6., 7., 8.],
[ 9., 10., 11.],
[ 9., 10., 11.]]]) torch.Size([2, 4, 3])
Process finished with exit code 0
tensorflow 的实现
点击查看代码
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import UpSampling1D
#
# Arguments:
# size: Integer. Upsampling factor.
#
# Input shape:
# 3D tensor with shape: `(batch_size, steps, features)`.
#
# Output shape:
# 3D tensor with shape: `(batch_size, upsampled_steps, features)`.
input_shape = (2, 2, 3)
x = np.arange(np.prod(input_shape)).reshape(input_shape)
print(x)
# [[[ 0 1 2]
# [ 3 4 5]]
# [[ 6 7 8]
# [ 9 10 11]]]
y = tf.keras.layers.UpSampling1D(size=2)(x)
print(y)
# tf.Tensor(
# [[[ 0 1 2]
# [ 0 1 2]
# [ 3 4 5]
# [ 3 4 5]]
# [[ 6 7 8]
# [ 6 7 8]
# [ 9 10 11]
# [ 9 10 11]]], shape=(2, 4, 3), dtype=int64)
两者是完全等价的
标签:11,10,import,torch,pytorch,shape,input,tensorflow,upsampling 来源: https://www.cnblogs.com/boyknight/p/16701393.html