发布时间:2023-11-27 08:30
too many indices for tensor of dimension 1
具体的错误信息忘记截图,大概就是上面的意思,对应的错误代码如下
# 下面代码是博主自己随意码的,具体要说一下怎么解决这种问题
import torch
index = torch.tensor([20,10,25,39,5,12]) # 这些index的值,对应着point中的索引
point = torch.ones(2,40)
# 然后用index中的数对point进行各点取值
result_x = point[0,index]
result_y = point[1,index]
这段代码的作用:pytorch的切片操作大家估计很熟悉,一般都是连续的操作比如:
a[4:,] # 从行的索引4开始向后取,列全取
a[:4,] # 从行的第一行开始取,取到索引4,但不包括索引4,列全取
但是,如果你想取的值不是连续的呢?你取的数,都是分开的,那要怎么切片操作???
博主之前在训练一个模型的时候,因为不知道怎么用类似于切片的操作,对tensor进行并行取值,用了双层循环,一个索引一个索引的取,一个模型大概跑了三天多,然后昨天改了一天,训练时间变成了1天多,简直飞起。
就上面代码的执行,会报以下的错误:
too many indices for tensor of dimension 1
也就是说,index的tensor里的indices太多了。然后网上找了一大堆的解决方法,都没有用,最后因为我之前的一部分代码中,用的索引取值(二维的,但都是单个数值)时,报了错误,说是indices必须时long、byte、bool类型(好像是这样说,忘了忘了),然后我就把size为1*1的tensor,强制转换为long类型,然后随手把下面报错的代码(也就是文章开头的代码)中的indices都转换为long类型了,然后就不报错。
# 下面代码是博主自己随意码的,具体要说一下怎么解决这种问题
import torch
index = torch.tensor([20,10,25,39,5,12]) # 这些index的值,对应着point中的索引
point = torch.ones(2,40)
# 然后用index中的数对point进行按点取值
result_x = point[0,index.long()]
result_y = point[1,index.long()] #只要加上这个转换,
只要加上long转换,解决上面报的错误too many indices for tensor of dimension 1
,然后就可以跟连续切片操作一样,进行并行的,按点取值,大大节省了训练时间。