FSL-GNN代码解读
作者:互联网
FSL-GNN代码解读
main.py(主函数)
1、加载数据集:
train_loader = generator.Generator(args.dataset_root, args, partition='train', dataset=args.dataset)
2、初始化或加载模型:
enc_nn = models.load_model('enc_nn', args, io)
metric_nn = models.load_model('metric_nn', args, io)
if enc_nn is None or metric_nn is None:
enc_nn, metric_nn = models.create_models(args=args)
softmax_module = models.SoftmaxModule()
models.create_models(args=args)
: in models.py
def create_models(args):
print (args.dataset)
if 'omniglot' == args.dataset:
enc_nn = EmbeddingOmniglot(args, 64)
elif 'mini_imagenet' == args.dataset:
enc_nn = EmbeddingImagenet(args, 128)
else:
raise NameError('Dataset ' + args.dataset + ' not knows')
return enc_nn, MetricNN(args, emb_size=enc_nn.emb_size)
class EmbeddingOmniglot(): # 特征提取
class EmbeddingImagenet(): # 略
class MetricNN(nn.Module):
if self.metric_network == 'gnn_iclr_nl':…… # 正常的网络
self.gnn_obj = gnn_iclr.GNN_nl() # in gnn_iclr.py
elif self.metric_network == 'gnn_iclr_active':…… # 主动学习
self.gnn_obj = gnn_iclr.GNN_active()# in gnn_iclr.py
class SoftmaxModule(): # 线性分类
class GNN_nl(nn.Module) & class GNN_active(nn.Module)
: in gnn_iclr.py
class GNN_nl(nn.Module): # 图网络主要部分
class Wcompute(nn.Module) # W邻接矩阵计算
class Gconv(nn.Module) # 组图
def gmul(input) # 更新图节点特征,W直接返回
3、训练
# 权重衰减
weight_decay = 1e-6
# 优化器
opt_enc_nn = optim.Adam(enc_nn.parameters(), lr=args.lr, weight_decay=weight_decay)
opt_metric_nn = optim.Adam(metric_nn.parameters(), lr=args.lr, weight_decay=weight_decay)
# 梯度置零,也就是把loss关于weight的导数变成0
opt_enc_nn.zero_grad()
opt_metric_nn.zero_grad()
# 训练
loss_d_metric = train_batch(
model=[enc_nn, metric_nn,
softmax_module],
data=[batch_x, label_x, batches_xi, labels_yi, oracles_yi, hidden_labels])
# 更新参数
opt_enc_nn.step()
opt_metric_nn.step()
# 自适应参数
adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn], lr=args.lr, iter=batch_idx)
# 显示训练中loss的更新
if batch_idx % args.log_interval == 0:
display_str = 'Train Iter: {}'.format(batch_idx)
display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss/counter)
io.cprint(display_str)
# 测试
def test_one_shot(args, model, test_samples=5000, partition='test') 定义于 test.py 中
val_acc_aux = test.test_one_shot # 验证集上测试
test_acc_aux = test.test_one_shot # 测试集上测试
test.test_one_shot( # 训练集上测试
args,
model=[enc_nn, metric_nn, softmax_module],
test_samples=test_samples,
partition='train')
# 测试完毕,将模型设置回训练状态
enc_nn.train()
metric_nn.train()
# 若在验证集上的效果继续变好,则更新
if val_acc_aux is not None and val_acc_aux >= val_acc:
# 保存模型
torch.save(enc_nn, 'checkpoints/%s/models/enc_nn.t7' % args.exp_name)
torch.save(metric_nn, 'checkpoints/%s/models/metric_nn.t7' % args.exp_name)
# 全部训练完毕后进行测试
test.test_one_shot
标签:enc,nn,models,metric,args,FSL,解读,test,GNN 来源: https://www.cnblogs.com/SethDeng/p/15371286.html