MoCo论文中的Algorithm 1伪代码解读

发布时间:2024-10-18 11:01

具体解读了什么东西

论文中提供的伪代码大约如下:
MoCo论文中的Algorithm 1伪代码解读_第1张图片

下面我将分步骤介绍这个代码干什么

1.query encoder和key encoder的参数初始化

其实也没表达什么就是一开始大家的参数是一样的:

f_k.params = f_q.params

2.之后就是loader当中取数据

这个也没啥的就是取出来数据的问题:

for x in loader: # load a minibatch x with N samples

3.数据增强

就是代码不是直接将内容输入其中,也会通过数据增强取出内容

x_q = aug(x) # a randomly augmented version
x_k = aug(x) # another randomly augmented version

4.核心操作

首先我们先理解一下这个N和C是什么?

q = f_q.forward(x_q) # queries: NxC
k = f_k.forward(x_k) # keys: NxC

N其实是一个batch_size
C是一个输入数据的特征数,每个输入数据是一个1×C的张量

k = k.detach() # no gradient to keys

这个其实就是文章的主要创新点了,因为优化key_encoder是来自于query_encoder的优化。所以自然就不需要前传梯度,也能剩下个内存。

这里是矩阵乘法,理解一下这里的矩阵乘法:

# positive logits: Nx1
l_pos = bmm(q.view(N,1,C), k.view(N,C,1))
# negative logits: NxK
l_neg = mm(q.view(N,C), queue.view(C,K))
# logits: Nx(1+K)
logits = cat([l_pos, l_neg], dim=1)
  • 1.首先我们应当理解一下这个q和k到底是什么东西,可以看到q和k分别来自于x_q和x_k,我们注意这两个东西其实都来自于x只是作了不同的数据增强罢了。
    好了,现在我们应该能判断出来,这里的x和k我们应该认为同一个类别。
  • 2.l_pos 现在我们就知道这个东西应该是一个N*1的一组接近1的数值
  • 3.我们注意queue是我们存储的之前的batch的内容,所以这个东西和我们当前这个batch的内容应该是没有任何交集的,也就是他们来自于不同的内容,按照对比学习的思想,来自不同事物的内容应该完全不相交。所以他们的相似度应该尽量的低。
  • 4.l_neg应当得到一个N*K的一组接近0的数值。
  • 5.logits的内容就自然而然出现了,应该为一个N*(K+1)的内容,这些内容应该具有下面的特点:K+1的向量除了第一位接近1之外其他都应该接近0。
  • 6.在现在的情况下我们自然而然可以得出一个内容就是,每个(K+1)的张量经过softmax之后,模型都应该判别其为正确。也就是所有的N个张量都是0号分类。

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号