其他分享
首页 > 其他分享> > Thesis-CCNet: Criss-Cross Attention for Semantic Segmentation

Thesis-CCNet: Criss-Cross Attention for Semantic Segmentation

作者:互联网

Thesis-CCNet: Criss-Cross Attention for Semantic Segmentation

CCNet: Criss-Cross Attention for Semantic Segmentation

import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
from torch.nn import functional as F
from torch.nn import Softmax

model_urls = {
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}

BatchNorm2d = nn.BatchNorm2d

def INF(B,H,W):
     return -torch.diag(torch.tensor(float("inf")).cuda(1).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)


class CC_module(nn.Module):
     
    def __init__(self,in_dim):
        super(CC_module, self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = Softmax(dim=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1))
          
    def forward(self, x):
        m_batchsize, _, height, width = x.size()
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
        concate = self.softmax(torch.cat([energy_H, energy_W], 3))
        #concate = concate * (concate>torch.mean(concate,dim=3,keepdim=True)).float()
        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
        #print(concate)
        #print(att_H) 
        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
        #print(out_H.size(),out_W.size())
        return self.gamma*(out_H + out_W) + x


class RCCAModule(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes):
        super(RCCAModule, self).__init__()
        inter_channels = in_channels // 4
        self.conva = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
                                   BatchNorm2d(inter_channels),nn.ReLU(inplace=False))
        self.cca = CC_module(inter_channels)
        self.convb = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
                                   BatchNorm2d(inter_channels),nn.ReLU(inplace=False))

        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels+inter_channels, out_channels, kernel_size=3, padding=1, dilation=1, bias=False),
            BatchNorm2d(out_channels),nn.ReLU(inplace=False),
            nn.Dropout2d(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
            )

    def forward(self, x, recurrence=2):
        output = self.conva(x)
        for i in range(recurrence):
            output = self.cca(output)
        output = self.convb(output)

        output = self.bottleneck(torch.cat([x, output], 1))
        return output


# 基于F.conv2d自己建的Conv2d类,其中F.conv2d仅仅只是卷积操作,而nn.Conv2d是卷积层类
class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                 padding, dilation, groups, bias)
    def forward(self, x):
        # return super(Conv2d, self).forward(x)
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

# ResNet中的block类型,指的是1x1,3x3,1x1三种卷积混合的模式,采用先降维再升维,降低计算复杂度
class Bottleneck(nn.Module):
    expansion = 4 # 在block最后升维的倍数,恢复原来的通道数
    # 这里的planes不再是网络中的输出通道数,而是在block中降维的输出通道数
    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, conv=None, norm=None):
        super(Bottleneck, self).__init__()
        self.conv1 = conv(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = norm(planes)
        self.conv2 = conv(planes, planes, kernel_size=3, stride=stride,
                               dilation=dilation, padding=dilation, bias=False)
        self.bn2 = norm(planes)
        self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = norm(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        # 此处的downsample利用1x1卷积来改变通道数,使残差块的连接可以直接相加
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
# deeplabv3的ASPP模块
class ASPP(nn.Module):
    def __init__(self, C, depth, num_classes, conv=nn.Conv2d, norm=nn.BatchNorm2d, momentum=0.0003, mult=1):
        super(ASPP, self).__init__()
        self._C = C # 进入aspp的通道数
        self._depth = depth # filter的个数
        self._num_classes = num_classes

        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.relu = nn.ReLU(inplace=True)
        # 第一个1x1卷积
        self.aspp1 = conv(C, depth, kernel_size=1, stride=1, bias=False)
        # aspp中的空洞卷积,rate=6,12,18
        self.aspp2 = conv(C, depth, kernel_size=3, stride=1,
                               dilation=int(6*mult), padding=int(6*mult),
                               bias=False)
        self.aspp3 = conv(C, depth, kernel_size=3, stride=1,
                               dilation=int(12*mult), padding=int(12*mult),
                               bias=False)
        self.aspp4 = conv(C, depth, kernel_size=3, stride=1,
                               dilation=int(18*mult), padding=int(18*mult),
                               bias=False)
        # 对最后一个特征图进行全局平均池化,再feed给256个1x1的卷积核,都带BN
        self.aspp5 = conv(C, depth, kernel_size=1, stride=1, bias=False)
        self.aspp1_bn = norm(depth, momentum)
        self.aspp2_bn = norm(depth, momentum)
        self.aspp3_bn = norm(depth, momentum)
        self.aspp4_bn = norm(depth, momentum)
        self.aspp5_bn = norm(depth, momentum)
        # 先上采样双线性插值得到想要的维度,再进入下面的conv
        self.conv2 = conv(depth * 5, depth, kernel_size=1, stride=1,
                               bias=False)
        self.bn2 = norm(depth, momentum)
        # 打分分类
        self.conv3 = nn.Conv2d(depth, num_classes, kernel_size=1, stride=1)

    def forward(self, x):
        x1 = self.aspp1(x)
        x1 = self.aspp1_bn(x1)
        x1 = self.relu(x1)
        x2 = self.aspp2(x)
        x2 = self.aspp2_bn(x2)
        x2 = self.relu(x2)
        x3 = self.aspp3(x)
        x3 = self.aspp3_bn(x3)
        x3 = self.relu(x3)
        x4 = self.aspp4(x)
        x4 = self.aspp4_bn(x4)
        x4 = self.relu(x4)
        x5 = self.global_pooling(x)
        x5 = self.aspp5(x5)
        x5 = self.aspp5_bn(x5)
        x5 = self.relu(x5)
        # 上采样:双线性插值使x得到想要的维度
        x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
                         align_corners=True)(x5)
        # 经过aspp之后,concat之后通道数变为了5倍
        x = torch.cat((x1, x2, x3, x4, x5), 1)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)

        return x


