其他分享
首页 > 其他分享> > 【Pytorch】DCGAN实战(三):二次元动漫头像生成

【Pytorch】DCGAN实战(三):二次元动漫头像生成

作者:互联网

文章目录


1.实现效果

使用DCGAN训练faces数据集,最终实现生成二次元动漫头像。
最后虽然生成了动漫头像,但是一些细节还是和真实的图像差别较大,比如说眼睛大小,眼睛颜色等。
之后我会将MINIST数据集、Oxford17数据集、以及faces数据集在训练过程中不同轮次的输出结果做一个总结。
生成二次元动漫头像的程序依然是沿用data.py、model.py、net.py、main.py但具体的编程的细节呢有所改变。
之前MINIST以及Oxford17数据集的程序
这里:
【Pytorch】DCGAN实战(一):基于MINIST数据集的手写数字生成
【Pytorch】DCGAN实战(二):基于Oxord17的鲜花图像生成

2.环境配置

2.1Python

Python版本为3.7

2.2Pytorch、CUDA

在这里不详细介绍了,网上有很多的安装教程,小伙伴们自行查找吧!

2.3Python IDE

Pycharm

3.具体实现

整体分为4个文件:data.py、model.py、net.py、main.py

3.1数据预处理(data.py)

(1)导入包

from torch.utils.data import DataLoader
from torchvision import utils, datasets, transforms

(2)定义数据类

class ReadData():
    def __init__(self,data_path,image_size=64):
        self.root=data_path
        self.image_size=image_size
        self.dataset=self.getdataset()
    def getdataset(self):
        #3.dataset
        dataset = datasets.ImageFolder(root=self.root,
                                   transform=transforms.Compose([
                                       transforms.Resize(self.image_size),
                                       transforms.CenterCrop(self.image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]))

        print(f'Total Size of Dataset: {len(dataset)}')
        return dataset

    def getdataloader(self,batch_size=128):
        dataloader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0)
        return dataloader

3.2模型Generator,Discriminator,权重初始化(model.py)

(1)导入包

import torch.nn as nn

(2)Generator

