您的位置:首页 > 娱乐 > 八卦 > PyTorch 基础学习(3) - 张量的数学操作

PyTorch 基础学习(3) - 张量的数学操作

2024/12/22 0:10:34 来源:https://blog.csdn.net/fenglingguitar/article/details/141139808  浏览:    关键词:PyTorch 基础学习(3) - 张量的数学操作

下面是关于PyTorch中常见数学操作的概述和教程,包括逐点运算、比较操作、线性代数操作等,突出每个操作的重点用法和示例。

逐点操作 (Pointwise Operations)

1. torch.abs
  • 功能: 计算输入张量的每个元素的绝对值。
  • 用法: torch.abs(input)
  • 示例:
    import torch
    tensor = torch.FloatTensor([-1, -2, 3])
    result = torch.abs(tensor)
    print(result)  # 输出: tensor([1., 2., 3.])
    
2. torch.acos
  • 功能: 返回输入张量每个元素的反余弦。
  • 用法: torch.acos(input)
  • 示例:
    a = torch.tensor([0.5, -1.0, 1.0])
    result = torch.acos(a)
    print(result)  # 输出: tensor([1.0472, 3.1416, 0.0000])
    
3. torch.add
  • 功能: 将一个标量值逐元素加到输入张量上。
  • 用法: torch.add(input, value)
  • 示例:
    a = torch.randn(4)
    result = torch.add(a, 10)
    print(result)
    
4. torch.mul
  • 功能: 用标量值乘以输入张量的每个元素。
  • 用法: torch.mul(input, value)
  • 示例:
    a = torch.randn(3)
    result = torch.mul(a, 100)
    print(result)
    

线性代数操作 (Linear Algebra Operations)

1. torch.mm
  • 功能: 矩阵乘法。
  • 用法: torch.mm(mat1, mat2)
  • 示例:
    mat1 = torch.randn(2, 3)
    mat2 = torch.randn(3, 3)
    result = torch.mm(mat1, mat2)
    print(result)
    
2. torch.inverse
  • 功能: 计算方阵的逆矩阵。
  • 用法: torch.inverse(input)
  • 示例:
    x = torch.rand(3, 3)
    result = torch.inverse(x)
    print(result)
    
3. torch.svd
  • 功能: 奇异值分解。
  • 用法: torch.svd(input)
  • 示例:
    a = torch.randn(5, 3)
    u, s, v = torch.svd(a)
    print(u, s, v)
    

比较操作 (Comparison Operations)

1. torch.eq
  • 功能: 比较两个张量的元素是否相等。
  • 用法: torch.eq(input, other)
  • 示例:
    a = torch.tensor([1, 2, 3])
    b = torch.tensor([1, 1, 4])
    result = torch.eq(a, b)
    print(result)  # 输出: tensor([ True, False, False])
    
2. torch.gt
  • 功能: 比较两个张量的元素是否大于另一个。
  • 用法: torch.gt(input, other)
  • 示例:
    a = torch.tensor([1, 2, 3])
    b = torch.tensor([1, 1, 4])
    result = torch.gt(a, b)
    print(result)  # 输出: tensor([False,  True, False])
    
3. torch.max
  • 功能: 返回输入张量所有元素的最大值。
  • 用法: torch.max(input)
  • 示例:
    a = torch.randn(1, 3)
    result = torch.max(a)
    print(result)
    

下面是一个包含逐点操作、线性代数操作和比较操作的PyTorch综合应用实例。该示例展示了如何处理一个简单的线性回归任务。

应用实例:线性回归

在这个例子中,我们将使用PyTorch来进行一个简单的线性回归任务。我们将生成一些合成数据,使用线性回归模型进行拟合,并展示模型参数的更新过程。

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 生成合成数据
torch.manual_seed(0)  # 固定随机种子
X = torch.linspace(0, 10, 100).view(-1, 1)  # 输入特征
y = 2 * X + 1 + torch.randn(X.size()) * 2  # 线性关系加噪声# 定义线性回归模型
class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(1, 1)  # 输入和输出都是1维def forward(self, x):return self.linear(x)# 初始化模型、损失函数和优化器
model = LinearRegressionModel()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器# 训练模型
num_epochs = 100
for epoch in range(num_epochs):# 前向传播outputs = model(X)loss = criterion(outputs, y)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 每10个epoch输出一次损失if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 获取训练后的参数
with torch.no_grad():slope = model.linear.weight.item()intercept = model.linear.bias.item()print(f'Learned parameters: Slope={slope:.4f}, Intercept={intercept:.4f}')# 可视化结果
predicted = model(X).detach().numpy()
plt.plot(X.numpy(), y.numpy(), 'ro', label='Original data')
plt.plot(X.numpy(), predicted, label='Fitted line')
plt.legend()
plt.show()

输出:
在这里插入图片描述

说明

  1. 数据生成: 我们生成了一组带有噪声的线性数据,线性关系为 (y = 2x + 1)。

  2. 模型定义: 使用nn.Linear定义了一个简单的线性回归模型。

  3. 损失函数和优化器: 使用均方误差作为损失函数,随机梯度下降作为优化器。

  4. 训练过程: 在训练过程中,我们执行前向传播计算损失,然后通过反向传播更新模型参数。

  5. 结果展示: 训练完成后,我们打印出学习到的参数,并使用matplotlib绘制原始数据和拟合直线。

总结

  • 逐点操作用于对张量中的每个元素进行相同的数学运算。
  • 线性代数操作涉及矩阵运算,如矩阵乘法、逆矩阵和分解。
  • 比较操作用于比较张量元素之间的关系。
  • PyTorch中的这些操作通常具有广播机制,允许对不同形状的张量进行运算,只要它们的形状是兼容的。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com