[Pytorch] Tensor Broadcasting
- Pytorch Tensor 브로드캐스팅
크기가 다른 두 텐서의 산술 연산
import torch
- 텐서 + 스칼라
\[\begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} \oplus 1 = \begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} + \begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix} = \begin{bmatrix} 1 + 1 & 2 + 1 \\ 3 + 1 & 4 + 1 \end{bmatrix}\]x = torch.Tensor([[1,2],
[3,4]])
x + 1
tensor([[2., 3.],
[4., 5.]])
- 텐서 + 벡터
\[\begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} \oplus \begin{bmatrix} 1 & 2 \end{bmatrix} = \begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} + \begin{bmatrix} 1 & 2 \\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 1 + 1 & 2 + 2 \\ 3 + 1 & 4 + 2 \end{bmatrix}\]x = torch.Tensor([[1,2],
[3,4]])
y = torch.Tensor([1,2])
print(x.shape)
print(y.shape)
print(x + y, (x+y).shape)
torch.Size([2, 2])
torch.Size([2])
tensor([[2., 4.],
[4., 6.]]) torch.Size([2, 2])
[1, 1, 2] [1 ,1, 2]
[ 2] --> [1 ,1, 2]
x = torch.Tensor([[[1,2]]])
y = torch.Tensor([1,2])
print(x.shape)
print(y.shape)
print(x + y, (x+y).shape)
torch.Size([1, 1, 2])
torch.Size([2])
tensor([[[2., 4.]]]) torch.Size([1, 1, 2])
- 텐서 + 텐서
\(\begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} \oplus \begin{bmatrix} 1\\ 2 \end{bmatrix} = \begin{bmatrix} 1 & 2\\ 3 & 4 \end{bmatrix} + \begin{bmatrix} 1 & 1 \\ 2 & 2 \end{bmatrix} = \begin{bmatrix} 1 + 1 & 2 + 1 \\ 3 + 2 & 4 + 2 \end{bmatrix}\)
[2,2] [2,2]
[2,1] --> [2,2]
x = torch.Tensor([[1,2],
[3,4]])
y = torch.Tensor([[1],
[2]])
print(x.shape)
print(y.shape)
print(x + y, (x+y).shape)
torch.Size([2, 2])
torch.Size([2, 1])
tensor([[2., 3.],
[5., 6.]]) torch.Size([2, 2])
- General Broadcasting Rules
Reference - NumPy.org
Image (3d array): 256 x 256 x 3
Scale (1d array): 3
Result (3d array): 256 x 256 x 3
A (4d array): 8 x 1 x 6 x 1
B (3d array): 7 x 1 x 5
Result (4d array): 8 x 7 x 6 x 5
- Broadcastable arrays
- broadcasting이 되는 경우
A (2d array): 5 x 4
B (1d array): 1
Result (2d array): 5 x 4
A (2d array): 5 x 4
B (1d array): 4
Result (2d array): 5 x 4
A (3d array): 15 x 3 x 5
B (3d array): 15 x 1 x 5
Result (3d array): 15 x 3 x 5
A (3d array): 15 x 3 x 5
B (2d array): 3 x 5
Result (3d array): 15 x 3 x 5
A (3d array): 15 x 3 x 5
B (2d array): 3 x 1
Result (3d array): 15 x 3 x 5
- broadcasting이 안되는 경우
A (1d array): 3
B (1d array): 4 # trailing dimensions do not match
A (2d array): 2 x 1
B (3d array): 8 x 4 x 3 # second from last dimensions mismatched
댓글남기기