PyTorch 切片运算 {Slice Operator}
- 1. `[:, -1, :]`
- 2. `[:, [-1], :]`
- References
1. [:, -1, :]
https://github.com/karpathy/llama2.c/blob/master/model.py
import torchlogits = torch.arange(1, 16)
print("logits.shape:", logits.shape)
print("logits:\n", logits)logits = logits.view(1, 3, 5)
print("\nlogits.shape:", logits.shape)
print("logits:\n", logits)final_logit_1 = logits[:, -1, :]
print("\nfinal_logit_1.shape:", final_logit_1.shape)
print("final_logit_1:\n", final_logit_1)final_logit_2 = logits[:, -1]
print("\nfinal_logit_2.shape:", final_logit_2.shape)
print("final_logit_2:\n", final_logit_2)final_logit_3 = logits[:, 2, :]
print("\nfinal_logit_3.shape:", final_logit_3.shape)
print("final_logit_3:\n", final_logit_3)final_logit_4 = logits[:, 2]
print("\nfinal_logit_4.shape:", final_logit_4.shape)
print("final_logit_4:\n", final_logit_4)
/home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py
logits.shape: torch.Size([15])
logits:tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])logits.shape: torch.Size([1, 3, 5])
logits:tensor([[[ 1, 2, 3, 4, 5],[ 6, 7, 8, 9, 10],[11, 12, 13, 14, 15]]])final_logit_1.shape: torch.Size([1, 5])
final_logit_1:tensor([[11, 12, 13, 14, 15]])final_logit_2.shape: torch.Size([1, 5])
final_logit_2:tensor([[11, 12, 13, 14, 15]])final_logit_3.shape: torch.Size([1, 5])
final_logit_3:tensor([[11, 12, 13, 14, 15]])final_logit_4.shape: torch.Size([1, 5])
final_logit_4:tensor([[11, 12, 13, 14, 15]])Process finished with exit code 0
2. [:, [-1], :]
https://github.com/karpathy/llama2.c/blob/master/model.py
import torchlogits = torch.arange(1, 16)
print("logits.shape:", logits.shape)
print("logits:\n", logits)logits = logits.view(1, 3, 5)
print("\nlogits.shape:", logits.shape)
print("logits:\n", logits)final_logit_1 = logits[:, -1, :]
print("\nfinal_logit_1.shape:", final_logit_1.shape)
print("final_logit_1:\n", final_logit_1)final_logit_2 = logits[:, [-1], :]
print("\nfinal_logit_2.shape:", final_logit_2.shape)
print("final_logit_2:\n", final_logit_2)final_logit_3 = logits[:, -1]
print("\nfinal_logit_3.shape:", final_logit_3.shape)
print("final_logit_3:\n", final_logit_3)final_logit_4 = logits[:, [-1]]
print("\nfinal_logit_4.shape:", final_logit_4.shape)
print("final_logit_4:\n", final_logit_4)
/home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py
logits.shape: torch.Size([15])
logits:tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])logits.shape: torch.Size([1, 3, 5])
logits:tensor([[[ 1, 2, 3, 4, 5],[ 6, 7, 8, 9, 10],[11, 12, 13, 14, 15]]])final_logit_1.shape: torch.Size([1, 5])
final_logit_1:tensor([[11, 12, 13, 14, 15]])final_logit_2.shape: torch.Size([1, 1, 5])
final_logit_2:tensor([[[11, 12, 13, 14, 15]]])final_logit_3.shape: torch.Size([1, 5])
final_logit_3:tensor([[11, 12, 13, 14, 15]])final_logit_4.shape: torch.Size([1, 1, 5])
final_logit_4:tensor([[[11, 12, 13, 14, 15]]])Process finished with exit code 0
References
[1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/