pytorch小记(二):pytorch矩阵乘法:torch.cat(tensors, dim=0)
- 语法
- 使用规则
- 示例 1:在第 0 维(行)拼接
- 示例 2:在第 1 维(列)拼接
- 示例 3:在高维张量上拼接
- 初始张量
- 1. 在 `dim=0` 拼接
- 2. 在 `dim=1` 拼接
- 3. 在 `dim=2` 拼接
- 总结
- 示例 4:拼接不同形状的张量(错误示范)
- 总结
在 PyTorch 中,torch.cat()
是一种用于在指定维度上连接张量的操作。它能够将多个张量沿某个轴拼接成一个新的张量。
语法
torch.cat(tensors, dim=0)
tensors
:一个包含多个待拼接张量的列表或元组。这些张量在指定的dim
维度以外的所有维度上必须具有相同的形状。dim
:指定在哪个维度上进行拼接操作。
使用规则
- 在指定维度上,张量的形状可以不同(因为会拼接)。
- 在其他维度上,张量的形状必须相同。
示例 1:在第 0 维(行)拼接
x = torch.tensor([[1, 2], [3, 4]]) # 形状 (2, 2)
y = torch.tensor([[5, 6], [7, 8]]) # 形状 (2, 2)result = torch.cat((x, y), dim=0) # 在第 0 维拼接
print(result)
输出:
tensor([[1, 2],[3, 4],[5, 6],[7, 8]])
- 原始张量
x
和y
在第 0 维上(行方向)拼接,因此新张量的形状为(4, 2)
。
示例 2:在第 1 维(列)拼接
x = torch.tensor([[1, 2], [3, 4]]) # 形状 (2, 2)
y = torch.tensor([[5, 6], [7, 8]]) # 形状 (2, 2)result = torch.cat((x, y), dim=1) # 在第 1 维拼接
print(result)
输出:
tensor([[1, 2, 5, 6],[3, 4, 7, 8]])
- 原始张量
x
和y
在第 1 维上(列方向)拼接,因此新张量的形状为(2, 4)
。
示例 3:在高维张量上拼接
我们来创建两个高维张量 x
和 y
,并分别在不同维度(dim=0
, dim=1
, dim=2
)上使用 torch.cat
进行拼接,展示具体计算结果。
初始张量
x = torch.tensor([[[1, 2, 3], [4, 5, 6]],[[7, 8, 9], [10, 11, 12]]
]) # 形状 (2, 2, 3)y = torch.tensor([[[13, 14, 15], [16, 17, 18]],[[19, 20, 21], [22, 23, 24]]
]) # 形状 (2, 2, 3)
x
和y
是形状为(2, 2, 3)
的 3D 张量:x: [[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]] y: [[[13, 14, 15],[16, 17, 18]],[[19, 20, 21],[22, 23, 24]]]
1. 在 dim=0
拼接
result_dim0 = torch.cat((x, y), dim=0)
print(result_dim0.shape) # torch.Size([4, 2, 3])
print(result_dim0)
拼接逻辑:
- 在第 0 维度(最外层)拼接,结果张量包含 4 个“块”,每个“块”的形状仍然是
(2, 3)
。
结果:
result_dim0:
[[[ 1, 2, 3],[ 4, 5, 6]],[[ 7, 8, 9],[ 10, 11, 12]],[[ 13, 14, 15],[ 16, 17, 18]],[[ 19, 20, 21],[ 22, 23, 24]]]
2. 在 dim=1
拼接
result_dim1 = torch.cat((x, y), dim=1)
print(result_dim1.shape) # torch.Size([2, 4, 3])
print(result_dim1)
拼接逻辑:
- 在第 1 维度(每个“块”中的行)拼接,结果张量包含 2 个“块”,每个“块”增加了 2 行,形状从
(2, 3)
变为(4, 3)
。
结果:
result_dim1:
[[[ 1, 2, 3],[ 4, 5, 6],[ 13, 14, 15],[ 16, 17, 18]],[[ 7, 8, 9],[ 10, 11, 12],[ 19, 20, 21],[ 22, 23, 24]]]
3. 在 dim=2
拼接
result_dim2 = torch.cat((x, y), dim=2)
print(result_dim2.shape) # torch.Size([2, 2, 6])
print(result_dim2)
拼接逻辑:
- 在第 2 维度(每行中的列)拼接,结果张量包含 2 个“块”,每个“块”有 2 行,但每行的列数增加了一倍,从 3 列变为 6 列。
结果:
result_dim2:
[[[ 1, 2, 3, 13, 14, 15],[ 4, 5, 6, 16, 17, 18]],[[ 7, 8, 9, 19, 20, 21],[ 10, 11, 12, 22, 23, 24]]]
总结
dim 值 | 拼接维度 | 结果形状 | 拼接效果 |
---|---|---|---|
dim=0 | 最外层 | (4, 2, 3) | 增加块的数量(纵向堆叠) |
dim=1 | 每块的行数 | (2, 4, 3) | 增加每块的行数(横向堆叠行) |
dim=2 | 每行的列数 | (2, 2, 6) | 增加每行的列数(横向堆叠列) |
通过改变 dim
,torch.cat
可以在不同维度上灵活地拼接张量。
示例 4:拼接不同形状的张量(错误示范)
如果张量在非拼接维度上的形状不同,会抛出错误:
x = torch.tensor([[1, 2], [3, 4]]) # 形状 (2, 2)
y = torch.tensor([[5, 6, 7]]) # 形状 (1, 3)result = torch.cat((x, y), dim=0) # 抛出错误
错误信息:
RuntimeError: Sizes of tensors must match except in dimension 0. Got 2 and 3 in dimension 1
如果希望在行方向 dim=0
拼接,可以通过 补零
或 裁剪
等方式使列数一致。
补零torch.nn.functional.pad
:
import torch
import torch.nn.functional as Fx = torch.tensor([[1, 2], [3, 4]]) # 形状 (2, 2)
y = torch.tensor([[5, 6, 7]]) # 形状 (1, 3)# 对 x 补零到列数 3
x_padded = F.pad(x, (0, 1)) # 在列方向右侧补 1 列零
# x_padded 形状: (2, 3)# 在 dim=0 拼接
result = torch.cat((x_padded, y), dim=0)
print(result)
结果:
tensor([[1, 2, 0],[3, 4, 0],[5, 6, 7]])
但是
result = torch.cat((x_padded, y), dim=1)
则还是错误的!!!
总结
torch.cat()
用于连接张量,指定的dim
决定了在哪个维度上进行拼接。- 拼接维度的大小是累加的,其他维度的大小必须一致。
- 如果不满足上述规则,会抛出错误。
通过这种操作,你可以灵活地调整和组织张量的数据结构。