其他分享
首页 > 其他分享> > torch_09_DCGAN_注意的细节

torch_09_DCGAN_注意的细节

作者:互联网

DCGAN github链接:https://github.com/darr/DCGAN

DCGAN:
1.在一次epoch中,如果第i批的i能够整除every_print,则打印到output文件中(打印出来)
2.训练过程:计算损失,梯度下降
3.iters:所有epochs累加到一块,迭代总次数
    如果iters能够整除save_print,或者进行到了最后一个epoch and i是最后一个epoch的最后一组,则保存该生成模型
具体操作:
     1.将生成的64张图片放入G网络中,生成一组假图片测试G网络的效果,这些假图片放入img_lst中
     2.将train_model的字典的state_dict(),使用torch.save(model_dict, .tar)保存到tar文件中

4.在show image中: 

       1.显示loss值的图片:生成G_D_losses.jpg
     2.将生成的假图片保存成动画      

 1 def _save_img_list(img_list, save_path, config):
 2     #_show_img_list(img_list)
 3     metadata = dict(title='generator images', artist='Matplotlib', comment='Movie support!')
 4     writer = ImageMagickWriter(fps=1,metadata=metadata)
 5     ims = [np.transpose(i, (1, 2, 0)) for i in img_list]
 6     fig, ax = plt.subplots()
 7     with writer.saving(fig, "%s/img_list.gif" % save_path,500):
 8         for i in range(len(ims)):
 9             ax.imshow(ims[i])
10             ax.set_title("step {}".format(i * config["save_every"]))
11             writer.grab_frame()

    3.将生成的假图片保存成图片    

 1 def _save_img_list(img_list, save_path): # 假图片的列表,保存路径,
 2     
 3     ims = [np.transpose(i, (1, 2, 0)) for i in img_list]
 4     name_img = 0
 5     for i in range(len(ims)):
 6         plt.figure(figsize=(8, 8))
 7         plt.subplot(1, 2, 1)
 8         plt.axis("off")
 9         str_name = "fake Images"+str(name_img)
10         plt.title("fake Images"+str(name_img))
11         name_img += 500
12         plt.imshow(ims[i])
13         name = str_name
14         full_path_name = "%s/%s" % (save_path, name)
15         plt.savefig(full_path_name) 

 

标签:09,name,img,torch,list,ims,plt,DCGAN,save
来源: https://www.cnblogs.com/shuangcao/p/11796642.html