解决RuntimeError: one of the variables needed for gradient computation has been modified by an inplace

发布时间:2023-01-01 10:30

问题描述:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation 

原因分析:

跑过一个GAN的Super-Resolution的网络, 在训练的时候总是报上面的错误,显示loss_backward()的过程中权重等不对,在百度查了好久都是给相同的答案,估计都是一个复制一个,一个复制一个的,当时气死了, 源代码是这样的

            netD.zero_grad()
            #optimizerD.step()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)
            #d_loss.backward()
            optimizerD.step()
    

            netG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()
            optimizerG.step()

            fake_img = netG(z)
            fake_out = netD(fake_img).mean()
            #this code is applied to the pytorch1.4 version and before.
            #because the optimizer.step() actually modifies the weights of the model inplace while the original value of these weights is needed to compute the loss.backward() !!!!

里面有我的注释,网上给的大多数方案,应该是几乎所有,都是关于命名的,比如传递值的时候尽量不用相同的名字,这样会造成梯度反向传播的时候会造成网络判断不对, 真的吐了…
其实是因为pytorch版本到达1.5以后的autograd变了而已…

解决方案:

1.其实很简单,第一种简单的方案就是把pytorch恢复到1.4之前的环境,就可以了, 还是附上Ubuntu命令

conda create -n torch python=3.7
conda activate torch
pip install torch==1.2.0 torchvision==0.4.0 
#如果网速太慢可以用清华源  后面加 -i https://pypi.tuna.tsinghua.edu.cn/simple some-package
pip install torch==1.2.0 torchvision==0.4.0 -i https://pypi.tuna.tsinghua.edu.cn/simple some-package
pip install numpy matplotlib scipy==1.2.0 pandas scikit-learn scikit-image
#一些常用的

2.第二种方法就是改一下代码就可以了, 把更新梯度的步骤调后放在一起即可, 以我的为例:

            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            netD.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            netG.zero_grad()

            d_loss.backward(retain_graph=True)
            g_loss.backward()

            optimizerD.step()
            optimizerG.step()

            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

解决了, 希望复制党少一点, 太没有营养了,

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号