其他分享
首页 > 其他分享> > few-shot-gnn代码阅读

few-shot-gnn代码阅读

作者:互联网

训练

分为两个网络:
Embedding层和GNN度量层

EmbeddingOmniglot

omniglot
EmbeddingOmniglot(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc_last): Linear(in_features=576, out_features=64, bias=False)
  (bn_last): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

MetricNN

MetricNN(
  (gnn_obj): GNN_nl_omniglot(
    (layer_w0): Wcompute(
      (conv2d_1): Conv2d(84, 168, kernel_size=(1, 1), stride=(1, 1))
      (bn_1): BatchNorm2d(168, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2d_2): Conv2d(168, 126, kernel_size=(1, 1), stride=(1, 1))
      (bn_2): BatchNorm2d(126, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2d_3): Conv2d(126, 84, kernel_size=(1, 1), stride=(1, 1))
      (bn_3): BatchNorm2d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2d_4): Conv2d(84, 84, kernel_size=(1, 1), stride=(1, 1))
      (bn_4): BatchNorm2d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2d_last): Conv2d(84, 1, kernel_size=(1, 1), stride=(1, 1))
    )
    (layer_l0): Gconv(
      (fc): Linear(in_features=168, out_features=48, bias=True)
      (bn): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (layer_w1): Wcompute(
      (conv2d_1): Conv2d(132, 264, kernel_size=(1, 1), stride=(1, 1))
      (bn_1): BatchNorm2d(264, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2d_2): Conv2d(264, 198, kernel_size=(1, 1), stride=(1, 1))
      (bn_2): BatchNorm2d(198, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2d_3): Conv2d(198, 132, kernel_size=(1, 1), stride=(1, 1))
      (bn_3): BatchNorm2d(132, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2d_4): Conv2d(132, 132, kernel_size=(1, 1), stride=(1, 1))
      (bn_4): BatchNorm2d(132, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2d_last): Conv2d(132, 1, kernel_size=(1, 1), stride=(1, 1))
    )
    (layer_l1): Gconv(
      (fc): Linear(in_features=264, out_features=48, bias=True)
      (bn): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (w_comp_last): Wcompute(
      (conv2d_1): Conv2d(180, 264, kernel_size=(1, 1), stride=(1, 1))
      (bn_1): BatchNorm2d(264, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dropout): Dropout(p=0.3, inplace=False)
      (conv2d_2): Conv2d(264, 198, kernel_size=(1, 1), stride=(1, 1))
      (bn_2): BatchNorm2d(198, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2d_3): Conv2d(198, 132, kernel_size=(1, 1), stride=(1, 1))
      (bn_3): BatchNorm2d(132, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2d_4): Conv2d(132, 132, kernel_size=(1, 1), stride=(1, 1))
      (bn_4): BatchNorm2d(132, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2d_last): Conv2d(132, 1, kernel_size=(1, 1), stride=(1, 1))
    )
    (layer_last): Gconv(
      (fc): Linear(in_features=360, out_features=20, bias=True)
      (bn): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)

标签:shot,05,0.1,gnn,affine,eps,few,stats,True
来源: https://blog.csdn.net/nini_k/article/details/118522237