二维tensor:
对于两个张量a
和b
a
的形状是(m, n)
,
那么b
的形状必须是(n,p)
这样a
和b
才能相乘。结果张量的形状将是(m, p)
。
import torch
import numpy as npnp.random.seed(2022)
a = np.random.randint(low=0, high=2, size=(4, 8))
a = torch.tensor(a)
b = np.random.randint(low=0, high=2, size=(8, 9))
b = torch.tensor(b)
c = torch.matmul(a, b)
# or
# c = a @ b
print(c.size())
最后结果为:torch.Size([4, 9])
三维tensor:
a
的形状是(*, m, n)
,其中*
是任意数量的批次维度。b
的形状是(*, n, p)
,其中*
必须与a
中的批次维度相同,n
是收缩维度,它必须与a
中的n
相匹配。】- 最后的结果为(
*, m,p
)
import torch
import numpy as npnp.random.seed(2022)
a = np.random.randint(low=0, high=2, size=(5,4, 8))
a = torch.tensor(a)
b = np.random.randint(low=0, high=2, size=(5,8, 9))
b = torch.tensor(b)
c = torch.matmul(a, b)
# or
# c = a @ b
print(c.size())
最后结果为:torch.Size([5, 4, 9])
四维tensor:
a
的形状是(*,c, m, n)
,其中*
是任意数量的批次维度,c可以理解为通道维度。b
的形状是(*,c, n, p)
,其中*
必须与a
中的批次维度相同,c必须与a中的通道维度相同,n
必须与a
中的n
相匹配。- 最后的结果为(
*,c, m,p
)
import torch
import numpy as npnp.random.seed(2022)
a = np.random.randint(low=0, high=2, size=(7,5,4, 8))
a = torch.tensor(a)
b = np.random.randint(low=0, high=2, size=(7,5,8, 9))
b = torch.tensor(b)
c = torch.matmul(a, b)
# or
# c = a @ b
print(c.size())
最后结果为:torch.Size([7, 5, 4, 9])