nn.BatchNorm2d的具体实现
作者:互联网
参考:https://blog.csdn.net/qq_38253797/article/details/116847588
import torch import torch.nn as nn import torch.nn.functional as F import numpy as np def _bn(): _batch = torch.randn(3, 4, 5, 5) aa = [] bb = [] for c in range(4): aa.append(0 + torch.mean(_batch[:, c, :, :]) * 0.1) bb.append(1 * 0.9 + torch.var(_batch[:, c, :, :]) * 0.1) print(aa) print(bb) m = nn.BatchNorm2d(4, affine=False, momentum=0.1) _a1 = m(_batch) print(_a1.shape) print(m.running_mean) print(m.running_var) if __name__ == '__main__': _bn()
标签:__,BatchNorm2d,nn,torch,batch,具体,print,import 来源: https://www.cnblogs.com/dxscode/p/15902639.html