当前位置: 首页> 教育> 高考 > 2021重大军事新闻_初中生怎么升大专学历_百度竞价调价软件_天津百度优化

2021重大军事新闻_初中生怎么升大专学历_百度竞价调价软件_天津百度优化

时间:2025/7/19 10:15:11来源:https://blog.csdn.net/qq_27390023/article/details/146187529 浏览次数:0次
2021重大军事新闻_初中生怎么升大专学历_百度竞价调价软件_天津百度优化

torch.distributions.categorical.Categorical 是 PyTorch 提供的离散概率分布(Categorical Distribution)类,用于从类别型概率分布(Categorical Distribution)中采样随机变量。

1. 语法

torch.distributions.categorical.Categorical(probs=None, logits=None)

2. 参数

参数作用
probs概率分布,形状为 [batch_size, num_classes],其中每行的值应为非负数,且每行的总和为 1.0
logits也可以用 logits 代替 probs,即未归一化的分数(softmax 之前的值),PyTorch 会自动计算 softmax 归一化

⚠ 注意probs 和 logits 只能二选一,否则会报错。

3. 基本用法

(1)用 probs 采样
import torch# 定义类别概率(3 个类别)
probs = torch.tensor([0.1, 0.3, 0.6])  # 类别 0, 1, 2 的概率分别是 10%, 30%, 60%# 创建 Categorical 分布
dist = torch.distributions.categorical.Categorical(probs)# 采样一个类别
sample = dist.sample()
print(sample)  # 输出可能是 0, 1, 或 2,概率分别为 10%, 30%, 60%
(2)用 logits 采样
logits = torch.tensor([1.0, 2.0, 3.0])  # 未归一化的 logits
dist = torch.distributions.categorical.Categorical(logits=logits)sample = dist.sample()
print(sample)  # 输出 0, 1, 2 的概率由 softmax(logits) 决定

内部计算方式

probs = torch.nn.functional.softmax(logits, dim=-1)

所以,logits [1.0, 2.0, 3.0] 会被转换为:

probs = torch.tensor([0.0900, 0.2447, 0.6652])  # softmax 归一化后的概率

4. 批量采样

如果 probs 是一个二维 Tensor,则可以对多个分布进行批量采样:

probs = torch.tensor([[0.2, 0.8], [0.5, 0.5], [0.9, 0.1]])  # 3 组分布
dist = torch.distributions.categorical.Categorical(probs)samples = dist.sample()
print(samples)  # 每个样本是 0 或 1(按不同的行概率分布)

解释

  • probs[0] = [0.2, 0.8],第 1 个分布中 1 的概率是 80%
  • probs[1] = [0.5, 0.5],第 2 个分布是均匀分布
  • probs[2] = [0.9, 0.1],第 3 个分布中 0 的概率是 90%

5. 计算 log_prob(计算样本的对数概率)

可以计算某个类别出现的 对数概率

probs = torch.tensor([0.1, 0.3, 0.6])
dist = torch.distributions.categorical.Categorical(probs)log_prob = dist.log_prob(torch.tensor(2))  # 计算类别 2 的对数概率
print(log_prob)  # 输出: -0.5108 (即 log(0.6))

log_prob 的作用:

  • log_prob(x) 计算 类别 x 出现的 log 概率,即 log(P(x))
  • 常用于计算损失函数(如交叉熵)

总结

torch.distributions.categorical.Categorical 是 PyTorch 中用于处理离散分类分布的工具。它支持从分布中采样、计算对数概率和熵,并且可以处理多维输入。在自然语言处理、掩码语言模型和强化学习等任务中,分类分布是一个非常重要的工具。

关键字:2021重大军事新闻_初中生怎么升大专学历_百度竞价调价软件_天津百度优化

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: