如题所示,我从若干年前使用pytorch开始,就发现它同时有torch.tensor()
与torch.Tensor()
两个方法,都可以创建新tensor,然而我并不清楚二者有何区别,只知道torch.tensor()
用法常见些。今天又遇上这个问题,势必要把他搞清楚!
太长不看版
下面的长篇大论确实有点烦,在这里用简单几句话说一下二者的区别吧:
torch.tensor()
,以及我们常见的torch.zeros()
, torch.empty()
等,都是pytorch专门实现的一种所谓“工厂函数”的东西,它实际上封装了多个过程:
1. 对象(tensor)创建过程
2. 给创建出的tensor赋值,赋值时会参考你传入的数据dtype,将你的tensor转成特定的类型,如torch.long/torch.double(奶奶的,torch的api改的挺频繁,我说咋找不到当年的torch.LongTensor/torch.DoubleTensor了........)
3. 根据你给出的device信息,完成tensor的设备操作
正文开始
在编程中,**工厂函数(Factory Function)**是一种设计模式,它通过一个函数来创建并返回对象,而无需直接调用类的构造函数。这种模式的核心目的是__封装对象的创建过程__,提供更灵活、更安全的对象生成方式。
PyTorch 中的工厂函数
在 PyTorch 的上下文中,工厂函数特指那些**直接生成张量(Tensor
)**的函数。这些函数封装了张量的内存分配和初始化逻辑,用户无需手动处理底层细节(如内存布局、设备分配等)。
常见 PyTorch 工厂函数
函数 | 作用 | 示例 |
---|---|---|
torch.zeros() | 创建全零张量 | torch.zeros(3,3) |
torch.ones() | 创建全一张量 | torch.ones(2,2) |
torch.rand() | 创建均匀分布([0,1) )的随机张量 | torch.rand(4) |
torch.randn() | 创建标准正态分布(N(0,1) )张量 | torch.randn(5,5) |
torch.empty() | 创建未初始化张量(内容随机) | torch.empty(10) |
torch.tensor() | 从数据(如列表)直接创建张量 | torch.tensor([1,2,3]) |
torch.arange() | 创建等差序列张量 | torch.arange(0, 5, step=1) |
工厂函数 vs 直接构造函数
PyTorch 张量的底层构造函数是 torch.Tensor()
,但直接使用它可能不够灵活。工厂函数是对构造函数的增强封装,提供了更人性化的接口。
对比示例
# 方式1: 使用构造函数(不推荐)
x = torch.Tensor(3, 3) # 创建一个 3x3 张量,内容未初始化(可能是任意值)# 方式2: 使用工厂函数(推荐)
y = torch.zeros(3, 3) # 明确生成全零张量
z = torch.rand(3, 3) # 明确生成随机张量
工厂函数的优势
-
语义清晰:
torch.rand(3,3)
直接表明生成的是均匀分布的随机张量。- 构造函数
torch.Tensor(3,3)
仅创建未初始化的张量,行为不明确。
-
隐式类型推断:
- 工厂函数自动推断数据类型(如
float32
):a = torch.tensor([1, 2, 3]) # 自动推断为 int64 b = torch.tensor([1.0, 2.0, 3.0]) # 自动推断为 float32
- 构造函数可能需要手动指定类型:
c = torch.Tensor([1, 2, 3]) # 强制转换为 float32
- 工厂函数自动推断数据类型(如
-
设备一致性:
- 工厂函数通过
device
参数统一控制张量位置(CPU/GPU):d = torch.rand(3, 3, device="cuda") # 直接在 GPU 上生成张量
- 工厂函数通过
-
参数灵活性:
- 支持动态参数(如形状、分布参数):
shape = (2, 2) e = torch.normal(mean=0, std=1, size=shape) # 灵活生成正态分布张量
- 支持动态参数(如形状、分布参数):
工厂函数的设计哲学
-
封装复杂性:
- 隐藏张量底层的内存分配、类型转换、设备分配等细节。
- 例如:
torch.randn()
内部需要调用随机数生成器和内存分配器。
-
提供安全保证:
- 避免用户直接操作未初始化内存(如
torch.empty()
的随机内容可能引发意外行为)。
- 避免用户直接操作未初始化内存(如
-
接口统一性:
- 所有工厂函数遵循一致的命名规则(如
torch.xxx()
),降低学习成本。
- 所有工厂函数遵循一致的命名规则(如
工厂函数的底层实现
以 torch.zeros()
为例,其伪代码逻辑大致如下:
def zeros(*size, dtype=None, device=None):# 1. 分配内存tensor = torch.Tensor.__new__(Tensor) # 调用底层构造函数tensor.resize_(*size)# 2. 初始化数据tensor.fill_(0)# 3. 处理类型和设备if dtype is not None:tensor = tensor.to(dtype)if device is not None:tensor = tensor.to(device)return tensor
何时使用工厂函数?
- 默认选择:绝大多数情况下应优先使用工厂函数,因为它们更安全、更直观。
- 例外场景:
- 需要完全控制内存分配时(如预分配缓冲区),可使用
torch.empty()
+ 原地操作(如normal_()
)。 - 需要兼容旧代码或特殊数据类型时,可能需要直接调用构造函数。
- 需要完全控制内存分配时(如预分配缓冲区),可使用
总结
PyTorch 的工厂函数是用户友好的张量生成接口,它们:
- 封装底层复杂性,提供简洁的语义(如
zeros
、randn
)。 - 确保数据安全和类型正确性。
- 统一管理设备(CPU/GPU)和数据类型的隐式逻辑。
这种设计让开发者能更专注于算法逻辑,而非底层内存管理细节。