构建具有多轮迭代偏好学习的数学智能体–《BUILDING MATH AGENTS WITH MULTI-TURN ITERA- TIVE PREFERENCE LEARNING》全文阅读
第一作者:Wei Xiong
第一单位:University of Illinois Urbana-Champaign
摘要
最近的研究表明,通过集成外部工具(例如代码解释器)和采用多轮思维链(CoT)推理,可以增强大型语言模型(LLMs)的数学问题解决能力。虽然现有方法侧重于合成数据生成和监督微调(SFT),但本文研究了互补的直接偏好学习方法以进一步提高模型性能。然而,现有的直接偏好学习算法最初是为单轮聊天任务设计的,并未完全解决工具集成数学推理任务所需的多轮推理和外部工具集成的复杂性。为了弥补这一空白,我们引入了一个专为此环境量身定制的多轮直接偏好学习框架,该框架利用代码解释器的反馈并优化轨迹级偏好。该框架包括多轮 DPO 和多轮 KTO 作为具体实现。通过使用 GSM8K 和 MATH 数据集的增强提示集训练各种语言模型,验证了我们框架的有效性。我们的结果显示了实质性的改进:一个经过监督微调的 Gemma-1.1-it-7B 模型在 GSMB8K 上的性能从 77.5% 提高到 83.9%,在 MATH 上从 46.1% 提高到 51.2%。同样,一个 Gemma-2-it-9B 模型在 GSMB8kK 上的性能从 84.1% 提高到 86.3%,在 MATH 上从 51.0% 提高到 54.5%。
1 INTRODUCTION
大型语言模型(LLMs)在各种语言任务中展示了卓越的能力。著名的模型包括 ChatGPT (OpenAl] |2023)、Claude (Anthropic} |2023) 和 Gemini 2023)。然而,尽管取得了这些进展,即使是最先进的闭源 LLMs 仍然在需要多轮决策的复杂推理任务上苦苦挣扎。特别是对于数学问题解决这一代表性任务,LLMs 经常在基本算术和符号计算方面失败 (Hendrycks et al.| 2021)。为了解决这个问题,最近的研究建议集成外部工具(例如计算器、计算 Python 库和符号求解器)来增强 LLMs 的数学问题解决能力 [2022} [2022| [2024a)。具体而言,通过将自然语言推理与这些外部工具的使用相结合,这些增强的 LLMs 可以接收来自工具交互的外部消息,并根据先前生成的 token 和外部消息进行推理,这显著提高了它们在数学任务中的表现 2023b} Toshniwal et al.| 2024} Shao et al.| 2024)。这些合成数据集在 MATH (Hendrycks et al.|/2021) 和 GSM8K (Cobbe et al.||202 1a) 等标准基准测试上取得了显著的测试准确率提高。在强大的 SFT 模型基础上,强化学习从人类反馈中学习(RLHF)已被证明是 LLMs 在后训练阶段引发知识的关键技术,并已成为 LLM 训练流水线中的标准 (Ouyang et al.| (2023)。广义地说,RLHF 学习范式最初是为了将 LLMs 与人类价值观和偏好对齐而设计的,它与 SFT 的区别在于它从相对反馈中学习。它显著增强了 ChatGPT、Claude 和 Gemini 等模型的能力,使它们能够生成更有帮助、无害和诚实的回复 [2022)。受 RLHF 在一般聊天应用中成功的启发,在本文中,我们探索了 RLHF 在装备外部工具时,用于提高 LLMs 数学问题解决能力的方法。特别是,由于深度 RL 方法(例如,近端策略优化,PPO 算法 (Schulman et al.||2017)) 通常样本效率低下且不稳定 (Choshen et al.| ,我们的目标是推导出直接偏好学习算法,该算法直接从偏好数据集中学习 (Zhao et al.|/2023}|Rafailov et al.|/2023)。
贡献:
我们首先将学习过程公式化为马尔可夫决策过程(MDP),这与 RLHF 中用于构建不与外部环境交互的通用聊天机器人的上下文匪徒方法通常不同 [2024} [Rafailov et al.||2023)。然后,我们推导出使用此类 MDP 进行规划的最优性条件,我们的发现表明,当外部随机性较低时,我们可以开发多轮直接对齐算法(M-DPO 和 M-KTO),其主要修改是在训练过程中屏蔽掉不相关的标记。此外,我们将我们的方法扩展到其在线迭代变体,最近的研究表明这是有希望的 (Xiong et al.| [2024] 2024b)。最后,我们通过案例研究评估了我们的方法,这些案例研究使用来自 MATH 和 GSM8K 基准的增强训练集,并采用各种基础模型,例如 Gemma (Team et al.} 2024)、CodeGemma (Team) P02} 和 Mistral (Ji {2023)。例如,监督微调的 Gemma-1.1-it-7B 模型在 GSM8K 上的性能从 77.5% 提高到 83.9%,在 MATH 上从 46.1% 提高到 51.2%。同样,Gemma-2-it-9B 模型在 GSM8K 上的性能从 84.1% 提高到 86.3%,在 MATH 上从 51.0% 提高到 54.5%。这些实证结果表明,与标准 SFT 模型相比,性能有了显著提高,证明了 RLHF 在复杂推理任务中的潜力。我们还提供了实践实现我们的在线迭代多轮方法的全面方案,并将公开我们的模型、数据集和代码,以供进一步研究和开发。
2 ALGORITHMS DEVELOPMENT
2.1 PROBLEM FORMULATION
我们首先正式阐述工具集成推理任务。在第一步,从某个分布 d 0 d_0 d0 中抽取提示 x ∈ X x \in \mathcal{X} x∈X 作为初始状态 s 1 = x s_1 = x s1=x。然后,在每个步骤 h ∈ [ H ] h \in [H] h∈[H] 中,
- 行动:智能体观察当前状态 s h s_h sh,它是与外部环境的前 h − 1 h-1 h−1 次交互的历史,并根据某个策略 π h ( ⋅ ∣ s h ) ∈ A ( s h ) \pi_h(\cdot|s_h) \in \mathcal{A}(s_h) πh(⋅∣sh)∈A(sh) 执行行动 a h a_h ah。
- 观察:为了响应智能体的行动, 环境 环境 环境 然后返回观察 o h ∼ P h ( ⋅ ∣ s h , a h ) o_h \sim P_h(\cdot|s_h, a_h) oh∼Ph(⋅∣sh,ah),基于历史 s h s_h sh 和当前行动 a h a_h ah。
然后,我们转换到新状态,这是到步骤 h + 1 h+1 h+1 的历史: s h + 1 = ( s h , a h , o h ) = ( x , a 1 , o 1 , ⋯ , a h , o h ) s_{h+1} = (s_h, a_h, o_h) = (x, a_1, o_1, \cdots, a_h, o_h) sh+1=(sh,ah,oh)=(x,a1,o1,⋯,ah,oh),并开始新步骤。此过程总共重复 H H H 轮,最终我们收集到一个轨迹: τ = ( x , a 1 , o 1 , ⋯ , a H , o H ) \tau = (x, a_1, o_1, \cdots, a_H, o_H) τ=(x,a1,o1,⋯,aH,oH)。我们在图 τ \tau τ 中展示了一个多轮工具集成推理的示例。通常,行动以 ReAct 方式进行,包括推理步骤 f h f_h fh 和执行步骤 e h e_h eh(例如编写 python 代码),我们顺便提一下,最近在 Zhong et al.|( ;|Rafailov et al.|(2024);/Xie et al.|(2024a) 中也研究了偏好学习的这种 MDP 公式化,但侧重于单轮聊天任务,并且没有明确考虑外部消息。'当没有歧义时,也采用缩写 s h + 1 ∼ P h ( ⋅ ∣ s h , a h ) s_{h+1} \sim P_h(\cdot|s_h, a_h) sh+1∼Ph(⋅∣sh,ah)。
为了将问题与从相对反馈中学习的 RLHF 联系起来,我们遵循 |Ouyang et al. (2022 2022) 假设我们可以查询 Bradley-Terry 模型以获得偏好信号。定义 1 (Bradley-Terry 模型)。我们表示 y = τ ∖ x y = \tau \setminus x y=τ∖x,其中提示从轨迹中排除。我们假设存在一个轨迹的效用函数 u ∗ u^* u∗,使得给定 ( x , y ′ , y ′ ′ ) (x, y', y'') (x,y′,y′′),一个响应 y ′ y' y′ 比另一个响应 y ′ ′ y'' y′′ 更受偏好,记为 y ′ > y ′ ′ y' > y'' y′>y′′,概率为
Prob ( y ′ > y ′ ′ ∣ x , y ′ , y ′ ′ ) = σ ( u ∗ ( x , y ′ ) − u ( x , y ′ ′ ) ) ( 1 ) \text{Prob}(y' > y'' | x, y', y'') = \sigma(u^*(x, y') - u(x, y'')) \quad (1) Prob(y′>y′′∣x,y′,y′′)=σ(u∗(x,y′)−u(x,y′′))(1)
其中 σ \sigma σ 是 sigmoid 函数 σ ( z ) = 1 / ( 1 + exp ( − z ) ) \sigma(z) = 1 / (1 + \exp(-z)) σ(z)=1/(1+exp(−z))。此外,给定 ( x , y ′ , y ′ ′ ) (x, y', y'') (x,y′,y′′),我们将抽取的偏好信号记为 z z z,其中 z = 1 z=1 z=1 表示 y ′ > y ′ ′ y' > y'' y′>y′′,而 z = 0 z=0 z=0 表示 y ′ ′ > y ′ y'' > y' y′′>y′。
这里我们只假设可以访问轨迹级偏好,而不是行动级偏好。然而,我们注意到效用函数本身可以以分步方式定义。效用函数的示例包括检查最终结果的二元奖励、结果监督奖励模型 (Cobbe et al.|{2021b)) 和过程监督奖励模型 (Lightman et al. )。
2.2 使用模型进行规划:最优性条件和实用算法
我们在本节中用一般 MDP 公式开发主要算法。遵循 (2023),我们首先建立模型 M = ( S , A , H , P , d 0 , u ) M = (S, A, H, P, d_0, u) M=(S,A,H,P,d0,u) 及其相关最优策略之间的联系。特别是,我们对以下关于参考策略 π ref \pi_{\text{ref}} πref 的 KL 正则化规划问题感兴趣:
arg max π J ( π ; M , π ref ) = E x ∼ d 0 , τ ∼ π ( ⋅ ∣ x ) , o h ∼ P h ( ⋅ ∣ s h , a h ) [ u ( τ ) − η ∑ h = 1 H D KL ( π h ( ⋅ ∣ s h ) , π ref , h ( ⋅ ∣ s h ) ) ] ( 2 ) \arg \max_\pi J(\pi; M, \pi_{\text{ref}}) = \mathbb{E}_{x \sim d_0, \tau \sim \pi(\cdot|x), o_h \sim P_h(\cdot|s_h, a_h)} [u(\tau) - \eta \sum_{h=1}^H D_{\text{KL}}(\pi_h(\cdot|s_h), \pi_{\text{ref},h}(\cdot|s_h))] \quad (2) argπmaxJ(π;M,πref)=Ex∼d0,τ∼π(⋅∣x),oh∼Ph(⋅∣sh,ah)[u(τ)−ηh=1∑HDKL(πh(⋅∣sh),πref,h(⋅∣sh))](2)
在 H = 1 H=1 H=1 的单轮情况下,关于效用函数 u u u 的最优解是 Gibbs 分布(参见引理)。转向多轮情况,我们首先考虑 H = 2 H=2 H=2 来阐明思想。想法是进行从 h = H = 2 h=H=2 h=H=2 到 h = 1 h=1 h=1 的向后迭代。具体而言,当我们固定 s 2 s_2 s2 仅考虑步骤 2 时,它简化为单轮情况:
π M , 2 ∗ ( ⋅ ∣ s 2 ) = arg max π 2 E a 2 ∼ π 2 ( ⋅ ∣ s 2 ) [ u ( s 2 , a 2 ) − η D KL ( π 2 ( ⋅ ∣ s 2 ) , π ref , 2 ( ⋅ ∣ s 2 ) ) ] ∝ π ref , 2 ( ⋅ ∣ s 2 ) exp ( u ( s 2 , a 2 ) η ) \pi^*_{M,2}(\cdot|s_2) = \arg \max_{\pi_2} \mathbb{E}_{a_2 \sim \pi_2(\cdot|s_2)} [u(s_2, a_2) - \eta D_{\text{KL}}(\pi_2(\cdot|s_2), \pi_{\text{ref},2}(\cdot|s_2))] \propto \pi_{\text{ref},2}(\cdot|s_2) \exp\left(\frac{u(s_2, a_2)}{\eta}\right) πM,2∗(⋅∣s2)=argπ2maxEa2∼π2(⋅∣s2)[u(s2,a2)−ηDKL(π2(⋅∣s2),πref,2(⋅∣s2))]∝πref,2(⋅∣s2)exp(ηu(s2,a2))
然后,我们可以定义与 π M , 2 ∗ \pi^*_{M,2} πM,2∗ 相关的值函数为
V M , 2 ∗ ( s 2 ) = E a 2 ∼ π M , 2 ∗ ( ⋅ ∣ s 2 ) [ u ( s 2 , a 2 ) − η D KL ( π M , 2 ∗ ( ⋅ ∣ s 2 ) , π ref , 2 ( ⋅ ∣ s 2 ) ) ] V^*_{M,2}(s_2) = \mathbb{E}_{a_2 \sim \pi^*_{M,2}(\cdot|s_2)} [u(s_2, a_2) - \eta D_{\text{KL}}(\pi^*_{M,2}(\cdot|s_2), \pi_{\text{ref},2}(\cdot|s_2))] VM,2∗(s2)=Ea2∼πM,2∗(⋅∣s2)[u(s2,a2)−ηDKL(πM,2∗(⋅∣s2),πref,2(⋅∣s2))]
Q M , 1 ∗ ( s 1 , a 1 ) = E o 1 ∼ P 1 ( ⋅ ∣ s 1 , a 1 ) [ V M , 2 ∗ ( s 2 ) ] Q^*_{M,1}(s_1, a_1) = \mathbb{E}_{o_1 \sim P_1(\cdot|s_1, a_1)} [V^*_{M,2}(s_2)] QM,1∗(s1,a1)=Eo1∼P1(⋅∣s1,a1)[VM,2∗(s2)]
对于步骤 1,由于我们已经确定了 π M , 2 ∗ \pi^*_{M,2} πM,2∗,通过 Q M , 1 ∗ ( s 1 , a 1 ) Q^*_{M,1}(s_1, a_1) QM,1∗(s1,a1) 的定义,我们有
π M , 1 ∗ ( ⋅ ∣ s 1 ) = arg max π 1 E a 1 ∼ π 1 ( ⋅ ∣ s 1 ) [ Q M , 1 ∗ ( s 1 , a 1 ) − η D KL ( π 1 ( ⋅ ∣ s 1 ) , π ref , 1 ( ⋅ ∣ s 1 ) ) ] ∝ π ref , 1 ( ⋅ ∣ s 1 ) exp ( Q M , 1 ∗ ( s 1 , a 1 ) η ) \pi^*_{M,1}(\cdot|s_1) = \arg \max_{\pi_1} \mathbb{E}_{a_1 \sim \pi_1(\cdot|s_1)} [Q^*_{M,1}(s_1, a_1) - \eta D_{\text{KL}}(\pi_1(\cdot|s_1), \pi_{\text{ref},1}(\cdot|s_1))] \propto \pi_{\text{ref},1}(\cdot|s_1) \exp\left(\frac{Q^*_{M,1}(s_1, a_1)}{\eta}\right) πM,1∗(⋅∣s1)=argπ1maxEa1∼π1(⋅∣s1)[QM,1∗(s1,a1)−ηDKL(π1(⋅∣s1),πref,1(⋅∣s1))]∝πref,1(⋅∣s1)exp(ηQM,1∗(s1,a1))
通过构建, { π M , h ∗ } h = 1 H \{\pi^*_{M,h}\}_{h=1}^H {πM,h∗}h=1H 是最优的,因为它最大化了 KL 正则化目标。对于一般 MDP,我们可以重复此过程 H H H 次,从 V M , H + 1 = 0 V_{M,H+1} = 0 VM,H+1=0 开始,我们递归地定义
Q M , h ( s h , a h ) = { u ( s H , a H ) , if h = H E o h ∼ P h ( ⋅ ∣ s h , a h ) [ V M , h + 1 ( s h + 1 ) ] , if h < H ( 3 ) Q_{M,h}(s_h, a_h) = \begin{cases} u(s_H, a_H), & \text{if } h = H \\ \mathbb{E}_{o_h \sim P_h(\cdot|s_h, a_h)} [V_{M,h+1}(s_{h+1})], & \text{if } h < H \end{cases} \quad (3) QM,h(sh,ah)={u(sH,aH),Eoh∼Ph(⋅∣sh,ah)[VM,h+1(sh+1)],if h=Hif h<H(3)
这里最优策略和 V 值由下式给出
π M , h ∗ ( a h ∣ s h ) = π ref , h ( a h ∣ s h ) exp ( Q M , h ( s h , a h ) / η ) Z h ( s h ) ( Gibbs distribution of Q M , h ) \pi^*_{M,h}(a_h|s_h) = \pi_{\text{ref},h}(a_h|s_h) \frac{\exp(Q_{M,h}(s_h, a_h)/\eta)}{Z_h(s_h)} \quad (\text{Gibbs distribution of } Q_{M,h}) πM,h∗(ah∣sh)=πref,h(ah∣sh)Zh(sh)exp(QM,h(sh,ah)/η)(Gibbs distribution of QM,h)
V M , h ( s h ) = E a h ∼ π M , h ∗ ( ⋅ ∣ s h ) [ Q M , h ( s h , a h ) − η D KL ( π M , h ∗ ( ⋅ ∣ s h ) , π ref , h ( ⋅ ∣ s h ) ) ] V_{M,h}(s_h) = \mathbb{E}_{a_h \sim \pi^*_{M,h}(\cdot|s_h)} [Q_{M,h}(s_h, a_h) - \eta D_{\text{KL}}(\pi^*_{M,h}(\cdot|s_h), \pi_{\text{ref},h}(\cdot|s_h))] VM,h(sh)=Eah∼πM,h∗(⋅∣sh)[QM,h(sh,ah)−ηDKL(πM,h∗(⋅∣sh),πref,h(⋅∣sh))]
= η log E a h ∼ π ref , h ( a h ∣ s h ) exp ( Q M , h ( s h , a h ) η ) = η log Z h ( s h ) ( 4 ) = \eta \log \mathbb{E}_{a_h \sim \pi_{\text{ref},h}(a_h|s_h)} \exp\left(\frac{Q_{M,h}(s_h, a_h)}{\eta}\right) = \eta \log Z_h(s_h) \quad (4) =ηlogEah∼πref,h(ah∣sh)exp(ηQM,h(sh,ah))=ηlogZh(sh)(4)
其中 Z h ( s h ) = ∑ a h ∈ A ( s h ) π ref , h ( a h ∣ s h ) exp ( Q M , h ( s h , a h ) / η ) Z_h(s_h) = \sum_{a_h \in \mathcal{A}(s_h)} \pi_{\text{ref},h}(a_h|s_h) \exp(Q_{M,h}(s_h, a_h)/\eta) Zh(sh)=∑ah∈A(sh)πref,h(ah∣sh)exp(QM,h(sh,ah)/η) 是归一化常数。V 值定义中的第二个等式来自引理 3。然后,根据定义, { π M , h ∗ } h = 1 H \{\pi^*_{M,h}\}_{h=1}^H {πM,h∗}h=1H 是最优的。本质上,我们解决了 H H H 个关于 Q 值的 Gibbs 分布。我们将结果总结到以下命题中。
命题 1. 我们可以递归地定义具有 horizon H 和外部观察 o h o_h oh 的 KL 正则化 MDP 的以下最优值函数和最优策略。对于 Q 值,我们有
Q M , h ( s h , a h ) = { u ( s H , a H ) , if h = H E o h ∼ P h ( ⋅ ∣ s h , a h ) [ V M , h + 1 ( s h + 1 ) ] , if h < H ( 5 ) Q_{M,h}(s_h, a_h) = \begin{cases} u(s_H, a_H), & \text{if } h = H \\ \mathbb{E}_{o_h \sim P_h(\cdot|s_h, a_h)} [V_{M,h+1}(s_{h+1})], & \text{if } h < H \end{cases} \quad (5) QM,h(sh,ah)={u(sH,aH),Eoh∼Ph(⋅∣sh,ah)[VM,h+1(sh+1)],if h=Hif h<H(5)
此外,对于所有 h ∈ [ H ] h \in [H] h∈[H],我们有:
我们有一些有趣的观察结果,它们可能具有独立的意义。
- 由于附加的 KL 约束,最优值函数通过对初始参考策略的期望来表征。
- 对于固定的步骤 h h h 和状态-行动对 ( s h , a h ) (s_h, a_h) (sh,ah),我们可以将未来视为一个 bandit(只有一步),然后,我们有 Q M , h ( s h , a h ) = E z u ( s h , a h , z ) Q_{M,h}(s_h, a_h) = \mathbb{E}_{z} u(s_h, a_h, z) QM,h(sh,ah)=Ezu(sh,ah,z),其中 z z z 是从 ( s h , a h ) (s_h, a_h) (sh,ah) 开始的完成。可以使用蒙特卡罗估计来估计这个值,通过多次 rollout。我们注意到,这个过程的非正则化版本通常被称为过程监督奖励(PRM)(Wang et al.|2023a)。换句话说,在 a) 中构建的 PRM 本质上是一个 Q-学习过程。
我们注意到,结果本质上来自熵正则化 MDPs (Williams & Peng| art} |2010)。
E ( x , τ + , τ − ) ∼ D [ log σ ( η ∑ h = 1 H ( log π θ ( a h + ∣ s h + ) − log π θ ( a h − ∣ s h − ) ) ) + η ∑ h = 1 H ( log π ref , h ( a h + ∣ s h + ) − log π ref , h ( a h − ∣ s h − ) ) ] ( 9 ) \mathbb{E}_{(x, \tau^+, \tau^-) \sim \mathcal{D}} [\log \sigma(\eta \sum_{h=1}^H (\log \pi_\theta(a_h^+|s_h^+) - \log \pi_\theta(a_h^-|s_h^-))) + \eta \sum_{h=1}^H (\log \pi_{\text{ref},h}(a_h^+|s_h^+) - \log \pi_{\text{ref},h}(a_h^-|s_h^-))] \quad (9) E(x,τ+,τ−)∼D[logσ(ηh=1∑H(logπθ(ah+∣sh+)−logπθ(ah−∣sh−)))+ηh=1∑H(logπref,h(ah+∣sh+)−logπref,h(ah−∣sh−))](9)
这里,项 (A) 类似于单轮情况,而项 (B) 对于具有相同提示 s 1 s_1 s1 的两个样本的奖励差异将被抵消。然而,在实践中,项 © 通常不可直接计算,因为项 (C’) 与外部环境的随机性有关。对于这项工作的重点,即工具集成数学推理,幸运的是代码执行结果由历史(LLMs 编写的代码)确定。这导致项 © = 0。因此,我们可以将方程 ( 9 ) (9) (9) 插入到基于包含 ( x , τ + , τ − ) (x, \tau^+, \tau^-) (x,τ+,τ−) 的数据集 D \mathcal{D} D 的效用函数的最大似然估计中,得到以下多轮 DPO (M-DPO) 损失:
L M-DPO ( θ ) = − E ( x , τ + , τ − ) ∼ D [ log σ ( η ∑ h = 1 H ( log π θ ( a h + ∣ s h + ) π ref , h ( a h + ∣ s h + ) − log π θ ( a h − ∣ s h − ) π ref , h ( a h − ∣ s h − ) ) ) ] ( 10 ) \mathcal{L}_{\text{M-DPO}}(\theta) = - \mathbb{E}_{(x, \tau^+, \tau^-) \sim \mathcal{D}} \left[ \log \sigma\left(\eta \sum_{h=1}^H \left(\log \frac{\pi_\theta(a_h^+|s_h^+)}{\pi_{\text{ref},h}(a_h^+|s_h^+)} - \log \frac{\pi_\theta(a_h^-|s_h^-)}{\pi_{\text{ref},h}(a_h^-|s_h^-)}\right)\right) \right] \quad (10) LM-DPO(θ)=−E(x,τ+,τ−)∼D[logσ(ηh=1∑H(logπref,h(ah+∣sh+)πθ(ah+∣sh+)−logπref,h(ah−∣sh−)πθ(ah−∣sh−)))](10)
类似地,我们可以在确定性转换下实现 M-KTO。我们请感兴趣的读者参考附录 A 以获取损失函数详情。
2.3 在线迭代训练
现在我们将规划算法 M-DPO 与在线迭代学习框架相结合,这受到了其在单轮情况中巨大成功的启发 (Xiong et al.||2024| 0246)。
学习目标。为了更全面地理解其统计行为,我们将考虑两种不同的学习目标。第一个目标是 KL 正则化目标:
max π E x ∼ d 0 , a h ∼ π ( ⋅ ∣ s h ) , o h ∼ P ∗ ( ⋅ ∣ s h , a h ) [ u ∗ ( x , y ) − η ∑ h = 1 H D KL ( π h ( ⋅ ∣ s h ) , π 0 ( ⋅ ∣ s h ) ) ] ( 11 ) \max_\pi \mathbb{E}_{x \sim d_0, a_h \sim \pi(\cdot|s_h), o_h \sim P^*(\cdot|s_h, a_h)} \left[u^*(x, y) - \eta \sum_{h=1}^H D_{\text{KL}}(\pi_h(\cdot|s_h), \pi_0(\cdot|s_h))\right] \quad (11) πmaxEx∼d0,ah∼π(⋅∣sh),oh∼P∗(⋅∣sh,ah)[u∗(x,y)−ηh=1∑HDKL(πh(⋅∣sh),π0(⋅∣sh))](11)
即 max π J ( π ; M ∗ , π 0 ) \max_\pi J(\pi; M^*, \pi_0) maxπJ(π;M∗,π0),其中 M ∗ = ( S , A , H , P ∗ , d 0 , u ∗ ) M^* = (S, A, H, P^*, d_0, u^*) M∗=(S,A,H,P∗,d0,u∗) 是真实环境, π 0 \pi_0 π0 是初始 SFT 策略。这个目标在 RLHF 中被广泛采用,需要我们只在以 SFT 策略 π 0 \pi_0 π0 为中心的固定 KL 球中搜索最优策略。相比之下,第二个是非正则化目标,即直接优化奖励:
max π E x ∼ d 0 , a h ∼ π ( ⋅ ∣ s h ) , o h ∼ P ∗ ( ⋅ ∣ s h , a h ) [ u ∗ ( x , y ) ] ( 12 ) \max_\pi \mathbb{E}_{x \sim d_0, a_h \sim \pi(\cdot|s_h), o_h \sim P^*(\cdot|s_h, a_h)} [u^*(x, y)] \quad (12) πmaxEx∼d0,ah∼π(⋅∣sh),oh∼P∗(⋅∣sh,ah)[u∗(x,y)](12)
这个目标是经典 RL 研究中的标准目标 (Sutton & Barto} [2018)。这个目标的一个动机是,在推理任务中,奖励函数比聊天任务更具可解释性(例如检查最终结果)。
算法框架。我们在算法 [I] 中介绍了一种通用的在线迭代算法框架。该框架被称为在线迭代多轮 Gibbs 采样从人类反馈(M-GSHF),因为最优策略是一种层状 Gibbs 分布,它泛化了 (2024) 中的结果。我们现在讨论该框架的一些特征如下。参考模型选择用于控制正则化水平。我们通过将参考模型选择作为超参数来统一方程 [IT] 和方程 [I2] 中的两个不同学习目标。首先,如果我们将参考模型固定为初始策略,即 π ref = π 0 \pi_{\text{ref}} = \pi_0 πref=π0, ∀ t ∈ [ T ] \forall t \in [T] ∀t∈[T],我们总是在以 π 0 \pi_0 π0 为中心的 KL 球内搜索最优策略,从而优化 KL 正则化目标。相比之下,受镜像下降 (Nemirovskij & Yudin 图 1:两个学习目标之间差异的说明。左:KL 正则化目标,因为我们不更新参考模型。(1983)) 的启发,如果我们每次迭代都将参考策略更新为上一次迭代学习的策略,即 π ref = π t − 1 \pi_{\text{ref}} = \pi_{t-1} πref=πt−1, ∀ t ∈ [ T ] \forall t \in [T] ∀t∈[T],累积更新可以使模型偏离原始的 π 0 \pi_0 π0(同时对每次迭代更新的大小进行约束),因此我们优化方程 [12] 中的非正则化目标。参见图 [I] 以获取说明。右:非正则化目标。非对称策略选择用于探索-利用权衡。我们以非对称方式更新我们的行为策略。第一个智能体旨在提取我们迄今收集的历史信息,并运行第 2.2 节中介绍的 M-DPO 或 M-DKO。然而,RL 研究 (2018) (2002) 广泛认为,简单地通过遵循经验上最好的模型来利用历史数据不足以获得良好的最终策略,同时还需要探索环境,以便可以收集新信息以促进后续学习,即探索-利用权衡。因此,第二个智能体将战略性地整合关于未来相对于 π t \pi_t πt 的不确定性,以选择 π t ′ \pi'_t πt′,这被称为探索策略。
算法 [I] 的全面理论分析推迟到附录 [D],由于篇幅限制,重点关注 KL 正则化目标。在这里,我们强调以下非正式结果(参见定理 [2] 了解完整版本),强调算法 [I] 的效率由次线性后悔保证。优化奖励的另一个目标已在 Wang et al.|(2023b) 中进行了理论研究,而分析镜像下降式算法的技术已在 (2020) 中开发。
定理 1 (非正式). 在可实现性假设下,对于 KL 正则化目标,算法 [Z] 的理论版本对于广泛的奖励和转换模型类别的后悔(方程 {5} 中定义)是次线性于 horizon T 的。
该定理的主要 takeaway 信息是,如果我们选择合适的探索策略,在线迭代学习是可证明有效的。我们还注意到,如果没有明确的机制来鼓励探索,LLM 本身的随机性不足以学习最优策略 2022),如果我们不做附加假设。转向实际算法设计,探索通常被解释为通过采用基础 DPO 策略 π t \pi_t πt 的推理时方法来增加收集到的数据的多样性。例如,可以像 Llama 项目 (Touvron et al.|/2023) 中那样调整采样温度,或者使用 best-of-n 采样 (Xu et al.|{2023}/Hoang Tran| 2024} |Dong et al.| 2024),其中这些方法比 vanilla on-policy 采样具有相当大的优势。在这项工作中,我们主要通过各种中间检查点来丰富生成的数据,就像 Claude 项目 (Bai et al. ) 所做的那样。我们将这种方法称为混合采样。采用奖励引导的蒙特卡罗树搜索 (MCTS) 也是很自然的 2024b),我们将其留待未来工作。
算法 1 在线迭代 M-GSHF
1:输入:KL 系数 η > 0 \eta > 0 η>0,horizon T > 0 T > 0 T>0,初始策略 π 0 \pi_0 π0,批大小 m > 0 m > 0 m>0。
2:初始化 D ← ∅ \mathcal{D} \leftarrow \emptyset D←∅, π t 1 = π t 2 = π ref , t ← π 0 \pi_t^1 = \pi_t^2 = \pi_{\text{ref},t} \leftarrow \pi_0 πt1=πt2=πref,t←π0。
3:对于 t = 1 , 2 , ⋯ , T t = 1, 2, \cdots, T t=1,2,⋯,T 执行
4:通过 x ∼ d 0 , τ 1 ∼ π t 1 , τ 2 ∼ π t 2 x \sim d_0, \tau^1 \sim \pi_t^1, \tau^2 \sim \pi_t^2 x∼d0,τ1∼πt1,τ2∼πt2 采样 m m m 对 ( x , τ 1 , τ 2 , z ) (x, \tau^1, \tau^2, z) (x,τ1,τ2,z) 作为 D t \mathcal{D}_t Dt,接收根据定义 [T] 的 Bradley-Terry 模型得到的 m m m 个偏好信号 z z z,并更新偏好数据集 D ← D ∪ D t \mathcal{D} \leftarrow \mathcal{D} \cup \mathcal{D}_t D←D∪Dt。
5:> 从历史数据中提取经验最优策略
6:实践:在 D \mathcal{D} D 上执行规划算法以获得 π t 1 \pi_t^1 πt1(例如,使用方程 [10] 中的 M-DPO 损失或方程 }13 中的 M-KTO 损失)。
7:理论:在 D \mathcal{D} D 上执行 MLE 以获得模型估计 M t = ( u ^ t , P ^ t ) M_t = (\hat{u}_t, \hat{P}_t) Mt=(u^t,P^t),如方程 和方程 所述,使用 M t , η , π ref , t M_t, \eta, \pi_{\text{ref},t} Mt,η,πref,t 获得 π t 1 \pi_t^1 πt1。
8:> 选择探索策略以促进学习
9:实践:给定 π t 1 \pi_t^1 πt1,使用启发式方法(例如混合采样、推理参数调整和 west-of-n 采样)选择 π t 2 \pi_t^2 πt2 作为探索策略。
10:理论:给定 π t 1 \pi_t^1 πt1,根据方程 [18] 选择 π t 2 \pi_t^2 πt2 作为探索策略。
11:> 选择参考模型以控制正则化水平
12:考虑非正则化目标时更新 π t + 1 , ref ← π t 1 \pi_{t+1, \text{ref}} \leftarrow \pi_t^1 πt+1,ref←πt1;考虑 KL 正则化目标时保持 π t + 1 , ref ← π 0 \pi_{t+1, \text{ref}} \leftarrow \pi_0 πt+1,ref←π0。
13:结束 for 循环。
14:输出:通过验证集得到的 π t 1 \pi_t^1 πt1 中最好的模型。
3 EXPERIMENTS
3.1 EXPERIMENT SETUP
任务、数据集和模型。我们使用 MATH 和 GSM8K (Cobbe et al.|/2021a) 的测试集来衡量模型解决数学问题的能力。为了构建训练提示集,我们使用 MetaMathQA (2023) 和 MMIQC (Liu & Yao||2024) 的提示,这是一个来自 MATH 的 7.5K 训练问题和 GSM8K 的 7.47K 训练问题的增强提示集。我们在图 [4] 中提供了一个数据样本示例。我们使用各种基础模型进行训练,包括 Gemma-1.1-it-7B 2024)、CodeGemma-1.1-it-7B (Team||2024)、Mistral-7B-v0.4?| (Jiang et al.||2023) 和 Gemma2-it-9B。我们首先使用 Open-MathInstruct 数据集的子集对模型进行微调。SFT 过程的详细信息在附录 |[B] 中提供。我们使用预训练版本是因为其来自 huggingface 的 instruct 模型的聊天模板与其自己的代码库不一致。
表 1:不同方法在 GSM8K 和 MATH 测试集上的主要结果。}:该模型作为其他方法的起始检查点。CoT 方法的结果取自技术报告 2024} [2023b)。对于迭代 M-DPO/M-KTO,除非另有说明,否则默认更新参考模型。相对于 SFT 起始检查点的增益标记为 +。基本模型 使用工具的方法
迭代 M-DPO 和 M-KTO 的实现。我们总共进行了 3 个 epoch 的迭代训练。对于每次迭代,我们有一个 20K 问题的提示集,使用当前 DPO 模型生成每个提示 20 个响应,使用上次迭代的模型生成每个提示 10 个响应。
3.2 MAIN RESULTS
我们检查这些响应的最终答案以确定其正确性。然后,对于每个提示,我们随机抽取两个具有正确和不正确最终答案的响应,并将它们添加到训练样本中。然后,我们使用 M-DPO/M-KTO 损失对收集到的样本训练模型。我们还包括了对参考模型选择的消融研究。为了实现 M-DPO,我们简单地将所有用户轮次 token 的标签设置为 -100,并在后续的损失计算中屏蔽对数概率。我们最多训练模型 1 个 epoch,并在迭代训练的第一次迭代中调整学习率在 {2e-7, 4e-7, 7e-7, 1e-6} 中。最终,Gemma-1.1 模型使用了 4e-7 的学习率,Gemma-2 模型和 Mistral 模型使用了 2e-7 的学习率。全局批量大小为 32,预热步数为 40。我们每 50 个训练步数通过分割的提示集评估模型。M-KTO 的超参数与 M-DPO 的基本相同。我们还遵循原始 KTO 论文 (Ethayarajh et al.||2024) 设置了 λ + = λ − = 1 \lambda_+ = \lambda_- = 1 λ+=λ−=1。
基线。现有文献主要关注合成数据生成和 SFT 来教授模型使用外部工具。我们使用 [1] 的结果作为基线,因为我们使用相同的 SFT 数据集,因此结果大致可比。对于 CoT 基线,我们使用来自 [33] 的 Wizardmath 模型。我们还包括奖励排序微调(RAFT)作为基线 (Dong et al.|/2023),它也被称为拒绝采样微调 (Touvron et al.|/2023)。另一个基线是单轮在线迭代 DPO 和 KTO 2023}|Ethayarajh et al.|!2024),它们忽略了问题结构(即外部消息),并将轨迹作为一个整体对待。在实现中,这意味着我们不屏蔽外部消息的 token。
我们在零样本设置下评估模型,并将主要结果报告在表 [I] 中。从表 [I] 的前两节可以看出,工具集成的 LLMs 显著优于仅使用 SFT 的 CoT 对应模型,这表明利用外部工具的好处。在随后的讨论中,我们将重点关注工具集成 LLMs 范围内的比较。
迭代 M-DPO 和 M-KTO 显著改善了 SFT 模型。在所有四个基本模型上,使用 M-DPO 或 M-KTO 的迭代训练始终在 GSM8K 和 MATH 上比初始 SFT 检查点带来显著改进。特别是,使用 M-DPO,对齐后的 Gemma-1.1-it-7B 模型在 GSM8K 和 MATH 上分别达到 83.9% 和 51.2% 的准确率,与开源的 Open-MathInstruct 微调 CodeLLaMA-2-70B 相当(在 GSM8K 上略差,但在 MATH 上略好)。此外,对齐后的 Gemma-2-it-9B 模型在 GSM8K 和 MATH 上分别达到 86.3% 和 54.5% 的准确率,超过了所有使用 Open-MathInstruct 训练的 7B 到 70B 范围内的开源模型。总的来说,我们的框架在 SFT 后可以稳健地进一步提升工具集成模型的性能。
迭代 M-DPO 和 M-KTO 超越了现有的 RLHF 基线。我们还观察到,迭代 M-DPO 和 M-KTO 超越了其他现有的 RLHF 基线。首先,它们在所有四个基础模型上始终且显著地优于 RAFT。这是因为 RAFT 只模仿正确的轨迹,而基于 DPO 和基于 KTO 的算法进一步利用了来自错误轨迹的负信号。我们注意到,我们流水线中的 SFT 阶段也可以被视为 RAFT 的应用。因此,我们的结果应解释为,在第一阶段 SFT 之后,带有负信号的算法具有更高的样本效率。此外,虽然在线迭代单轮 DPO (KTO) 也提供了更好的性能,但它通常不如多轮版本。这表明学习预测由代码解释器返回的离策略外部消息通常会对推理能力的提高产生负面影响。我们还在图 5 中展示了一个我们遇到的代表性示例,其中 LLMs 生成了构建不良的代码,导致出现异常且冗长的外部消息。强迫 LLMs 学习预测这些消息会显著损害模型的推理能力。
迭代训练和参考更新带来更好的性能。以使用 M-DPO 的 Gemma-1.1-it-7B 为例,我们观察到在线迭代训练带来了更好的结果。GSMS8K 测试准确率从 77.5% (SFT) 提高到 81.5% (iter 1),再到 82.5% (iter2),最后到 83.9% (iter3),而 MATH 的测试准确率从 46.1% (SFT) 提高到 49.1% (iter 1),再到 49.7% (iter2),最后到 51.2% (iter3)。这与我们的理论洞察一致,即迭代训练有助于模型逐步探索和学习最优策略。此外,如果参考模型固定在 SFT 策略上,最终性能会比每次迭代更新参考模型显著更差。这可能是因为在这种情况下,算法优化的是非正则化奖励,而数学推理任务中的奖励比一般聊天任务中的更准确,从而带来了更好的领域内性能。关于 KL 正则化影响的详细消融研究推迟到下一节。
图 2:pass @n 准确率与候选数量 n 的关系。我们遵循先前的工作 (2024); |Toshniwal et al.|(2024) 使用温度 0.7 评估模型。我们注意到,偏好学习仅在 n 相对较小时提高了 pass@n 指标。
偏好学习仅在 n 相对较小时提高了 pass@n。我们在图 [2] 中绘制了 pass@n 准确率与候选轨迹数量 n 的关系。如果至少有一个 n 个采样轨迹是正确的,则认为问题已解决。我们发现,偏好学习仅在 n 较小时提高了 pass@n 准确率。对于 n > 16,所有模型在 GSM8K 和 MATH 上的表现相似,这表明迭代 M-DPO 不会引入新知识,而是提高了 top-n 响应的质量。这一观察结果也与 CoT 推理的结果一致 (Shao et al.|/2024)。
3.3 消融研究和讨论
适度的 KL 正则化平衡了每次迭代的改进和探索。迭代 DPO 的有效性高度依赖于参考模型和 KL 系数。在我们的消融研究中,我们首先考虑了两种不同的参考模型选择:(1)使用固定的参考模型 π 1 \pi_1 π1;方法 | GSM8K MATH SFT 77.5 46.1 更新参考 +) = 0.01 | 81.7 50.1 更新参考 + η \eta η = 0.1 | 83.9 51.2 更新参考 + η \eta η = 0.5 | 82.8 49.7 固定参考 + = 0.1 | 79.9 48.0 (2)每次迭代将参考模型更新为上次迭代的模型,这可以视为生成表 2:KL 正则化对迭代 M-DPO 影响的消融研究。多样性和奖励优化之间的权衡。如表 B.3 所示,更新参考模型的模型优于使用固定参考模型的模型。我们假设在推理任务中,正确的推理路径高度集中,使得多样性不太重要,因此优化非正则化奖励可以获得更好的模型性能。先前的工作 (Tunstall et al.|) 关于离线 DPO 表明,较低的 KL 系数 (0.01) 通过允许模型更多地偏离 SFT 模型 π 0 \pi_0 π0 来提高性能。在我们的消融研究中,我们搜索了 KL 系数 η ∈ { 0.01 , 0.1 , 0.5 } \eta \in \{0.01, 0.1, 0.5\} η∈{0.01,0.1,0.5}。根据表 B.3,我们发现最强的模型是通过适度的 KL 系数 0.1 获得的,它优于 0.01 和 0.5。为了解释这一点,我们在迭代训练期间绘制了 GSM8K 测试准确率(图 [3p])。在第一次迭代中,较低的 KL 值显示出更大的改进,与 [Tunstall et al.| (2023)] 的结果一致。然而,使用非常低的 KL 系数训练的模型会迅速失去多样性,降低了它们生成多样化轨迹以供后续训练的能力,导致后续迭代的收益递减。相反,较高的 KL 系数 0.5 施加了过多的正则化,限制了每次迭代的改进。总之,对于在线迭代训练,平衡每次迭代的改进和探索效率是优化整体性能的关键,这种直觉也适用于采样策略和其他实验技术。
图 3:左:不同 KL 正则化水平下 GSM8K 数据集的测试准确率。右:不同采样策略下 MATH 数据集的测试准确率。
采样策略的影响:数据多样性和覆盖率至关重要。在 Gemma-1.1-it-7B 的迭代训练中,我们看到正确轨迹的比例从第一次迭代的 47% 增加到最后一次迭代的 76%。此外,随着参考模型在每一步更新,轨迹多样性下降,这对于 DPO/KTO 训练至关重要,因为它具有对比性质。我们遵循 ; [Dong et al.| (2024) 探索了两种数据收集策略:(1)on-policy 采样(从当前模型采样的轨迹)和(2)混合采样(20 条轨迹来自当前模型,10 条来自上一次迭代的模型)。如表 [6] 所示,混合采样显著优于 on-policy 采样,特别是在第三次迭代中,on-policy 采样未能提高 MATH 测试准确率。这突出了迭代训练中多样性的重要性,并与先前的研究结果一致,即高级探索策略有助于防止多样性崩溃并改进偏好学习 (Bai et al.| {2022} {Touvron et al.| /2023| [Xiong et al.| 2024] 2024| 2024)。探索更高级的探索策略,例如 MCTS,在未来的研究中也会很有趣。为了确保正确和不正确的推理路径都存在,我们为每个方法收集了 N 条轨迹 | GSM8K MATH 提示。较大的 N 通常会提高提示 SFT 77.5 46.1 的覆盖率,因为对于困难问题需要更多样本。N=30 + Mixture 83.9 51.2 例如,在迭代 1 中,当 N=30 时,92.5% 的提示被覆盖,N=12 + Mixture 83.5 51.2 而 N=12 为 83.0%,N=6 为 60%。参见 N=6 + Mixture 82.0 49.2 图 [2] 以了解 pass@1 与 N 之间的关系。然而,增加 N=30 + On-policy 83.1 49.5 N 也会增加计算成本。在我们的消融研究中(表 [3.3]),我们发现将 N 从 6 增加到 12 会导致性能显著提升,反映了复杂问题更好的覆盖率。然而,将 N 从 12 增加到 30 仅产生微小改进,表明在 vanilla 拒绝采样中,较大的 N 的好处迅速减弱。我们预计困难感知采样可以带来更好的性能,同时保持适度的推理成本。
4 结论、限制和未来研究方向
在本文中,我们证明了偏好学习作为监督微调的替代方法,进一步增强了工具集成推理 LLMs 在 SFT 后的性能。我们介绍了一种在线迭代多轮直接偏好优化算法,并通过对多个基础模型的广泛实验进行了验证。结果显示,与 SFT 策略相比,pass@1 有显著提高,特别是在 GSM8K 和 MATH 等基准测试上。消融研究强调了平衡每次迭代的改进和探索的重要性,这通过适度的 KL 正则化和战略性探索选择得以实现。
仍然有几个有待探索的改进途径。我们目前的方法仅使用最终结果检查作为偏好信号,限制了具有正确或不正确答案的轨迹之间的比较。在数据排序阶段可以使用逐步奖励信号)。同时,细粒度的奖励信号可以支持使用高级探索策略,如 west-of-n 采样 (Pace et al.|[2024)) 或 MCTS (Xie et al.|/2024b)) 在我们的启发式探索实现中。最后,虽然直接偏好学习算法在带有代码解释器的数学推理任务中显示出有希望的收益,但它不能直接应用于更复杂和随机的外部环境或对抗动态对手的通用智能体学习。特别是,它需要构建一个价值网络,以便在优化目标中包含自适应裕度,并考虑外部环境的随机性。我们将对这种更复杂的算法的研究留待未来工作。
我们相信使结果可复现非常重要。遵循 ICLR 的作者指南,我们在此提供可复现性声明,以帮助感兴趣的读者复现我们的结果。大多数实现细节,包括超参数,都在第 -T 节和附录 |B] 中提供。此外,我们已开源我们的训练代码以及分步指南,以 Gemma-1.1-it-7B 为例。我们还提供了处理后的 SFT 数据集、提示集和 M-DPO/M-KTO 第一次迭代的训练数据,以便于下载(详情参见补充材料)。本文的 RLHF 实验在 8xA100 80G GPU 上运行,另外还使用了一台带有 8xA100 40G GPU 的机器来加速数据