其他分享
首页 > 其他分享> > Resizer Model

Resizer Model

作者:互联网

Resizer Model

class Resizer(nn.Module):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.interpolate_mode = cfg.resizer.interpolate_mode
        self.scale_factor = cfg.data.image_size / cfg.data.resizer_image_size

        n = cfg.resizer.num_kernels
        r = cfg.resizer.num_resblocks
        slope = cfg.resizer.negative_slope

        self.module1 = nn.Sequential(
            nn.Conv2d(cfg.resizer.in_channels, n, kernel_size=7, padding=3),
            nn.LeakyReLU(slope, inplace=True),
            nn.Conv2d(n, n, kernel_size=1),
            nn.LeakyReLU(slope, inplace=True),
            nn.BatchNorm2d(n)
        )

        resblocks = []
        for i in range(r):
            resblocks.append(ResBlock(n, slope))
        self.resblocks = nn.Sequential(*resblocks)

        self.module3 = nn.Sequential(
            nn.Conv2d(n, n, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(n)
        )

        self.module4 = nn.Conv2d(n, cfg.resizer.out_channels, kernel_size=7,
                                 padding=3)

        self.interpolate = partial(F.interpolate,
                                   scale_factor=self.scale_factor,
                                   mode=self.interpolate_mode,
                                   align_corners=False,
                                   recompute_scale_factor=False)

    def forward(self, x):
        residual = self.interpolate(x)

        out = self.module1(x)
        out_residual = self.interpolate(out)

        out = self.resblocks(out_residual)
        out = self.module3(out)
        out = out + out_residual

        out = self.module4(out)

        out = out + residual

        return out

标签:nn,cfg,self,interpolate,resizer,Model,Resizer,out
来源: https://www.cnblogs.com/lwp-nicol/p/15517659.html