- Pytorch Tensor 브로드캐스팅

크기가 다른 두 텐서의 산술 연산

import torch

- 텐서 + 스칼라

\[\begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} \oplus 1 = \begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} + \begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix} = \begin{bmatrix} 1 + 1 & 2 + 1 \\ 3 + 1 & 4 + 1 \end{bmatrix}\]
x = torch.Tensor([[1,2],
                  [3,4]])
x + 1
tensor([[2., 3.],
        [4., 5.]])

- 텐서 + 벡터

\[\begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} \oplus \begin{bmatrix} 1 & 2 \end{bmatrix} = \begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} + \begin{bmatrix} 1 & 2 \\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 1 + 1 & 2 + 2 \\ 3 + 1 & 4 + 2 \end{bmatrix}\]
x = torch.Tensor([[1,2],
                  [3,4]])
y = torch.Tensor([1,2])
print(x.shape)
print(y.shape)
print(x + y, (x+y).shape)
torch.Size([2, 2])
torch.Size([2])
tensor([[2., 4.],
        [4., 6.]]) torch.Size([2, 2])


[1, 1, 2]     [1 ,1, 2] 
[      2] --> [1 ,1, 2]
x = torch.Tensor([[[1,2]]])
y = torch.Tensor([1,2])
print(x.shape)
print(y.shape)
print(x + y, (x+y).shape)
torch.Size([1, 1, 2])
torch.Size([2])
tensor([[[2., 4.]]]) torch.Size([1, 1, 2])

- 텐서 + 텐서

\(\begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} \oplus \begin{bmatrix} 1\\ 2 \end{bmatrix} = \begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} + \begin{bmatrix} 1 & 1 \\ 2 & 2 \end{bmatrix} = \begin{bmatrix} 1 + 1 & 2 + 1 \\ 3 + 2 & 4 + 2 \end{bmatrix}\)

[2,2]     [2,2]
[2,1] --> [2,2]
x = torch.Tensor([[1,2],
                  [3,4]])
y = torch.Tensor([[1],
                  [2]])
print(x.shape)
print(y.shape)
print(x + y, (x+y).shape)
torch.Size([2, 2])
torch.Size([2, 1])
tensor([[2., 3.],
        [5., 6.]]) torch.Size([2, 2])

- General Broadcasting Rules

Reference - NumPy.org

Image  (3d array): 256 x 256 x 3
Scale  (1d array):             3
Result (3d array): 256 x 256 x 3

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5

- Broadcastable arrays

- broadcasting이 되는 경우

A      (2d array):  5 x 4
B      (1d array):      1
Result (2d array):  5 x 4

A      (2d array):  5 x 4
B      (1d array):      4
Result (2d array):  5 x 4

A      (3d array):  15 x 3 x 5
B      (3d array):  15 x 1 x 5
Result (3d array):  15 x 3 x 5

A      (3d array):  15 x 3 x 5
B      (2d array):       3 x 5
Result (3d array):  15 x 3 x 5

A      (3d array):  15 x 3 x 5
B      (2d array):       3 x 1
Result (3d array):  15 x 3 x 5

- broadcasting이 안되는 경우

A      (1d array):  3
B      (1d array):  4 # trailing dimensions do not match

A      (2d array):      2 x 1
B      (3d array):  8 x 4 x 3 # second from last dimensions mismatched

태그:

카테고리:

업데이트:

댓글남기기