其他分享
首页 > 其他分享> > 样式迁移---在网络搭建好之后,训练的是图片

样式迁移---在网络搭建好之后,训练的是图片

作者:互联网

基于CNN的样式迁移

在这里插入图片描述
$$$$

pretrained_net = torchvision.models.vgg19(pretrained=True)
style_layers, content_layers = [0, 5, 10, 19, 28], [25]# 样式(局部全局都有) 内容(取上层)
net = nn.Sequential(*[pretrained_net.features[i] for i in
                      range(max(content_layers + style_layers) + 1)])# 只取28层,如果后边还有,直接丢掉
         
# 按层数抽取特征
def extract_features(X, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        X = net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents, styles

# 定义LOSS
# 内容 Y_hat生成图片 Y 真实图片
def content_loss(Y_hat, Y):
    # 我们从动态计算梯度的树中分离目标:
    # 这是一个规定的值,而不是一个变量。
    return torch.square(Y_hat - Y.detach()).mean()


# 样式  如何匹配样式?两张图片的样式一样?n统计信息一样,一阶(就是均值)默认都为零。下面计算方差
def gram(X):
    num_channels, n = X.shape[1], X.numel() // X.shape[1]
    X = X.reshape((num_channels, n))
    return torch.matmul(X, X.T) / (num_channels * n)

def style_loss(Y_hat, gram_Y):
    return torch.square(gram(Y_hat) - gram_Y.detach()).mean()


# total varence  tv降噪
def tv_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
                  torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())



content_weight, style_weight, tv_weight = 1, 1e3, 10
def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # 分别计算内容损失、样式损失和总变差损失
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # 对所有损失求和
    l = sum(10 * styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l

为了抽取图像的内容特征和样式特征,我们可以选择VGG网络中某些层的输出。 一般来说,越靠近输入层,越容易抽取图像的细节信息;反之,则越容易抽取图像的全局信息。 为了避免合成图像过多保留内容图像的细节,我们选择VGG较靠近输出的层,即内容层,来输出图像的内容特征。 我们还从VGG中选择不同层的输出来匹配局部和全局的样式,这些图层也称为样式层。 正如 7.2节 中所介绍的,VGG网络使用了5个卷积块。 实验中,我们选择第四卷积块的最后一个卷积层作为内容层,选择每个卷积块的第一个卷积层作为样式层。 这些层的索引可以通过打印pretrained_net实例获取。……

# 在nn中定义一个矩阵,当然可以用某个图片替换。而不用随机初始化,主要是如此定义以后就能算梯度了
class SynthesizedImage(nn.Module):
    def __init__(self, img_shape, **kwargs):
        super(SynthesizedImage, self).__init__(**kwargs)
        self.weight = nn.Parameter(torch.rand(*img_shape))

    def forward(self):
        return self.weight
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
   
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat, styles_Y_hat = extract_features(
            X, content_layers, style_layers)                            # 每次抽取X的特征
        contents_l, styles_l, tv_l, l = compute_loss(
            X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) # 计算LOSS
        
        l.backward()                                                    # 反向传播求梯度
        trainer.step()                                                  # 下降梯度数更新
        scheduler.step()
        
    return X

标签:styles,layers,style,样式,---,content,hat,contents,搭建
来源: https://blog.csdn.net/ResumeProject/article/details/120608135