其他分享
首页 > 其他分享> > GFPGAN源代码分析(四)

GFPGAN源代码分析(四)

作者:互联网

2021SC@SDUSC

一、分析的代码片段

1.代码展示

class GFPGANv1Clean(nn.Module):
    """GFPGANv1 Clean version."""

    def __init__(
            self,
            out_size,
            num_style_feat=512,
            channel_multiplier=1,
            decoder_load_path=None,
            fix_decoder=True,
            # for stylegan decoder
            num_mlp=8,
            input_is_latent=False,
            different_w=False,
            narrow=1,
            sft_half=False):

        super(GFPGANv1Clean, self).__init__()
        self.input_is_latent = input_is_latent
        self.different_w = different_w
        self.num_style_feat = num_style_feat

        unet_narrow = narrow * 0.5
        channels = {
            '4': int(512 * unet_narrow),
            '8': int(512 * unet_narrow),
            '16': int(512 * unet_narrow),
            '32': int(512 * unet_narrow),
            '64': int(256 * channel_multiplier * unet_narrow),
            '128': int(128 * channel_multiplier * unet_narrow),
            '256': int(64 * channel_multiplier * unet_narrow),
            '512': int(32 * channel_multiplier * unet_narrow),
            '1024': int(16 * channel_multiplier * unet_narrow)
        }

        self.log_size = int(math.log(out_size, 2))
        first_out_size = 2**(int(math.log(out_size, 2)))

        self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)

2.代码作用分析

该类继承自nn.Module类

二、具体应用

1.channels的设置

实际调用的时候narrow=1,

channels保存了经过convolution层后的输出的通道数

2.调用torch.nn.Conv2d()搭建卷积神经网络

    self.log_size = int(math.log(out_size, 2))
    first_out_size = 2**(int(math.log(out_size, 2)))
    self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)

标签:分析,narrow,int,self,unet,GFPGAN,源代码,out,size
来源: https://blog.csdn.net/qq_45969525/article/details/122170451