发布时间:2024-12-02 14:01
目录
1、torch.Tensor.repeat()
2、torch.Tensor.expand()
函数定义:
repeat(*sizes) → Tensor
作用:
在指定的维度上重复这个张量,即把这个维度的张量复制*sizes次。同时可以通过复制的形式扩展维度的数量。
注意:torch.Tensor.repeat方法与numpy.tile方法作用相似,而不是numpy.repeat!torch中与numpy.repeat类似的方法是torch.repeat_interleave!
区别:与expand的不同之处在于,repeat函数传入的参数直接就是对应维度要扩充的倍数,而不是最后的shape。
举例分析:
例1——对应(已存在的)维度的拓展。
import torch
a = torch.tensor([[1], [2], [3]]) # 3 * 1
b = a.repeat(3, 2) # torch.linspace(0, 10, 5)
print('a:\n', a)
print('shape of a', a.size()) # 原始shape = (3,1)
print('b:\n', b)
print('shape of b', b.size()) # 新的shape = (3*3,1*2),新增加的数据通过复制得到
''' 运行结果 '''
a:
tensor([[1],
[2],
[3]])
shape of a torch.Size([3, 1]) 注: 原始shape = (3,1)
b:
tensor([[1, 1],
[2, 2],
[3, 3],
[1, 1],
[2, 2],
[3, 3],
[1, 1],
[2, 2],
[3, 3]])
shape of b torch.Size([9, 2]) 注: 新的shape = (3*3,1*2),新增加的数据通过复制得到
例2——带有(原始不存在的)维度数量拓展的用法:
import torch
a = torch.tensor([[1, 2], [3, 4], [5, 6]]) # 3 * 2
b = a.repeat(3, 2, 1) # 在原始tensor的0维前拓展一个维度,并把原始tensor的第1维扩张2倍,都是通过复制来完成的
print('a:\n', a)
print('shape of a', a.size()) # 原始维度为 (3,2)
print('b:\n', b)
print('shape of b', b.size()) # 新的维度为 (3,2*2,2*1)=(3,4,2)
''' 运行结果 '''
a:
tensor([[1, 2],
[3, 4],
[5, 6]])
shape of a torch.Size([3, 2]) 注:原始维度为 (3,2)
b:
tensor([[[1, 2],
[3, 4],
[5, 6],
[1, 2],
[3, 4],
[5, 6]],
[[1, 2],
[3, 4],
[5, 6],
[1, 2],
[3, 4],
[5, 6]],
[[1, 2],
[3, 4],
[5, 6],
[1, 2],
[3, 4],
[5, 6]]])
shape of b torch.Size([3, 6, 2]) 注:新的维度为 (3,2*2,2*1)=(3,4,2)
函数定义:
expand(*sizes) → Tensor
作用:不仅可以对tensor指定的(已存在的)维度进行扩大(复制型扩大),扩大后的shape为*(size)。而且还有类似于unsqueeze的维度扩充功能,新增加的维度将会加在前面。
区别:与repeat不同之处在于,expand传入的参数直接就是将tensor扩大后的shape。
举例说明:
a = torch.ones(3, 1) # 创建3*1的全为1的tensor
b = a.expand(3, 2) # 对a的维度1进行扩充
c = a.expand(2, 3, 2) # 对a的维度1进行扩充,并在维度0前加一个维度
print('a:', a)
print('shape of a:', a.shape)
print('b:', b)
print('shape of b:', b.shape)
print('c:', c)
print('shape of c:', c.shape)
''' 运行结果 '''
a: tensor([[1.],
[1.],
[1.]])
shape of a: torch.Size([3, 1])
b: tensor([[1., 1.],
[1., 1.],
[1., 1.]])
shape of b: torch.Size([3, 2])
c: tensor([[[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.]]])
shape of c: torch.Size([2, 3, 2])