Dpo

$$ \nonumber \newcommand{aster}{*} \newcommand{exist}{\exists} \newcommand{B}{\mathbb B} \newcommand{C}{\mathbb C} \newcommand{I}{\mathbb I} \newcommand{N}{\mathbb N} \newcommand{Q}{\mathbb Q} \newcommand{R}{\mathbb R} \newcommand{Z}{\mathbb Z} \newcommand{eR}{\overline {\mathbb R}} \newcommand{cD}{ {\mathbb D}} \newcommand{dD}{ {\part \mathbb D}} \newcommand{dH}{ {\part \mathbb H}} \newcommand{eC}{\overline {\mathbb C}} \newcommand{A}{\mathcal A} \newcommand{D}{\mathcal D} \newcommand{E}{\mathcal E} \newcommand{F}{\mathcal F} \newcommand{G}{\mathcal G} \newcommand{H}{\mathcal H} \newcommand{J}{\mathcal J} \newcommand{L}{\mathcal L} \newcommand{U}{\mathcal U} \newcommand{M}{\mathcal M} \newcommand{O}{\mathcal O} \newcommand{P}{\mathcal P} \newcommand{S}{\mathcal S} \newcommand{T}{\mathcal T} \newcommand{V}{\mathcal V} \newcommand{W}{\mathcal W} \newcommand{X}{\mathcal X} \newcommand{Y}{\mathcal Y} \newcommand{bE}{\symbf E} \newcommand{bF}{\symbf F} \newcommand{bD}{\symbf D} \newcommand{bI}{\symbf I} \newcommand{bX}{\symbf X} \newcommand{bY}{\symbf Y} \newcommand{nz}{\mathcal Z} \newcommand{bT}{\mathbb T} \newcommand{bB}{\mathbb B} \newcommand{bS}{\mathbb S} \newcommand{bA}{\mathbb A} \newcommand{bL}{\mathbb L} \newcommand{bP}{\symbf P} \newcommand{bM}{\symbf M} \newcommand{bH}{\mathbb H} \newcommand{dd}{\mathrm d} \newcommand{Mu}{\mathup M} \newcommand{Tau}{\mathup T} \newcommand{ae}{\operatorname{a.e.}} \newcommand{aut}{\operatorname{aut}} \newcommand{adj}{\operatorname{adj}} \newcommand{char}{\operatorname{char}} \newcommand{cov}{\operatorname{Cov}} \newcommand{cl}{\operatorname{cl}} \newcommand{cont}{\operatorname{cont}} \newcommand{e}{\mathbb E} \newcommand{pp}{\operatorname{primitive}} \newcommand{dist}{\operatorname{dist}} \newcommand{diam}{\operatorname{diam}} \newcommand{fp}{\operatorname{Fp}} \newcommand{from}{\leftarrow} \newcommand{Gal}{\operatorname{Gal}} \newcommand{GCD}{\operatorname{GCD}} \newcommand{LCM}{\operatorname{LCM}} \newcommand{fg}{\mathrm{fg}} \newcommand{gf}{\mathrm{gf}} \newcommand{im}{\operatorname{Im}} \newcommand{image}{\operatorname{image}} \newcommand{inj}{\hookrightarrow} \newcommand{irr}{\operatorname{irr}} \newcommand{lcm}{\operatorname{lcm}} \newcommand{ltrieq}{\mathrel{\unlhd}} \newcommand{ltri}{\mathrel{\lhd}} \newcommand{loc}{ {\operatorname{loc}}} \newcommand{null}{\operatorname{null}} \newcommand{part}{\partial} \newcommand{pf}{\operatorname{Pf}} \newcommand{pv}{\operatorname{Pv}} \newcommand{rank}{\operatorname{rank}} \newcommand{range}{\operatorname{range}} \newcommand{re}{\operatorname{Re}} \newcommand{span}{\operatorname{span}} \newcommand{su}{\operatorname{supp}} \newcommand{sgn}{\operatorname{sgn}} \newcommand{syn}{\operatorname{syn}} \newcommand{var}{\operatorname{Var}} \newcommand{res}{\operatorname{Res}} \newcommand{data}{\operatorname{data}} \newcommand{erfc}{\operatorname{erfc}} \newcommand{erfcx}{\operatorname{erfcx}} \newcommand{tr}{\operatorname{tr}} \newcommand{col}{\operatorname{Col}} \newcommand{row}{\operatorname{Row}} \newcommand{sol}{\operatorname{Sol}} \newcommand{lub}{\operatorname{lub}} \newcommand{glb}{\operatorname{glb}} \newcommand{ltrieq}{\mathrel{\unlhd}} \newcommand{ltri}{\mathrel{\lhd}} \newcommand{lr}{\leftrightarrow} \newcommand{phat}{^\widehat{\,\,\,}} \newcommand{what}{\widehat} \newcommand{wbar}{\overline} \newcommand{wtilde}{\widetilde} \newcommand{iid}{\operatorname{i.i.d.}} \newcommand{Exp}{\operatorname{Exp}} \newcommand{abs}[1]{\left| {#1}\right|} \newcommand{d}[2]{D_{\text{KL}}\left (#1\middle\| #2\right)} \newcommand{n}[1]{\|#1\|} \newcommand{norm}[1]{\left\|{#1}\right\|} \newcommand{pd}[2]{\left \langle {#1},{#2} \right \rangle} \newcommand{argmax}[1]{\underset{#1}{\operatorname{argmax}}} \newcommand{argmin}[1]{\underset{#1}{\operatorname{argmin}}} \newcommand{p}[1]{\left({#1}\right)} \newcommand{c}[1]{\left \{ {#1}\right\}} \newcommand{s}[1]{\left [{#1}\right]} \newcommand{a}[1]{\left \langle{#1}\right\rangle} \newcommand{cc}[2]{\left(\begin{array}{c} #1 \\ #2 \end{array}\right)} \newcommand{f}{\mathfrak F} \newcommand{fi}{\mathfrak F^{-1}} \newcommand{Fi}{\mathcal F^{-1}} \newcommand{l}{\mathfrak L} \newcommand{li}{\mathfrak L^{-1}} \newcommand{Li}{\mathcal L^{-1}} \newcommand{const}{\text{const.}} \newcommand{Int}{\operatorname{Int}} \newcommand{Ext}{\operatorname{Ext}} \newcommand{Bd}{\operatorname{Bd}} \newcommand{Cl}{\operatorname{Cl}} \newcommand{Iso}{\operatorname{Iso}} \newcommand{Lim}{\operatorname{Lim}} \newcommand{src}{\operatorname{src}} \newcommand{tgt}{\operatorname{tgt}} \newcommand{input}{\operatorname{input}} \newcommand{output}{\operatorname{output}} \newcommand{weight}{\operatorname{weight}} \newcommand{paths}{\operatorname{paths}} \newcommand{init}{\bullet} \newcommand{fin}{\circledcirc} \newcommand{advance}{\operatorname{advance}} \newcommand{di}[2]{\frac{\part}{\part {#1}^{#2}}} $$

Direct Preference Optimization #

Direct Preference Optimization: Your Language Model is Secretly a Reward Model

https://openreview.net/forum?id=HPuSIXJaa9#

Bradley-Terry model #

Consider agents $(A _ i) _ {i =1}^N$. Each agent has score $\alpha _ i > 0$.

Denote $i > j$ as the event $A _ i$ is winning $A _ j$. Bradley-Terry model defines the model winning rate as: $$ P(i \succ j) = \frac{\alpha _ i}{\alpha _ i + \alpha _ j} $$

RLHF and DPO #

In RLHF, we implement an reward model $\alpha _ \theta(x, y)$ that represents an estimate of the score, where $x$ is the prompt, and $y$ is the answer of the LLM. Suppose $\alpha _ \theta(x, y) := \exp r _ \theta(x, y)$. Then: $$ P(y _ w \succ y _ l) = \frac{\alpha _ \theta(x, y _ w)}{\alpha _ \theta(x, y _ w) + \alpha _ \theta(x, y _ l)} $$ When maximizing the log likelihood, this term becomes: $$ \log P(y _ w \succ y _ l) = \log \s{\frac{1}{1 + \exp\p{r _ \theta(x, y _ l) - r _ \theta(x, y _ w)}}} = \log \sigma\p{r _ \theta(x, y _ w) - r _ \theta(x, y _ l)} $$ RLHF optimizes the following objective: $$ \begin{aligned} \L & = E _ {X \sim \D, Y \sim \pi} \s{r(X, Y)} - \beta \d{\pi(y | x)}{\pi _ {\text{ref}}(y | x)} = E \s{r(X, Y) - \beta \log \frac{\pi(Y | X)}{\pi _ {\text{ref}}( Y | X)}} \end{aligned} $$ We have a closed form solution $\pi _ * $ for $\pi$ here, since the loss is actually taking the form of a KL divergence: $$ \begin{aligned} -\L & \propto E\s{\log \frac{\pi(Y | X)}{\pi _ {\text{ref}}( Y | X)} - \frac{1}{\beta}r(X, Y)} = E\s{\log \frac{\pi(Y | X)}{\pi _ {\text{ref}}(Y | X)} - \log \exp\c{\frac{1}{\beta} r(X, Y)}}\\ & = E _ {Y \sim \pi(Y|X)}\s{\log \frac{\pi(Y | X)}{\pi _ {\text{ref}}(Y | X) \exp\p{r(X, Y) / \beta}}} = E\s{\log \frac{\pi(Y | X)/Z(X)}{\pi _ {\text{ref}}(Y | X) \exp\p{r(X, Y) / \beta} /Z(X)}}\\ & = E\s{\log \frac{\pi(Y | X)}{\pi _ * (Y | X)} - \log Z(X)} = \d{\pi(y | x)}{\pi _ * (y | x)} + \const \end{aligned} $$ There is a one-to-one relationship between the optimal oplicy $\pi _ * $ and the reward model $r(x, y)$: $$ \pi _ * (y | x) = \frac{\pi _ {\text{ref}}(y | x) \exp\p{r(x, y) / \beta}}{Z(x)} \implies r(x, y) = \beta\cdot \log \c{\frac{\pi _ * (y | x)}{\pi _ {\text{ref}}(y | x)} Z(x)} $$

  • Where $Z(x)$ can be considered an aribrartrary factor per $X$.
  • In RLHF, we start from an optimal reward model $r(x, y)$, then optimize for the model $\pi(y | x)$.
  • But in DPO, we notice that $r(x, y)$ can be directly parameterized by $\pi _ * (y | x)$.

As a result, we can easily simultaneously optimize the policy and the reward model. $$ \L _ {\text{DPO}}(\theta) = E\s{-\log \sigma \c{\beta \log \frac{\pi _ \theta(Y _ w | X)}{\pi _ \text{ref}(Y _ w | X)} - \beta \log \frac{\pi _ \theta(Y _ l|X)}{\pi _ \text{ref}(Y _ l | X)}}} $$ You could probably just tune a discrete model in this fashion and it will perform better. Where $y _ w$ are samples from the real data, and $y _ l$ are generated samples from the model.

DPO Gradient #

Recall that:

  • $\sigma(x) = 1 / (1 + e^{-x})$ and $\sigma'(x) = \sigma(x)(1 - \sigma(x))$.
  • And $f(x) = \log\sigma(x)$ gives $f'(x) = 1 - \sigma(x) = \sigma(-x)$.

The gradient of DPO loss takes the following form: $$ \begin{aligned} \nabla _ \theta \L _ \text{DPO} (\theta) & = E \s{\sigma{\p{r _ \theta(Y _ l | X) - r _ \theta(Y _ w|X)}} \cdot \beta \p{\nabla _ \theta \log \pi _ \theta(Y _ w | X) - \nabla _ \theta \log \pi _ \theta(Y _ l | X)}} \end{aligned} $$ Here $r _ \theta(y | x)$ is the induced reward model.