PyTorch torch.cat
- 1. `torch.cat`
- 2. Example
- 3. Example
- References
torch
https://pytorch.org/docs/stable/torch.html
torch.cat
(Python function, intorch.cat
)
1. torch.cat
https://pytorch.org/docs/stable/generated/torch.cat.html
torch.cat(tensors, dim=0, *, out=None) -> Tensor
Concatenates the given sequence of seq
tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size (0,)
.
在给定维度上连接给定的 seq
张量序列。所有张量必须具有相同的形状 (连接维度除外),或者是一个大小为 (0,)
的一维空张量。
torch.cat()
can be seen as an inverse operation for torch.split()
and torch.chunk()
.
torch.cat()
可以看作是 torch.split()
和 torch.chunk()
的逆运算。
torch.cat()
can be best understood via examples.
torch.stack()
concatenates the given sequence along a new dimension.
torch.stack()
沿着新维度连接给定的序列。
- Parameters
tensors (sequence of Tensors)
- any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
任何相同类型的张量 Python 序列。提供的非空张量必须具有相同的形状,连接维度除外。
dim (int, optional)
- the dimension over which the tensors are concatenated
连接张量的维度
- Keyword Arguments
out (Tensor, optional)
- the output tensor.
2. Example
(base) yongqiang@yongqiang:~$ python
Python 3.11.4 (main, Jul 5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.0811, 0.4571, -1.5260],[ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 0)
tensor([[ 0.0811, 0.4571, -1.5260],[ 1.4803, -0.0314, -1.5818],[ 0.0811, 0.4571, -1.5260],[ 1.4803, -0.0314, -1.5818],[ 0.0811, 0.4571, -1.5260],[ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 1)
tensor([[ 0.0811, 0.4571, -1.5260, 0.0811, 0.4571, -1.5260, 0.0811, 0.4571, -1.5260],[ 1.4803, -0.0314, -1.5818, 1.4803, -0.0314, -1.5818, 1.4803, -0.0314, -1.5818]])
>>>
>>> exit()
(base) yongqiang@yongqiang:~$
3. Example
https://github.com/karpathy/llama2.c/blob/master/model.py
import torchidxs = torch.randn(1, 5)
print("idxs.shape:", idxs.shape)
print("idxs:\n", idxs)next_idx = torch.randn(1, 1)
print("\nnext_idx.shape:", next_idx.shape)
print("next_idx:\n", next_idx)print("\nidxs.size(1):", idxs.size(1))
idxs_set = torch.cat((idxs, next_idx), dim=1)
print("\nidxs_set.shape:", idxs_set.shape)
print("idxs_set:\n", idxs_set)
/home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py
idxs.shape: torch.Size([1, 5])
idxs:tensor([[-1.3383, 0.1427, 0.0857, 2.2887, 0.1691]])next_idx.shape: torch.Size([1, 1])
next_idx:tensor([[0.4807]])idxs.size(1): 5idxs_set.shape: torch.Size([1, 6])
idxs_set:tensor([[-1.3383, 0.1427, 0.0857, 2.2887, 0.1691, 0.4807]])Process finished with exit code 0
References
[1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/