GFPGAN源代码分析(四)
作者:互联网
2021SC@SDUSC
一、分析的代码片段
1.代码展示
class GFPGANv1Clean(nn.Module):
"""GFPGANv1 Clean version."""
def __init__(
self,
out_size,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False):
super(GFPGANv1Clean, self).__init__()
self.input_is_latent = input_is_latent
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
'16': int(512 * unet_narrow),
'32': int(512 * unet_narrow),
'64': int(256 * channel_multiplier * unet_narrow),
'128': int(128 * channel_multiplier * unet_narrow),
'256': int(64 * channel_multiplier * unet_narrow),
'512': int(32 * channel_multiplier * unet_narrow),
'1024': int(16 * channel_multiplier * unet_narrow)
}
self.log_size = int(math.log(out_size, 2))
first_out_size = 2**(int(math.log(out_size, 2)))
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
2.代码作用分析
该类继承自nn.Module类
二、具体应用
1.channels的设置
实际调用的时候narrow=1,
channels保存了经过convolution层后的输出的通道数
2.调用torch.nn.Conv2d()搭建卷积神经网络
self.log_size = int(math.log(out_size, 2))
first_out_size = 2**(int(math.log(out_size, 2)))
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
标签:分析,narrow,int,self,unet,GFPGAN,源代码,out,size 来源: https://blog.csdn.net/qq_45969525/article/details/122170451