Pytorch笔记:诡异的索引操作 + too many indices for tensor of dimension 1的一种解决方法

发布时间:2023-11-27 08:30

问题

too many indices for tensor of dimension 1

具体的错误信息忘记截图,大概就是上面的意思,对应的错误代码如下

错误代码

# 下面代码是博主自己随意码的,具体要说一下怎么解决这种问题
import torch
index = torch.tensor([20,10,25,39,5,12])  # 这些index的值,对应着point中的索引
point = torch.ones(2,40)
# 然后用index中的数对point进行各点取值
result_x = point[0,index]
result_y = point[1,index] 

分析

这段代码的作用:pytorch的切片操作大家估计很熟悉,一般都是连续的操作比如:

a[4:,]  # 从行的索引4开始向后取,列全取
a[:4,]  # 从行的第一行开始取,取到索引4,但不包括索引4,列全取

但是,如果你想取的值不是连续的呢?你取的数,都是分开的,那要怎么切片操作???

博主之前在训练一个模型的时候,因为不知道怎么用类似于切片的操作,对tensor进行并行取值,用了双层循环,一个索引一个索引的取,一个模型大概跑了三天多,然后昨天改了一天,训练时间变成了1天多,简直飞起。

就上面代码的执行,会报以下的错误:

too many indices for tensor of dimension 1

也就是说,index的tensor里的indices太多了。然后网上找了一大堆的解决方法,都没有用,最后因为我之前的一部分代码中,用的索引取值(二维的,但都是单个数值)时,报了错误,说是indices必须时long、byte、bool类型(好像是这样说,忘了忘了),然后我就把size为1*1的tensor,强制转换为long类型,然后随手把下面报错的代码(也就是文章开头的代码)中的indices都转换为long类型了,然后就不报错。

正确代码

# 下面代码是博主自己随意码的,具体要说一下怎么解决这种问题
import torch
index = torch.tensor([20,10,25,39,5,12])  # 这些index的值,对应着point中的索引
point = torch.ones(2,40)
# 然后用index中的数对point进行按点取值
result_x = point[0,index.long()]
result_y = point[1,index.long()]   #只要加上这个转换,

只要加上long转换,解决上面报的错误too many indices for tensor of dimension 1,然后就可以跟连续切片操作一样,进行并行的,按点取值,大大节省了训练时间。

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号