class Generator(nn.Module):
    def __init__(self, nz,ngf,nc):
        super(Generator, self).__init__()
        self.nz = nz
        self.ngf = ngf
        self.nc=nc

        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(self.nz, self.ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(self.ngf, self.nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

(3)Discriminator

class Discriminator(nn.Module):
    def __init__(self, ndf,nc):
        super(Discriminator, self).__init__()
        self.ndf=ndf
        self.nc=nc
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(self.nc, self.ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
            # state size. (1) x 1 x 1
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

(4)权重初始化

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

3.3网络训练(net.py)

(1)导入包

import torch
import torch.nn as nn
from torchvision import utils, datasets, transforms
import time
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import os

(2)创建类

class DCGAN():
    def __init__(self,lr,beta1,nz, batch_size,num_showimage,device, model_save_path,figure_save_path,generator, discriminator, data_loader,):
        self.real_label=1
        self.fake_label=0
        self.nz=nz
        self.batch_size=batch_size
        self.num_showimage=num_showimage
        self.device = device
        self.model_save_path=model_save_path
        self.figure_save_path=figure_save_path

        self.G = generator.to(device)
        self.D = discriminator.to(device)
        self.opt_G=torch.optim.Adam(self.G.parameters(), lr=lr, betas=(beta1, 0.999))
        self.opt_D = torch.optim.Adam(self.D.parameters(), lr=lr, betas=(beta1, 0.999))
        self.criterion = nn.BCELoss().to(device)

        self.dataloader=data_loader
        self.fixed_noise = torch.randn(self.num_showimage, nz, 1, 1, device=device)

        self.img_list = []
        self.G_loss_list = []
        self.D_loss_list = []
        self.D_x_list = []
        self.D_z_list = []

    def train(self,num_epochs):
        loss_tep = 10
        G_loss=0
        D_loss=0
        print("Starting Training Loop...")
        # For each epoch
        for epoch in range(num_epochs):
        #**********计时*********************
            beg_time = time.time()
            # For each batch in the dataloader
            for i, data in enumerate(self.dataloader):
                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                x = data[0].to(self.device)
                b_size = x.size(0)
                lbx = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device)
                D_x = self.D(x).view(-1)
                LossD_x = self.criterion(D_x, lbx)
                D_x_item = D_x.mean().item()
                # print("log(D(x))")

                z = torch.randn(b_size, self.nz, 1, 1, device=self.device)
                gz = self.G(z)
                lbz1 = torch.full((b_size,), self.fake_label, dtype=torch.float, device=self.device)
                D_gz1 = self.D(gz.detach()).view(-1)
                LossD_gz1 = self.criterion(D_gz1, lbz1)
                D_gz1_item = D_gz1.mean().item()
                # print("log(1 - D(G(z)))")

                LossD = LossD_x + LossD_gz1
                # print("log(D(x)) + log(1 - D(G(z)))")

                self.opt_D.zero_grad()
                LossD.backward()
                self.opt_D.step()
                # print("update LossD")
                D_loss+=LossD

                ############################
                # (2) Update G network: maximize log(D(G(z)))
                ###########################
                lbz2 = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device) # fake labels are real for generator cost
                D_gz2 = self.D(gz).view(-1)
                D_gz2_item = D_gz2.mean().item()
                LossG = self.criterion(D_gz2, lbz2)
                # print("log(D(G(z)))")

                self.opt_G.zero_grad()
                LossG.backward()
                self.opt_G.step()
                # print("update LossG")
                G_loss+=LossG

                end_time = time.time()
            # **********计时*********************
                run_time = round(end_time - beg_time)
                # print('lalala')
                print(
                    f'Epoch: [{epoch + 1:0>{len(str(num_epochs))}}/{num_epochs}]',
                    f'Step: [{i + 1:0>{len(str(len(self.dataloader)))}}/{len(self.dataloader)}]',
                    f'Loss-D: {LossD.item():.4f}',
                    f'Loss-G: {LossG.item():.4f}',
                    f'D(x): {D_x_item:.4f}',
                    f'D(G(z)): [{D_gz1_item:.4f}/{D_gz2_item:.4f}]',
                    f'Time: {run_time}s',
                    end='\r\n'
                )
                # print("lalalal2")

                # Save Losses for plotting later
                self.G_loss_list.append(LossG.item())
                self.D_loss_list.append(LossD.item())

                # Save D(X) and D(G(z)) for plotting later
                self.D_x_list.append(D_x_item)
                self.D_z_list.append(D_gz2_item)

                # # Save the Best Model
                # if LossG < loss_tep:
                #     torch.save(self.G.state_dict(), 'model.pt')
                #     loss_tep = LossG
            if not os.path.exists(self.model_save_path):
                os.makedirs(self.model_save_path)

            torch.save(self.D.state_dict(), self.model_save_path + 'disc_{}.pth'.format(epoch))
            torch.save(self.G.state_dict(), self.model_save_path + 'gen_{}.pth'.format(epoch))
                # Check how the generator is doing by saving G's output on fixed_noise
            with torch.no_grad():
                fake = self.G(self.fixed_noise).detach().cpu()
                
            self.img_list.append(utils.make_grid(fake * 0.5 + 0.5, nrow=10))
            print()

        if not os.path.exists(self.figure_save_path):
            os.makedirs(self.figure_save_path)
        plt.figure(1,figsize=(8, 4))
        plt.title("Generator and Discriminator Loss During Training")
        plt.plot(self.G_loss_list[::10], label="G")
        plt.plot(self.D_loss_list[::10], label="D")
        plt.xlabel("iterations")
        plt.ylabel("Loss")
        plt.axhline(y=0, label="0", c="g")  # asymptote
        plt.legend()
        plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'loss.jpg', bbox_inches='tight')


        plt.figure(2,figsize=(8, 4))
        plt.title("D(x) and D(G(z)) During Training")
        plt.plot(self.D_x_list[::10], label="D(x)")
        plt.plot(self.D_z_list[::10], label="D(G(z))")
        plt.xlabel("iterations")
        plt.ylabel("Probability")
        plt.axhline(y=0.5, label="0.5", c="g")  # asymptote
        plt.legend()
        plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'D(x)D(G(z)).jpg', bbox_inches='tight')

        fig = plt.figure(3,figsize=(5, 5))
        plt.axis("off")
        ims = [[plt.imshow(item.permute(1, 2, 0), animated=True)] for item in self.img_list]
        ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
        HTML(ani.to_jshtml())
        # ani.to_html5_video()
        ani.save(self.figure_save_path + str(num_epochs) + 'epochs_' + 'generation.gif')


        plt.figure(4,figsize=(8, 4))
        # Plot the real images
        plt.subplot(1, 2, 1)
        plt.axis("off")
        plt.title("Real Images")
        real = next(iter(self.dataloader))  # real[0]image,real[1]label
        plt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0))

        # Load the Best Generative Model
        # self.G.load_state_dict(
        #     torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device)))
        self.G.eval()
        # Generate the Fake Images
        with torch.no_grad():
            fake = self.G(self.fixed_noise).cpu()
        # Plot the fake images
        plt.subplot(1, 2, 2)
        plt.axis("off")
        plt.title("Fake Images")
        fake = utils.make_grid(fake[:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0)
        plt.imshow(fake)

        # Save the comparation result
        plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'result.jpg', bbox_inches='tight')
        plt.show()

    def test(self,epoch):
        # Size of the Figure
        plt.figure(figsize=(8, 4))

        # Plot the real images
        plt.subplot(1, 2, 1)
        plt.axis("off")
        plt.title("Real Images")
        real = next(iter(self.dataloader))#real[0]image,real[1]label
        plt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0))

        # Load the Best Generative Model
        self.G.load_state_dict(torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device)))
        self.G.eval()
        # Generate the Fake Images
        with torch.no_grad():
            fake = self.G(self.fixed_noise.to(self.device))
        # Plot the fake images
        plt.subplot(1, 2, 2)
        plt.axis("off")
        plt.title("Fake Images")
        fake = utils.make_grid(fake * 0.5 + 0.5, nrow=10)
        plt.imshow(fake.permute(1, 2, 0))

        # Save the comparation result
        plt.savefig(self.figure_save_path+'result.jpg', bbox_inches='tight')
        plt.show()

