其他分享
首页 > 其他分享> > TensorFlow2子类模型多输入多输出

TensorFlow2子类模型多输入多输出

作者:互联网

在最近的一次项目中,因为需要模型具有多输入多输出,而且我的一个输出是一个包含张量的列表,所以无法使用函数式API或者容器去造模型,因为列表的添加操作不是一个层,而这两类的输出必须是层的结果,虽然可以用tf.keras.layers.Lambda将此操作变成层,但总归是牵强的,所以使用子类模型。

class Test(keras.Model):
    def __init__(self):
        super(Test, self).__init__()
        filters = 64
        initializer = tf.random_normal_initializer(0., 0.02)
        self.conv1 = Conv2D(filters, 4, 2, 'same', use_bias=False, 
                                              kernel_initializer=initializer)
        self.bn1 = BatchNormalization()
        self.conv2 = Conv2D(filters*2, 4, 2, 'same', use_bias=False, 
                                              kernel_initializer=initializer)
        self.bn2 = BatchNormalization()
        self.conv3 = Conv2D(filters*4, 4, 2, 'same', use_bias=False, 
                                              kernel_initializer=initializer)
        self.bn3 = BatchNormalization()

    def call(self, inputs):
        x1 = inputs[0]
        x2 = inputs[1]
        skips = []      # 存结果的列表
        x1_1 = tf.nn.relu(self.bn1(self.conv1(x1)))
        x2_1 = tf.nn.relu(self.bn1(self.conv1(x2)))
        skips.append(x1_1)

        x1_2 = tf.nn.relu(self.bn1(self.conv1(x1_1)))
        x2_2 = tf.nn.relu(self.bn1(self.conv1(x2_1)))
        skips.append(x1_2)

        x1_3 = tf.nn.relu(self.bn1(self.conv1(x1_2)))
        x2_3 = tf.nn.relu(self.bn1(self.conv1(x2_2)))
        skips.append(x1_3)

        return [skips, x2_3]


model = Test()
model.build(input_shape=[(batch_size, data_size), (batch_size, data_size)])
input1 = tf.random.normal([batch_size, data_size])
input2 = tf.random.normal([batch_size, data_size])
out_put1, out_put2 = model([input1, input2])

TF2用着真的是太难受了,网上的教程都比较泛,对一些细节的处理实例太难找了,找着了还大概率是tf.compat.v1。。。  做完这次我真的好好去看看Torch了。。。

另外,在TF2的图执行模式里,是无法使用for等循环的,但有专门的库函数tf.while_loop,反正我还不怎么会用,或者直接可以转eager模式就可以解决。

我的问题可能在一些大佬看来很低级,但确实给我造成了麻烦,我本以为教程上的东西就能解决一切问题的了,还是太弱。若要朋友想指正我的说法或者想要交流TF2里的坑,请私信我。

标签:TensorFlow2,子类,self,initializer,x2,tf,x1,输入,size
来源: https://blog.csdn.net/ReichQin/article/details/120392116