在 PyTorch 中,fill_
是一个张量的原地操作方法,用于将张量中的所有元素填充为指定的值。
方法签名
Tensor.fill_(value)
参数
value
(float or int): 要填充到张量中的值。
返回值
返回调用该方法的张量本身,且是经过修改后的张量。
特点
- 原地操作:
fill_
是原地操作,会直接修改调用它的张量的内容,而不会创建新的张量。 - 广播不可用: 它直接填充整个张量的所有元素,不支持像其他操作那样进行广播。
- 效率高: 由于是原地操作,减少了内存分配和复制,效率较高。
示例代码
基本用法
import torch# 创建一个张量
tensor = torch.zeros(3, 3)
print("填充前:")
print(tensor)# 使用 fill_ 填充
tensor.fill_(5)
print("填充后:")
print(tensor)
修改部分维度的张量
可以通过索引或切片选中部分张量,然后对选中的子张量调用 fill_
:
tensor = torch.zeros(4, 4)# 修改张量的部分内容
tensor[1:3, 1:3].fill_(9)
print(tensor)
使用整型填充
tensor = torch.ones(3, 3)# 填充整数值
tensor.fill_(42)
print(tensor)
注意事项
- 原地操作: 由于是原地操作,调用此方法后,原张量的内容会被永久更改。
- 类型一致性: 填充值的类型需要与张量的数据类型兼容,例如,填充一个
float
类型的张量时不能直接传入字符串等无关类型。
先克隆张量再填充: clone
+ fill_
new_tensor = tensor.clone().fill_(10)