发布时间:2024-10-18 11:01
其实也没表达什么就是一开始大家的参数是一样的:
f_k.params = f_q.params
这个也没啥的就是取出来数据的问题:
for x in loader: # load a minibatch x with N samples
就是代码不是直接将内容输入其中,也会通过数据增强取出内容
x_q = aug(x) # a randomly augmented version
x_k = aug(x) # another randomly augmented version
首先我们先理解一下这个N和C是什么?
q = f_q.forward(x_q) # queries: NxC
k = f_k.forward(x_k) # keys: NxC
N其实是一个batch_size
C是一个输入数据的特征数,每个输入数据是一个1×C的张量
k = k.detach() # no gradient to keys
这个其实就是文章的主要创新点了,因为优化key_encoder是来自于query_encoder的优化。所以自然就不需要前传梯度,也能剩下个内存。
这里是矩阵乘法,理解一下这里的矩阵乘法:
# positive logits: Nx1
l_pos = bmm(q.view(N,1,C), k.view(N,C,1))
# negative logits: NxK
l_neg = mm(q.view(N,C), queue.view(C,K))
# logits: Nx(1+K)
logits = cat([l_pos, l_neg], dim=1)