torch_09_DCGAN_注意的细节
作者:互联网
DCGAN github链接:https://github.com/darr/DCGAN
DCGAN:
1.在一次epoch中,如果第i批的i能够整除every_print,则打印到output文件中(打印出来)
2.训练过程:计算损失,梯度下降
3.iters:所有epochs累加到一块,迭代总次数
如果iters能够整除save_print,或者进行到了最后一个epoch and i是最后一个epoch的最后一组,则保存该生成模型
具体操作:
1.将生成的64张图片放入G网络中,生成一组假图片测试G网络的效果,这些假图片放入img_lst中
2.将train_model的字典的state_dict(),使用torch.save(model_dict, .tar)保存到tar文件中
4.在show image中:
1.显示loss值的图片:生成G_D_losses.jpg
2.将生成的假图片保存成动画
1 def _save_img_list(img_list, save_path, config): 2 #_show_img_list(img_list) 3 metadata = dict(title='generator images', artist='Matplotlib', comment='Movie support!') 4 writer = ImageMagickWriter(fps=1,metadata=metadata) 5 ims = [np.transpose(i, (1, 2, 0)) for i in img_list] 6 fig, ax = plt.subplots() 7 with writer.saving(fig, "%s/img_list.gif" % save_path,500): 8 for i in range(len(ims)): 9 ax.imshow(ims[i]) 10 ax.set_title("step {}".format(i * config["save_every"])) 11 writer.grab_frame()
3.将生成的假图片保存成图片
1 def _save_img_list(img_list, save_path): # 假图片的列表,保存路径, 2 3 ims = [np.transpose(i, (1, 2, 0)) for i in img_list] 4 name_img = 0 5 for i in range(len(ims)): 6 plt.figure(figsize=(8, 8)) 7 plt.subplot(1, 2, 1) 8 plt.axis("off") 9 str_name = "fake Images"+str(name_img) 10 plt.title("fake Images"+str(name_img)) 11 name_img += 500 12 plt.imshow(ims[i]) 13 name = str_name 14 full_path_name = "%s/%s" % (save_path, name) 15 plt.savefig(full_path_name)
标签:09,name,img,torch,list,ims,plt,DCGAN,save 来源: https://www.cnblogs.com/shuangcao/p/11796642.html