其他分享
首页 > 其他分享> > ann2snn的代码分析

ann2snn的代码分析

作者:互联网

首先,主函数是if_cnn_mnist_work.py
1.输出snn测试结果的就是这么一些代码:

    utils.pytorch_ann2snn(model_name=model_name,
                              norm_tensor=norm_tensor,
                              test_data_loader=test_data_loader,
                              device=device,
                              T=T,
                              log_dir=log_dir,
                              config=config
                              )

2.ctrl+鼠标左键点击pytorch_ann2snn
可以看到里面的代码

def pytorch_ann2snn(model_name, norm_tensor, test_data_loader, device, T, log_dir, config,
                        load_state_dict=False, ann=None):
    print("我在pytorch_ann2snn子函数中,其中设备使用的是"+device)
    '''
    * :ref:`API in English <pytorch_conversion-en>`

    .. _standard_conversion-cn:

    :param model_name: 模型名字,用于文件夹中寻找保存的模型
    :param norm_tensor: 用于模型归一化的数据,其格式以能够作为网络输入为准。这部分数据应当从训练集抽取
    :param test_data_loader: 测试数据加载器,用于仿真
    :param device: 运行的设备
    :param T: 仿真时长
    :param log_dir: 用于保存临时文件的日志文件夹
    :param config: 用于转换的配置
    :param load_state_dict: 如果希望使用state dict加载的模型,将此参数设置为 ``True`` 
    :param ann: 用于加载state dict的模型,使用的模块均为Pytorch内置模块
    :return: ``None``

其中snn准确率是在这里计算得到的:

    snn_acc = sim.simulate_snn(snn=snn,
                               device=device,
                               data_loader=test_data_loader,
                               T=T,
                               poisson=config['simulation']['encoder']['possion'],
                               fig_name=model_name,
                               ann_baseline=ann_acc*100,
                               log_dir=log_dir)

3.ctrl+鼠标左键点击simulata_snn
可以看到里面的代码,第64行

correct += (out_spikes_counter.max(1)[1] == label.to(device)).float().sum().item()

代表就是正确的数量
其中out_spikes_counter.max(1)[1]就是目标输出的,label.to(device)).float().sum().item()就是标签。如果他们两者相等,那么就把correct加上一个1
4.接下来的工作就是改一下上述位置的代码,可以输出混淆矩阵就行了

标签:分析,name,snn,代码,ann2snn,param,loader,device,dir
来源: https://blog.csdn.net/huatianxue/article/details/117436411