softplus的逆的数值稳定计算方法
- 什么是softplus函数?
- softplus函数的逆
- 理想表达式
- 稳定替代式
- 混合策略函数(自动数值稳定选择)
参考内容: Add softplus inverse #72759
什么是softplus函数?
根据Softplus的pytorch实现1 ,Softplus是ReLU函数2的近似,用于保证输出总是为正。其公式和图像如下:
Softplus ( x ) = 1 β ∗ log ( 1 + exp ( β ∗ x ) ) \operatorname{Softplus}(x)=\frac{1}{\beta} * \log (1+\exp (\beta * x)) Softplus(x)=β1∗log(1+exp(β∗x))
softplus函数的逆
我们很容易知道,对于 β = 1 \beta=1 β=1的softplus函数逆是:
y = log ( e x − 1 ) y=\log (e^x-1) y=log(ex−1)
我们希望实现一个softplus函数逆的数值稳定版本。记softplus为函数 f f f,softplus函数的逆为 f − 1 f^{-1} f−1.则对于任意输入,应该有 f − 1 f ( x ) = x f^{-1}f(x)=x f−1f(x)=x.
下面设计几种计算方式:
理想表达式
利用泰勒展开式:避免先算 exp(x) 再减 1,而是在内部使用了泰勒展开(或其他高精度公式) 来专门针对小 x 优化计算方式。
# 使用利用泰勒展开的expm1计算
def f1(t):return torch.expm1(t).log()x = torch.linspace(-50, 150, 200, requires_grad=True)
for i, f in enumerate([f1]):y = f(softplus(x))print(grad(y.sum(), [x])[0].sum())plt.plot(x.detach(), y.detach() + i * 2)
plt.show()
当 x 很小时,expm1(x)
会比 exp(x) - 1
更稳定。
缺点: 在代码测试下: 函数支持(x为-50到88)的范围。
当x很大时,t=softplus(x)也会很大,在经过log之前的结果,还是会有溢出,造成结果为inf。即这个计算方式不支持特别大的x。
稳定替代式
利用变形: log ( e x − 1 ) = x + log ( 1 − e − x ) \log \left(e^x-1\right)=x+\log \left(1-e^{-x}\right) log(ex−1)=x+log(1−e−x)
def f2(x):return x + (1 - x.neg().exp()).log()x = torch.linspace(-50, 150, 200, requires_grad=True)
for i, f in enumerate([f1, f2]):y = f(softplus(x))print(grad(y.sum(), [x])[0].sum())plt.plot(x.detach(), y.detach() + i * 2)
plt.show()
当 x 很大时,e^{-x} 很小,接近0,1 - e^{-x} 接近 1,不会有溢出,所以 (1 - e^{-x}) 是一个稳定的小差值,避免了 expm1 计算的缺点。
缺点: 在代码测试下: 函数支持(x为-16到150)的范围。
当x很小时,e^{-x} 很大,还是会有溢出。即这个计算方式不支持特别小的x。
混合策略函数(自动数值稳定选择)
def f3(x):big = x > torch.tensor(torch.finfo(x.dtype).max).log()return torch.where(big,f2(x.masked_fill(~big, 1.)),f1(x.masked_fill(big, 1.)),)
x = torch.linspace(-50, 150, 200, requires_grad=True)
for i, f in enumerate([f3]):y = f(softplus(x))print(y)print(grad(y.sum(), [x])[0].sum())plt.plot(x.detach(), y.detach() + i * 2)
plt.show()
- 判断 x 是否“太大”( e x e^x ex超过最大值),以至于 expm1(x) 会溢出。
- 如果太大了(big == True),就用 f2 来计算;
- 否则用f1计算。
torch.expm13是通过计算泰勒展开式而不是直接计算相减(消除了大多数有效位数)来提高计算精度,保证数值稳定性。
torch.finfo(x.dtype).max的目的是获取张量 x x x 所使用的数据类型(比如 float32 或 float16)能表示的最大值。
masked_fill的作用是
这里如果去掉masked_fill, f2(x)和f1(x)会出现一部分inf的值,仍然可以通过torch.where筛选出正确的值return。
但注意,虽然 torch.where() 只会输出 f1(x) 或 f2(x) 的部分结果,但 两个分支的计算都会执行,并且 参与反向传播。
如果在 f1(x) 或 f2(x) 中出现了无效操作(比如 log(0)、log(-inf)、log(nan)),哪怕这些值最终没被 where() 选中,梯度也会传回去,导致整个图崩溃或者出 nan。
所以masked_fill在这里的作用是: 用安全值填充未使用的区域,避免溢出或非法操作
Softplus的pytorch实现 ↩︎
ReLU激活函数:简单之美 ↩︎
torch.expm1作用 ↩︎