残差加se块pytorch实现
作者:互联网
class Residual(nn.Module):
def __init__(self,in_channels,out_channels,use_1x1conv=False,stride=1):
super(Residual,self).__init__()
self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,stride=stride)
self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
if use_1x1conv:
self.conv3=nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride)
else:
self.conv3=None
self.bn1=nn.BatchNorm2d(out_channels)
self.bn2=nn.BatchNorm2d(out_channels)
self.avg_pool=nn.AdaptiveAvgPool2d(1)
self.fc=nn.Sequential(nn.Linear(out_channels,out_channels,bias=False),
nn.ReLU(inplace=True),
nn.Linear(out_channels,out_channels,bias=False),
nn.Sigmoid())
def forward(self,X): Y=F.relu(self.bn1(self.conv1(X))) Y=self.bn2(self.conv2(Y)) if self.conv3: X=self.conv3(X) b,c,_,_=Y.size() y=self.avg_pool(Y).view(b,c) y=self.fc(y).view(b,c,1,1) Y=y.expand_as(Y)*Y return F.relu(Y+X)
def forward(self,X): Y=F.relu(self.bn1(self.conv1(X))) Y=self.bn2(self.conv2(Y)) if self.conv3: X=self.conv3(X) b,c,_,_=Y.size() y=self.avg_pool(Y).view(b,c) y=self.fc(y).view(b,c,1,1) Y=y.expand_as(Y)*Y return F.relu(Y+X)
标签:conv3,nn,self,残差,channels,stride,pytorch,se,out 来源: https://www.cnblogs.com/hahaah/p/15813322.html