其他分享
首页 > 其他分享> > 读《PyTorch + NumPy这么做会降低模型准确率,这是bug还是预期功能?》

读《PyTorch + NumPy这么做会降低模型准确率,这是bug还是预期功能?》

作者:互联网

看了文章: 

【转载】 浅谈PyTorch的可重复性问题(如何使实验结果可复现)

 

 

然后,转到:

PyTorch + NumPy这么做会降低模型准确率,这是bug还是预期功能?

 

 

发现了在pytorch中的一个容易被忽略的问题,那就是多进程操作时各个进程其实是和父进程有着相同的随机种子的,重点不在于各个子进程和父进程随机种子相同,重点的是这些子进程之间的随机种子也是相同的,那么在这些子进程中进行的任何相同顺序的随机数生成也会是相同的,这个现象有可能导致自己的代码运行获得不到自己计划得到的结果,因此该现象应该被注意。

 

其实该种现象还是很常见的,如果同时在一个linux系统中fork生成100个进程,每个进程都是以当前系统时间作为随机种子,那么这100个进程的随机种子也是完全相同的,这个问题是很容易被忽视的。

 

 

 

 

原文中的表述:

PyTorch 使用多进程并行加载数据,worker 进程是使用 fork start 方法创建的。这意味着每个工作进程继承父进程的所有资源,包括 NumPy 的随机数生成器的状态。

 

 

 

 

 

 

 

各个子进程设置不同的随机种子的方法:            (引自:https://www.cnblogs.com/devilmaycry812839668/p/14693658.html

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

 

 

 

 

 

 

 

 

 

重点在函数:   worker_init_fn

 

该函数在各个子进程初始的时候执行,我们可以在这个函数中进行设置以使各个子进程的随机种子不相同。

 

标签:GLOBAL,worker,PyTorch,seed,随机,做会,进程,NumPy,种子
来源: https://www.cnblogs.com/devilmaycry812839668/p/15840632.html