发布时间:2023-09-22 08:00
我在计算loss的过程中,遇到了以上错误
出错代码1:
temp = target
正确代码1:
temp = target.clone()
出错代码2:
output1[:, 0, :, :, :] = 1 - output1[:, 0, :, :, :]
temp = target
正确代码2:
output2 = output1.clone()
output2[:, 0, :, :, :] = 1 - output1[:, 0, :, :, :]
出错代码:
intersection1 = 2. * (output1 * target1).sum()#output1 和 target1大小不一,在计算过程中,target1会经过广播机制扩成和output1同样大小。
正确代码:
target1 = torch.cat([temp,temp],1)
intersection1 = 2. * (output1 * target1).sum()
当然,此错误还有其他原因