3.4 主函数(main.py)

(1)导入文件

from data import ReadData
from model import Discriminator, Generator, weights_init
from net import DCGAN
import torch

(2)定义超参数

ngpu=1
ngf=64
ndf=64
nc=3
nz=100
lr=0.003
beta1=0.5
batch_size=100
num_showimage=100

data_path="./oxford17_class"
model_save_path="./models/"
figure_save_path="./figures/"

device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

(3)实例化

dataset=ReadData(data_path)
dataloader=dataset.getdataloader(batch_size=batch_size)

G = Generator(nz,ngf,nc).apply(weights_init)
print(G)
D = Discriminator(ndf,nc).apply(weights_init)
print(D)

dcgan=DCGAN( lr,beta1,nz,batch_size,num_showimage,device, model_save_path,figure_save_path,G, D, dataloader)

(4)进行训练

dcgan.train(num_epochs=20)

4.训练过程

4.1 Generator和Discriminator的Loss损失曲线图

训练过程中Generator和Discriminator的Loss曲线图(以200个epoch为例):
Generator和Discriminator的Loss损失曲线图

4.2 D(x)和D(G(z))曲线图

训练过程中Discriminator输出(以200个epoch为例):
D(x)和D(G(z))曲线图

4.3最终生成结果图

训练结束后生成图片(以5个epoch为例):
最终生成结果图

5.完整代码

链接:https://pan.baidu.com/s/15J6sZL3rCPLm2jZFEuyzNw
提取码:DGAN

6.引用参考

https://blog.csdn.net/qq_42951560/article/details/112199229
https://blog.csdn.net/qq_42951560/article/details/110308336

7.问题反馈

如果运行有问题,欢迎给我私信留言!

标签:plt,nn,self,torch,二次元,Pytorch,DCGAN,path,size
来源: https://blog.csdn.net/qq_44031210/article/details/120111225