其他分享
首页 > 其他分享> > matplotlib----attention可视化

matplotlib----attention可视化

作者:互联网

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.ticker as ticker

a = torch.randn(4, 2)
b = a.softmax(dim=1)
c = a.softmax(dim=0).transpose(0, 1)
print(a, '\n',  b, '\n', c)
d = b.matmul(c)
print(d)

d = d.numpy()
variables = ['A','B','C','X']
labels = ['ID_0','ID_1','ID_2','ID_3']

df = pd.DataFrame(d, columns=variables, index=labels)

fig = plt.figure()

ax = fig.add_subplot(111)

cax = ax.matshow(df, interpolation='nearest', cmap='hot_r')
fig.colorbar(cax)

tick_spacing = 1
ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
ax.yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))

ax.set_xticklabels([''] + list(df.columns))
ax.set_yticklabels([''] + list(df.index))

plt.show()

参照博客: https://blog.csdn.net/m0_38133212/article/details/86664569

标签:set,df,attention,matplotlib,----,import,ax,ticker,ID
来源: https://blog.csdn.net/weiwei935707936/article/details/113795344