神经网络高维互信息计算Python实现(MINE)
作者:互联网
论文
Belghazi, Mohamed Ishmael, et al. “ Mutual information neural estimation .” International Conference on Machine Learning . 2018.
利用神经网络的梯度下降法可以实现快速高维连续随机变量之间互信息的估计,上述论文提出了Mutual Information Neural Estimator (MINE)。NN在维度和样本量上都是线性可伸缩的,MI的计算可以通过反向传播进行训练。
核心
Python实现
现有github上的代码无法计算和估计高维随机变量,只能计算一维随机变量,下面的代码给出的修改方案能够计算真实和估计高维随机变量的真实互信息。
其中,为了计算理论的真实互信息,我们不直接暴力求解矩阵(耗时,这也是为什么要有MINE的原因),我们采用给定生成随机变量的参数计算理论互信息。
SIGNAL_NOISE = 0.2 SIGNAL_POWER = 3
完整代码基于pytorch
# Name: MINE_simple # Author: Reacubeth # Time: 2020/12/15 18:49 # Mail: noverfitting@gmail.com # Site: www.omegaxyz.com # *_*coding:utf-8 *_* import numpy as np import torch import torch.nn as nn from tqdm import tqdm import matplotlib.pyplot as plt SIGNAL_NOISE = 0.2 SIGNAL_POWER = 3 data_dim = 3 num_instances = 20000 def gen_x(num, dim): return np.random.normal(0., np.sqrt(SIGNAL_POWER), [num, dim]) def gen_y(x, num, dim): return x + np.random.normal(0., np.sqrt(SIGNAL_NOISE), [num, dim]) def true_mi(power, noise, dim): return dim * 0.5 * np.log2(1 + power/noise) mi = true_mi(SIGNAL_POWER, SIGNAL_NOISE, data_dim) print('True MI:', mi) hidden_size = 10 n_epoch = 500 class MINE(nn.Module): def __init__(self, hidden_size=10): super(MINE, self).__init__() self.layers = nn.Sequential(nn.Linear(2 * data_dim, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1)) def forward(self, x, y): batch_size = x.size(0) tiled_x = torch.cat([x, x, ], dim=0) idx = torch.randperm(batch_size) shuffled_y = y[idx] concat_y = torch.cat([y, shuffled_y], dim=0) inputs = torch.cat([tiled_x, concat_y], dim=1) logits = self.layers(inputs) pred_xy = logits[:batch_size] pred_x_y = logits[batch_size:] loss = - np.log2(np.exp(1)) * (torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y)))) # compute loss, you'd better scale exp to bit return loss model = MINE(hidden_size) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) plot_loss = [] all_mi = [] for epoch in tqdm(range(n_epoch)): x_sample = gen_x(num_instances, data_dim) y_sample = gen_y(x_sample, num_instances, data_dim) x_sample = torch.from_numpy(x_sample).float() y_sample = torch.from_numpy(y_sample).float() loss = model(x_sample, y_sample) model.zero_grad() loss.backward() optimizer.step() all_mi.append(-loss.item()) fig, ax = plt.subplots() ax.plot(range(len(all_mi)), all_mi, label='MINE Estimate') ax.plot([0, len(all_mi)], [mi, mi], label='True Mutual Information') ax.set_xlabel('training steps') ax.legend(loc='best') plt.show()
结果
变量维度为1
变量维度为3
需要指出的是在计算最终的互信息时需要将基数e转为基数2。如果只是求得一个比较值,在真实使用的过程中可以省略。
本文的文字及图片来源于网络,仅供学习、交流使用,不具有任何商业用途,如有问题请及时联系我们以作处理
想要获取更多Python学习资料可以加
QQ:2955637827私聊
或加Q群630390733
大家一起来学习讨论吧!
标签:dim,Python,互信息,torch,mi,MINE,sample,size 来源: https://www.cnblogs.com/putao11111/p/14156531.html