Coordinate Attention +resnet+pytorch实现
作者:互联网
# CA (coordinate attention)
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torchsummary import summary
import torch.utils.model_zoo as model_zoo
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
import torch
import torch.nn as nn
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
def forward(self, x):
return self.relu(x + 3) / 6
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)
def forward(self, x):
return x * self.sigmoid(x)
class CoordAttention(nn.Module):
def __init__(self, in_channels, out_channels, reduction=32):
super(CoordAttention, self).__init__()
self.pool_w, self.pool_h = nn.AdaptiveAvgPool2d((1, None)), nn.AdaptiveAvgPool2d((None, 1))
temp_c = max(8, in_channels // reduction)
self.conv1 = nn.Conv2d(in_channels, temp_c, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(temp_c)
self.act1 = h_swish()
self.conv2 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0)
self.conv3 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
short = x
n, c, H, W = x.shape
x_h, x_w = self.pool_h(x), self.pool_w(x).permute(0, 1, 3, 2)
x_cat = torch.cat([x_h, x_w], dim=2)
out = self.act1(self.bn1(self.conv1(x_cat)))
x_h, x_w = torch.split(out, [H, W], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
out_h = torch.sigmoid(self.conv2(x_h))
out_w = torch.sigmoid(self.conv3(x_w))
return short * out_w * out_h
# 搭建CA_ResNet34
class BottleneckBlock(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1,
norm_layer=None):
super(BottleneckBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False)
self.bn1 = norm_layer(width)
self.conv2 = nn.Conv2d(width, width, 3, padding=dilation, stride=stride, groups=groups, dilation=dilation,
bias=False)
self.bn2 = norm_layer(width)
self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
self.ca = CoordAttention(in_channels=planes * self.expansion, out_channels=planes * self.expansion)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.ca(out) # add CA
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, depth, n_class=1000, with_pool=True):
super(ResNet, self).__init__()
layer_cfg = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3]
}
layers = layer_cfg[depth]
self.num_classes = n_class
self.with_pool = with_pool
self._norm_layer = nn.BatchNorm2d
self.inplanes = 64
self.dilation = 1
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = self._norm_layer(self.inplanes)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
if with_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
if n_class > 0:
self.fc = nn.Linear(512 * block.expansion, n_class)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride=stride, bias=False),
norm_layer(planes * block.expansion), )
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, 1, 64, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.with_pool:
x = self.avgpool(x)
if self.num_classes > 0:
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def ca_resnet34(**kwargs):
return ResNet(BottleneckBlock, 34, **kwargs)
def resnet_CA_instance(n_class, pretrained=False, **kwargs): # resnet34的模型
model = ResNet(BottleneckBlock, 34, n_class, **kwargs)
if pretrained:
pretrained_dict = model_zoo.load_url(model_urls['resnet34'])
model_dict = model.state_dict()
# 筛除不加载的层结构
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新当前网络的结构字典
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_class) # 15 output classes
stdv = 1.0 / math.sqrt(1000)
for p in model.fc.parameters():
p.data.uniform_(-stdv, stdv)
return model
# 利用高阶 API 查看模型
ca_res34 = ca_resnet34(n_class=15)
print(ca_res34)
x = torch.rand(1, 3, 224, 224)
i = ca_res34(x)
print(i.shape)
summary(ca_res34, (3, 224, 224))
引用请附属作者名:叫我小张就行了
标签:__,layer,nn,self,Attention,stride,pytorch,Coordinate,out 来源: https://blog.csdn.net/qq_37278761/article/details/117249138