Jensen不等式实战解析(一)——从信息论到机器学习

📅 2026/6/20 3:08:04
Jensen不等式实战解析(一)——从信息论到机器学习
1. 初识Jensen不等式从凸函数到概率期望第一次听说Jensen不等式是在研究生时期的概率论课上。当时教授在黑板上画了一个凸函数的图像然后在曲线上方随意点了几个点用直线连接起来。这个简单的几何演示让我立刻理解了Jensen不等式的核心思想对于凸函数函数值的平均值总是大于等于平均值的函数值。用数学语言来说对于一个凸函数f如果λ₁ λ₂ ... λₙ 1且λᵢ ≥ 0那么有 f(∑λᵢxᵢ) ≤ ∑λᵢf(xᵢ)这个看似简单的式子却在信息论和机器学习中扮演着关键角色。举个生活中的例子假设f(x)代表咖啡的价格随温度变化的函数温度越高价格越贵且涨价幅度越来越大这就是凸函数。那么两杯不同温度咖啡的平均温度对应的价格会低于这两杯咖啡价格的算术平均。这就是Jensen不等式在现实中的直观体现。2. 信息论中的核心应用从熵到KL散度2.1 信息熵的凸性证明在信息论中Jensen不等式最经典的应用就是证明信息熵的凸性。香农熵H(X)-∑p(x)log p(x)实际上就是一个关于概率分布的凹函数因为负对数函数是凸的。利用Jensen不等式我们可以证明对于两个概率分布p和q以及0≤λ≤1有 H(λp (1-λ)q) ≥ λH(p) (1-λ)H(q)这个性质保证了信息熵具有良好的数学性质也是很多信息论结论的基础。我在研究数据压缩时就深刻体会到这个性质的重要性——它确保了混合概率分布的信息量不会突然增大。2.2 KL散度的非负性证明KL散度Kullback-Leibler divergence是衡量两个概率分布差异的重要指标。利用Jensen不等式和对数函数的凸性我们可以优雅地证明KL散度的非负性D(p||q) ∑p(x)log(p(x)/q(x)) -∑p(x)log(q(x)/p(x)) ≥ -log(∑p(x)(q(x)/p(x))) -log(1) 0这个证明过程中关键一步就是用到了对数函数的凸性和Jensen不等式。在实际项目中评估模型预测分布与真实分布的差异时KL散度的这个性质保证了我们的评估指标总是有意义的。3. 机器学习中的关键桥梁EM算法解析3.1 EM算法中的下界构造EMExpectation-Maximization算法是机器学习中处理隐变量模型的经典方法。我第一次实现EM算法时就被其中Jensen不等式的巧妙应用所震撼。在E步我们需要构造一个对数似然函数的下界log p(X|θ) log ∑ p(X,Z|θ) log ∑ q(Z) [p(X,Z|θ)/q(Z)] ≥ ∑ q(Z) log [p(X,Z|θ)/q(Z)]这里的关键就是把log函数凹函数放在求和符号外面利用Jensen不等式得到下界。这个下界通常更容易优化从而引出了M步的参数更新。3.2 变分推断中的变分下界在更复杂的概率图模型中变分推断Variational Inference同样依赖于Jensen不等式来构造证据下界ELBOlog p(X) ≥ E_q[log p(X,Z)] - E_q[log q(Z)]这个下界使得我们可以用简单的分布q来近似复杂的后验分布p(Z|X)。在实际项目中我经常用这个技巧来处理高维隐变量模型大大简化了计算复杂度。4. 优化问题的实用技巧从理论到实现4.1 损失函数设计中的凸性保证在设计机器学习模型的损失函数时凸性是一个非常重要的性质。利用Jensen不等式我们可以验证很多常用损失函数的凸性。例如对于逻辑回归的负对数似然损失L(θ) -∑ [y_i log σ(θ^T x_i) (1-y_i)log(1-σ(θ^T x_i))]其中σ是sigmoid函数。由于sigmoid函数的对数凹性结合Jensen不等式可以证明这个损失函数是凸的从而保证梯度下降能找到全局最优解。4.2 正则化项的推导在贝叶斯视角下正则化项通常对应着参数的先验分布。比如L2正则化对应高斯先验L1正则化对应拉普拉斯先验。利用Jensen不等式我们可以推导出这些正则化项在优化过程中的行为边界。例如在推导变分自编码器VAE的目标函数时重构误差和KL散度项的平衡就依赖于Jensen不等式提供的理论保证。这让我在实际调参时能够更好地理解每个超参数的作用。5. 实战案例分析从理论到代码实现5.1 用Python验证Jensen不等式让我们用Python实际验证一下Jensen不等式。以指数函数为例import numpy as np # 定义凸函数 def f(x): return np.exp(x) # 随机生成点和权重 x np.random.rand(5) lambda_ np.random.dirichlet(np.ones(5)) # 计算两边值 left f(np.sum(lambda_ * x)) right np.sum(lambda_ * f(x)) print(ff(∑λx): {left:.4f} ≤ ∑λf(x): {right:.4f})运行结果通常会显示左边小于等于右边验证了Jensen不等式。不过要注意当点数很少时由于浮点精度可能会出现看似违反不等式的情况这是数值计算中的常见陷阱。5.2 在PyTorch中实现EM算法下面是一个简化的EM算法实现展示了如何利用Jensen不等式import torch def em_algorithm(data, n_components, n_iter100): # 初始化参数 mu torch.randn(n_components) var torch.ones(n_components) pi torch.ones(n_components)/n_components for _ in range(n_iter): # E步计算后验概率利用Jensen不等式的下界 log_prob -0.5 * ((data[:,None]-mu)**2/var torch.log(var)) log_weighted log_prob torch.log(pi) q torch.softmax(log_weighted, dim1) # M步最大化下界 Nk q.sum(0) pi Nk / len(data) mu (q.T data) / Nk var (q.T (data[:,None]-mu)**2) / Nk return mu, var, pi这个实现展示了如何将理论转化为实际代码。E步中softmax的计算实际上就是在构造Jensen不等式中的下界。6. 常见误区与调试技巧6.1 函数凸性判断错误新手最容易犯的错误就是错误判断函数的凸性。我曾经在一个项目中误以为某个复合函数是凸的导致推导的算法不收敛。后来通过绘制函数图像和二阶导数检查才发现问题。建议在使用Jensen不等式前先用以下方法验证凸性计算二阶导数对于可微函数绘制函数曲线观察用随机点验证不等式是否成立6.2 权重条件的忽视Jensen不等式要求权重λᵢ满足∑λᵢ1且λᵢ≥0。在实际应用中特别是自己设计新算法时很容易忽略这个条件。我曾在实现一个变分推断算法时因为没有正确归一化权重导致结果完全错误。调试这类问题时建议添加assert语句检查权重和使用softmax等保证归一化在文档中明确标注权重约束6.3 数值稳定性问题对数域计算时容易出现数值不稳定问题。例如在计算log-sum-exp时直接实现可能会导致上溢或下溢。解决方案是使用以下稳定实现def logsumexp(x): x_max x.max() return x_max torch.log(torch.sum(torch.exp(x - x_max)))这个技巧在实现EM算法和变分推断时特别有用可以避免很多难以调试的数值问题。