PyTorch:torch.Tensor.repeat()、expand()

发布时间:2024-12-02 14:01

目录

1、torch.Tensor.repeat()

2、torch.Tensor.expand()


1、torch.Tensor.repeat()

函数定义:

repeat(*sizes) → Tensor

作用:

在指定的维度上重复这个张量,即把这个维度的张量复制*sizes次。同时可以通过复制的形式扩展维度的数量。

注意:torch.Tensor.repeat方法与numpy.tile方法作用相似,而不是numpy.repeat!torch中与numpy.repeat类似的方法是torch.repeat_interleave!

区别:与expand的不同之处在于,repeat函数传入的参数直接就是对应维度要扩充的倍数,而不是最后的shape。

举例分析:

例1——对应(已存在的)维度的拓展。

import torch

a = torch.tensor([[1], [2], [3]])  # 3 * 1
b = a.repeat(3, 2)  # torch.linspace(0, 10, 5)
print('a:\n', a)
print('shape of a', a.size())  # 原始shape = (3,1)
print('b:\n', b)
print('shape of b', b.size())  # 新的shape = (3*3,1*2),新增加的数据通过复制得到

'''   运行结果   '''
a:
 tensor([[1],
        [2],
        [3]])
shape of a torch.Size([3, 1])  注: 原始shape = (3,1)
b:
 tensor([[1, 1],
        [2, 2],
        [3, 3],
        [1, 1],
        [2, 2],
        [3, 3],
        [1, 1],
        [2, 2],
        [3, 3]])
shape of b torch.Size([9, 2])  注: 新的shape = (3*3,1*2),新增加的数据通过复制得到

例2——带有(原始不存在的)维度数量拓展的用法:

import torch
a = torch.tensor([[1, 2], [3, 4], [5, 6]])  # 3 * 2
b = a.repeat(3, 2, 1)   # 在原始tensor的0维前拓展一个维度,并把原始tensor的第1维扩张2倍,都是通过复制来完成的
print('a:\n', a)
print('shape of a', a.size())  # 原始维度为 (3,2)
print('b:\n', b)
print('shape of b', b.size())  # 新的维度为 (3,2*2,2*1)=(3,4,2)

'''   运行结果   '''
a:
 tensor([[1, 2],
         [3, 4],
         [5, 6]])
shape of a torch.Size([3, 2])   注:原始维度为 (3,2)
b:
 tensor([[[1, 2],
          [3, 4],
          [5, 6],
          [1, 2],
          [3, 4],
          [5, 6]],
 
         [[1, 2],
          [3, 4],
          [5, 6],
          [1, 2],
          [3, 4],
          [5, 6]],

         [[1, 2],
          [3, 4],
          [5, 6],
          [1, 2],
          [3, 4],
          [5, 6]]])
shape of b torch.Size([3, 6, 2])   注:新的维度为 (3,2*2,2*1)=(3,4,2)

2、torch.Tensor.expand()

函数定义:

expand(*sizes) → Tensor

作用:不仅可以对tensor指定的(已存在的)维度进行扩大(复制型扩大),扩大后的shape为*(size)。而且还有类似于unsqueeze的维度扩充功能,新增加的维度将会加在前面。

区别:与repeat不同之处在于,expand传入的参数直接就是将tensor扩大后的shape。

举例说明

a = torch.ones(3, 1)   # 创建3*1的全为1的tensor
b = a.expand(3, 2)     # 对a的维度1进行扩充
c = a.expand(2, 3, 2)  # 对a的维度1进行扩充,并在维度0前加一个维度 
print('a:', a)
print('shape of a:', a.shape)
print('b:', b)
print('shape of b:', b.shape)
print('c:', c)
print('shape of c:', c.shape)

'''   运行结果   '''
a: tensor([[1.],
           [1.],
           [1.]])
shape of a: torch.Size([3, 1])

b: tensor([[1., 1.],
           [1., 1.],
           [1., 1.]])
shape of b: torch.Size([3, 2])

c: tensor([[[1., 1.],
            [1., 1.],
            [1., 1.]],
           [[1., 1.],
            [1., 1.],
            [1., 1.]]])
shape of c: torch.Size([2, 3, 2])

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号