# 基于ResNet的deeplabv3
class ResNet(nn.Module):
    def __init__(self, block, block_num, num_classes, num_groups=None, weight_std=False, beta=False, pretrained=False):
        self.inplanes = 64 # 控制残差块的输入通道数 planes:输出通道数
        # nn.BatchNorm2d和nn.GroupNorm两种不同的归一化方法
        self.norm = nn.BatchNorm2d
        self.conv = Conv2d if weight_std else nn.Conv2d
        super(ResNet, self).__init__()

        if not beta:
            # 整个ResNet的第一个conv
            self.conv1 = self.conv(3, 64, kernel_size=7, stride=2, padding=3,
                                   bias=False)
        else:
            # 第一个残差模块的conv
            self.conv1 = nn.Sequential(
                self.conv(3, 64, 3, stride=2, padding=1, bias=False),
                self.conv(64, 64, 3, stride=1, padding=1, bias=False),
                self.conv(64, 64, 3, stride=1, padding=1, bias=False))
        self.bn1 = self.norm(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 建立残差块部分
        self.layer1 = self._make_layer(block, 64,  block_num[0])
        self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)
        # block4开始为dilation空洞卷积
        self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, dilation=2)
        # ccnet模块
        self.ccnet = RCCAModule(512 * block.expansion, 512, num_classes)
        # ccnet最后融合的一个特征
        self.dsn = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
            self.norm(512),nn.ReLU(inplace=False),
            nn.Dropout2d(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
        )
        # aspp,512 * block.expansion是经过残差模块的输出通道数
        self.aspp = ASPP(512 * block.expansion, 256, num_classes, conv=self.conv, norm=self.norm)
        # 模仿aspp进行danet和aspp的cat之后,进行conv+norm等操作
        self.conv2 = self.conv(num_classes * 3, num_classes, kernel_size=1, stride=1, bias=False)
        # 遍历模型进行初始化
        for m in self.modules():
            if isinstance(m, self.conv):        #isinstance:m类型判断    若当前组件为 conv
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))  #正太分布初始化
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): #若为batchnorm
                m.weight.data.fill_(1)          #weight为1
                m.bias.data.zero_()             #bias为0

        if pretrained:
            self._load_pretrained_model()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        # stride!=1 代表后续残差块中有stride=2,尺寸大小改变,所以第一个残差块中的stride也该用来修改尺寸
        if stride != 1 or dilation != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                self.conv(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, dilation=max(1, dilation/2), bias=False),
                self.norm(planes * block.expansion),
            )
        # laysers 存放产生的残差块,最后根据此列表进行生成网络
        layers = []
        # 在多个残差块中,只有第一个残差块的输入输出通道不一致,所以先单独添加带downsample的block
        layers.append(block(self.inplanes, planes, stride, downsample, dilation=max(1, dilation/2), conv=self.conv, norm=self.norm))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation, conv=self.conv, norm=self.norm))

        return nn.Sequential(*layers)


    def forward(self, x):
        # x.shape:[batch_size, channels, H, w]
        size = (x.shape[2], x.shape[3])
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x_dsn = self.dsn(x)
        #print('x_dsn:',x_dsn.shape) x_dsn: torch.Size([4, 4, 20, 20])
        x_res = self.layer4(x)
        #print('x_res:',x_res.shape) x_res: torch.Size([4, 2048, 20, 20])
        # ASPP
        x_aspp = self.aspp(x_res)
        #print('x_aspp:',x_aspp.shape) x_aspp: torch.Size([4, 4, 20, 20])
        # ccnet
        x_ccnet_1 = self.ccnet(x_res, 2)
        #print('x_ccnet_1:',x_ccnet_1.shape) x_ccnet_1: torch.Size([4, 4, 20, 20])
        x_ccnet_2 = torch.cat([x_ccnet_1, x_dsn],1)
        #print('x_ccnet_2:',x_ccnet_2.shape) x_ccnet_2: torch.Size([4, 8, 20, 20])
        out = torch.cat((x_aspp, x_ccnet_2),1)
        #print('out cat shape', out.shape) out cat shape torch.Size([4, 12, 20, 20])
        out = self.conv2(out)
        out = nn.Upsample(size, mode='bilinear', align_corners=True)(out)
        return out
    
    def _load_pretrained_model(self):
        pretrain_dict = model_zoo.load_url(model_urls['resnet152'])
        model_dict = {}
        state_dict = self.state_dict()
        for k, v in pretrain_dict.items():
            if k in state_dict:
                model_dict[k] = v
        state_dict.update(model_dict)
        self.load_state_dict(state_dict)

# 实例化模型
def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    # [3,4,6,3]对应block_num,残差块的数量
    model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model



if __name__ == "__main__":
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    model = resnet152()
    model = model.to(device)
    x = torch.rand((4,3,320,320))
    #x = torch.tensor(x, dtype = torch.float)
    x = x.to(device)  
    print(x.shape)
    print('====================')
    output = model(x)
    print('====================')
    print(output.shape)

标签:Segmentation,Semantic,conv,nn,Criss,self,stride,size,out
来源: https://www.cnblogs.com/lwp-nicol/p/16583217.html