Deep Coral loss

发布时间:2024-03-27 17:01

import torch


def CORAL(source, target):
    d = source.data.shape[1] #coral公式中的分母部分
    ns, nt = source.data.shape[0], target.data.shape[0]
    # source covariance
    xm = torch.mean(source, 0, keepdim=True) - source #对应着Cs的分子部分
    xc = xm.t() @ xm/(ns-1)  #对应着Cs的分子部分

    # target covariance
    xmt = torch.mean(target, 0, keepdim=True) - target#对应着Ct的分子部分
    xct = xmt.t() @ xmt/(nt-1)#对应着Ct的分子部分

    # frobenius norm between source and target
    loss = torch.mean(torch.mul((xc - xct), (xc - xct))) #Cs-Ct的点乘
    loss = loss/(4*d*d)

    return loss

Coral公式:

\"\"

\"Deep

只做学习使用,作者也是看了别人进行了学习总结,希望能对你有所帮助。

 故障诊断与python学习

Deep CORAL: Correlation Alignment for Deep Domain Adaptation_gdtop818的博客-CSDN博客

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号