其他分享
首页 > 其他分享> > 对比学习量化评价

对比学习量化评价

作者:互联网

在超球面上通过对齐和一致实现理解对比表示学习 —— 论文阅读笔记

两个对比损失最关键的要素:

image

torch 版本代码:

# bsz : batch size (number of positive pairs)
# d : latent dim
# x : Tensor, shape=[bsz, d]
# latents for one side of positive pairs
# y : Tensor, shape=[bsz, d]
# latents for the other side of positive pairs
# lam : hyperparameter balancing the two losses
def lalign(x, y, alpha=2):
    return (x - y).norm(dim=1).pow(alpha).mean()
def lunif(x, t=2):
    sq_pdist = torch.pdist(x, p=2).pow(2)
    return sq_pdist.mul(-t).exp().mean().log()
loss = lalign(x, y) + lam * (lunif(x) + lunif(y)) / 2

tensorflow 版本:

def lalign(x, y, alpha=2):
    """
    x: [bs, d] latents for one side of positive pairs
    y: [bs,d] latents for the other side of positive pairs
    """
    # 第二范数
    return tf.reduce_mean(tf.pow(tf.norm(x - y, axis=1), alpha))

def lunif(x, t=2):
    """
    x: [bs, d]
    """
    batch_size = tf.shape(x)[0]
    # 实现torch.pdist
    x=tf.cast(x, tf.float32)
    pdist_matrix = tf.norm(x[:, None]-x, axis=2)
    bool_mask = tf.cast(1-tf.linalg.band_part(tf.ones((batch_size,batch_size)),-1,0), bool) # 右上对角线
    pdist = pdist_matrix[bool_mask]
    sq_pdist = tf.pow(pdist, 2)
    return tf.math.log(tf.reduce_mean(tf.exp(-t*sq_pdist)))

标签:pairs,positive,sq,pdist,评价,tf,量化,pow,对比
来源: https://www.cnblogs.com/carolsun/p/16419791.html