发布时间:2024-03-27 17:01
import torch
def CORAL(source, target):
d = source.data.shape[1] #coral公式中的分母部分
ns, nt = source.data.shape[0], target.data.shape[0]
# source covariance
xm = torch.mean(source, 0, keepdim=True) - source #对应着Cs的分子部分
xc = xm.t() @ xm/(ns-1) #对应着Cs的分子部分
# target covariance
xmt = torch.mean(target, 0, keepdim=True) - target#对应着Ct的分子部分
xct = xmt.t() @ xmt/(nt-1)#对应着Ct的分子部分
# frobenius norm between source and target
loss = torch.mean(torch.mul((xc - xct), (xc - xct))) #Cs-Ct的点乘
loss = loss/(4*d*d)
return loss
Coral公式:
只做学习使用,作者也是看了别人进行了学习总结,希望能对你有所帮助。
故障诊断与python学习
Deep CORAL: Correlation Alignment for Deep Domain Adaptation_gdtop818的博客-CSDN博客