ann2snn的代码分析
作者:互联网
首先,主函数是if_cnn_mnist_work.py
1.输出snn测试结果的就是这么一些代码:
utils.pytorch_ann2snn(model_name=model_name,
norm_tensor=norm_tensor,
test_data_loader=test_data_loader,
device=device,
T=T,
log_dir=log_dir,
config=config
)
2.ctrl+鼠标左键点击pytorch_ann2snn
可以看到里面的代码
def pytorch_ann2snn(model_name, norm_tensor, test_data_loader, device, T, log_dir, config,
load_state_dict=False, ann=None):
print("我在pytorch_ann2snn子函数中,其中设备使用的是"+device)
'''
* :ref:`API in English <pytorch_conversion-en>`
.. _standard_conversion-cn:
:param model_name: 模型名字,用于文件夹中寻找保存的模型
:param norm_tensor: 用于模型归一化的数据,其格式以能够作为网络输入为准。这部分数据应当从训练集抽取
:param test_data_loader: 测试数据加载器,用于仿真
:param device: 运行的设备
:param T: 仿真时长
:param log_dir: 用于保存临时文件的日志文件夹
:param config: 用于转换的配置
:param load_state_dict: 如果希望使用state dict加载的模型,将此参数设置为 ``True``
:param ann: 用于加载state dict的模型,使用的模块均为Pytorch内置模块
:return: ``None``
其中snn准确率是在这里计算得到的:
snn_acc = sim.simulate_snn(snn=snn,
device=device,
data_loader=test_data_loader,
T=T,
poisson=config['simulation']['encoder']['possion'],
fig_name=model_name,
ann_baseline=ann_acc*100,
log_dir=log_dir)
3.ctrl+鼠标左键点击simulata_snn
可以看到里面的代码,第64行
correct += (out_spikes_counter.max(1)[1] == label.to(device)).float().sum().item()
代表就是正确的数量
其中out_spikes_counter.max(1)[1]
就是目标输出的,label.to(device)).float().sum().item()
就是标签。如果他们两者相等,那么就把correct
加上一个1
4.接下来的工作就是改一下上述位置的代码,可以输出混淆矩阵就行了
标签:分析,name,snn,代码,ann2snn,param,loader,device,dir 来源: https://blog.csdn.net/huatianxue/article/details/117436411