您的位置:首页 > 科技 > 能源 > 网站搭建徐州百度网络_电商要怎么做起来_优秀营销案例分享_seo免费优化软件

网站搭建徐州百度网络_电商要怎么做起来_优秀营销案例分享_seo免费优化软件

2024/10/31 21:24:50 来源:https://blog.csdn.net/qq_27390023/article/details/143373141  浏览:    关键词:网站搭建徐州百度网络_电商要怎么做起来_优秀营销案例分享_seo免费优化软件
网站搭建徐州百度网络_电商要怎么做起来_优秀营销案例分享_seo免费优化软件

PyTorch 的 torch.distributions 模块提供了对概率分布的全面支持,允许用户通过对象化的方式定义、操作和采样各种常见分布。该模块适用于概率建模、生成模型(如变分自动编码器 VAE)、强化学习等需要使用分布的场景。每种分布都有通用的接口来计算概率、对数概率、采样等。

常用的分布类型

以下是 torch.distributions 模块中一些常见的分布:

Categorical
  • 离散分类分布,可以用于从多类别中采样。
  • 用法:输入类别的概率 probs 或对数概率 logits
from torch.distributions import Categorical
probs = torch.tensor([0.2, 0.5, 0.3])
dist = Categorical(probs)
sample = dist.sample()  # 从 [0, 1, 2] 中采样
Normal
  • 正态分布(高斯分布),常用于生成连续值或噪声。
  • 用法:需要指定均值 mean 和标准差 std
from torch.distributions import Normal
dist = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
sample = dist.sample((5,))  # 从正态分布中采样 5 个样本
Bernoulli
  • 伯努利分布,用于二项采样。
  • 用法:输入事件发生的概率 probs
from torch.distributions import Bernoulli
dist = Bernoulli(torch.tensor([0.7]))
sample = dist.sample()  # 生成 0 或 1 的二值样本
Poisson
  • 泊松分布,常用于建模单位时间内事件发生的次数。
  • 用法:给定事件发生的速率参数 rate
from torch.distributions import Poisson
dist = Poisson(torch.tensor([3.0]))  # λ = 3.0
sample = dist.sample((5,))  # 生成 5 个泊松分布样本
Exponential
  • 指数分布,通常用于建模事件间隔的时间。
  • 用法:给定 rate 参数,事件发生的速率。
from torch.distributions import Exponential
dist = Exponential(torch.tensor([1.0]))
sample = dist.sample((5,))  # 从指数分布中采样 5 个样本
Beta
  • Beta 分布,适用于生成 0 到 1 之间的连续值,常用于贝叶斯建模。
  • 用法:需要指定两个形状参数 alpha 和 beta
from torch.distributions import Beta
dist = Beta(torch.tensor([2.0]), torch.tensor([5.0]))
sample = dist.sample((5,))  # 采样 5 个 Beta 分布样本
Gamma
  • Gamma 分布,用于生成正的连续值,广泛应用于队列理论、贝叶斯统计。
  • 用法:需要指定 concentration 和 rate 参数。
from torch.distributions import Gamma
dist = Gamma(torch.tensor([1.0]), torch.tensor([2.0]))
sample = dist.sample((5,))  # 从 Gamma 分布中采样 5 个样本
MultivariateNormal
  • 多变量正态分布,用于生成多维连续值,适合应用在多维高斯模型或联合概率建模中。
  • 用法:需要均值向量 mean 和协方差矩阵 covariance_matrix
from torch.distributions import MultivariateNormal
mean = torch.zeros(2)
cov_matrix = torch.eye(2)
dist = MultivariateNormal(mean, cov_matrix)
sample = dist.sample()  # 生成 2 维样本

分布对象的通用方法

每个分布对象一般都支持以下几个通用方法:

sample:从分布中采样。
sample = dist.sample((5,))  # 采样 5 个样本
log_prob:计算样本的对数概率密度(或对数概率质量)。
log_prob = dist.log_prob(sample)  # 计算给定样本的对数概率
entropy:计算分布的熵。
entropy = dist.entropy()  # 得到分布的熵值
mean 和 variance:可以直接访问分布的均值和方差。
mean = dist.mean
variance = dist.variance
rsample:用于重新采样,支持自动微分(通常用于生成模型中的重参数化技巧)。
rsample = dist.rsample()  # 用于支持重参数化的采样

torch.distributions 的典型应用

  • 生成模型:如变分自动编码器(VAE),通过从正态分布中采样来生成潜在空间的点。
  • 概率推断:如马尔可夫链蒙特卡洛(MCMC)和强化学习中基于策略的抽样。
  • 贝叶斯建模:如 Beta 分布和 Gamma 分布在贝叶斯统计中的应用。

torch.distributions 模块提供了灵活的 API,支持定义复杂的概率模型和采样过程。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com