tensorflow2实现coordinate attention
作者:互联网
import tensorflow as tf
from tensorflow.keras.layers import (Conv2D,AvgPool2D,Input)
def CoordAtt(x, reduction = 32):
def coord_act(x):
tmpx = tf.nn.relu6(x+3) / 6
x = x * tmpx
return x
x_shape = x.get_shape().as_list()
[b, h, w, c] = x_shape
x_h = AvgPool2D(pool_size=(1, w), strides = 1)(x)
x_w = AvgPool2D(pool_size=(h, 1), strides = 1)(x)
x_w = tf.transpose(x_w, [0, 2, 1, 3])
y = tf.concat([x_h, x_w], axis=1)
mip = max(8, c // reduction)
y = Conv2D(mip, (1, 1), strides=1, activation=coord_act,name='ca_conv1')(y)
x_h, x_w = tf.split(y, num_or_size_splits=2, axis=1)
x_w = tf.transpose(x_w, [0, 2, 1, 3])
a_h = Conv2D(c, (1, 1), strides=1,activation=tf.nn.sigmoid,name='ca_conv2')(x_h)
a_w = Conv2D(c, (1, 1), strides=1,activation=tf.nn.sigmoid,name='ca_conv3')(x_w)
out = x * a_h * a_w
return out
if __name__ == '__main__':
inputs = Input(shape=(224,224,3))
outputs = CoordAtt(inputs)
print(outputs.shape)
标签:tensorflow2,__,name,attention,strides,shape,tf,coordinate,Conv2D 来源: https://blog.csdn.net/sdhdsf132452/article/details/122821992