pytorch & numpy广播法则
阅读原文时间:2023年07月11日阅读:1

广播法则

  1. 所有数组向维度最高的数组看齐,若维度不足则在最前面的维度用1补齐

  2. 扩展维度后,所有数组在某一维度相同或者长度为1,否则不能计算

  3. 当可以计算时,将长度为1的维度扩展为另一数组相应维度的长度

    a = torch.ones(3, 2)
    b = torch.zeros(2,3,1)
    a + b

    a : (3, 2)-->(1, 3, 2)

    a : (1, 3, 2)-->(2, 3, 2)

    b : (2, 3, 1)-->(2, 3, 2)

    a + b : (2, 3, 2)

手工实现广播(建议,较为直观):

a.view(1, 3, 2).expand(2, 3, 2)
b.expand(2, 3, 2)
# repeat和expand功能类似,但是repeat会把数据复制多份,会占用额外空间

手机扫一扫

移动阅读更方便

阿里云服务器
腾讯云服务器
七牛云服务器