$$
\nonumber
\newcommand{aster}{*}
\newcommand{exist}{\exists}
\newcommand{B}{\mathbb B}
\newcommand{C}{\mathbb C}
\newcommand{I}{\mathbb I}
\newcommand{N}{\mathbb N}
\newcommand{Q}{\mathbb Q}
\newcommand{R}{\mathbb R}
\newcommand{Z}{\mathbb Z}
\newcommand{eR}{\overline {\mathbb R}}
\newcommand{cD}{ {\mathbb D}}
\newcommand{dD}{ {\part \mathbb D}}
\newcommand{dH}{ {\part \mathbb H}}
\newcommand{eC}{\overline {\mathbb C}}
\newcommand{A}{\mathcal A}
\newcommand{D}{\mathcal D}
\newcommand{E}{\mathcal E}
\newcommand{F}{\mathcal F}
\newcommand{G}{\mathcal G}
\newcommand{H}{\mathcal H}
\newcommand{J}{\mathcal J}
\newcommand{L}{\mathcal L}
\newcommand{U}{\mathcal U}
\newcommand{M}{\mathcal M}
\newcommand{O}{\mathcal O}
\newcommand{P}{\mathcal P}
\newcommand{S}{\mathcal S}
\newcommand{T}{\mathcal T}
\newcommand{V}{\mathcal V}
\newcommand{W}{\mathcal W}
\newcommand{X}{\mathcal X}
\newcommand{Y}{\mathcal Y}
\newcommand{bE}{\symbf E}
\newcommand{bF}{\symbf F}
\newcommand{bD}{\symbf D}
\newcommand{bI}{\symbf I}
\newcommand{bX}{\symbf X}
\newcommand{bY}{\symbf Y}
\newcommand{nz}{\mathcal Z}
\newcommand{bT}{\mathbb T}
\newcommand{bB}{\mathbb B}
\newcommand{bS}{\mathbb S}
\newcommand{bA}{\mathbb A}
\newcommand{bL}{\mathbb L}
\newcommand{bP}{\symbf P}
\newcommand{bM}{\symbf M}
\newcommand{bH}{\mathbb H}
\newcommand{dd}{\mathrm d}
\newcommand{Mu}{\mathup M}
\newcommand{Tau}{\mathup T}
\newcommand{ae}{\operatorname{a.e.}}
\newcommand{aut}{\operatorname{aut}}
\newcommand{adj}{\operatorname{adj}}
\newcommand{char}{\operatorname{char}}
\newcommand{cov}{\operatorname{Cov}}
\newcommand{cl}{\operatorname{cl}}
\newcommand{cont}{\operatorname{cont}}
\newcommand{e}{\mathbb E}
\newcommand{pp}{\operatorname{primitive}}
\newcommand{dist}{\operatorname{dist}}
\newcommand{diam}{\operatorname{diam}}
\newcommand{fp}{\operatorname{Fp}}
\newcommand{from}{\leftarrow}
\newcommand{Gal}{\operatorname{Gal}}
\newcommand{GCD}{\operatorname{GCD}}
\newcommand{LCM}{\operatorname{LCM}}
\newcommand{fg}{\mathrm{fg}}
\newcommand{gf}{\mathrm{gf}}
\newcommand{im}{\operatorname{Im}}
\newcommand{image}{\operatorname{image}}
\newcommand{inj}{\hookrightarrow}
\newcommand{irr}{\operatorname{irr}}
\newcommand{lcm}{\operatorname{lcm}}
\newcommand{ltrieq}{\mathrel{\unlhd}}
\newcommand{ltri}{\mathrel{\lhd}}
\newcommand{loc}{ {\operatorname{loc}}}
\newcommand{null}{\operatorname{null}}
\newcommand{part}{\partial}
\newcommand{pf}{\operatorname{Pf}}
\newcommand{pv}{\operatorname{Pv}}
\newcommand{rank}{\operatorname{rank}}
\newcommand{range}{\operatorname{range}}
\newcommand{re}{\operatorname{Re}}
\newcommand{span}{\operatorname{span}}
\newcommand{su}{\operatorname{supp}}
\newcommand{sgn}{\operatorname{sgn}}
\newcommand{syn}{\operatorname{syn}}
\newcommand{var}{\operatorname{Var}}
\newcommand{res}{\operatorname{Res}}
\newcommand{data}{\operatorname{data}}
\newcommand{erfc}{\operatorname{erfc}}
\newcommand{erfcx}{\operatorname{erfcx}}
\newcommand{tr}{\operatorname{tr}}
\newcommand{col}{\operatorname{Col}}
\newcommand{row}{\operatorname{Row}}
\newcommand{sol}{\operatorname{Sol}}
\newcommand{lub}{\operatorname{lub}}
\newcommand{glb}{\operatorname{glb}}
\newcommand{ltrieq}{\mathrel{\unlhd}}
\newcommand{ltri}{\mathrel{\lhd}}
\newcommand{lr}{\leftrightarrow}
\newcommand{phat}{^\widehat{\,\,\,}}
\newcommand{what}{\widehat}
\newcommand{wbar}{\overline}
\newcommand{wtilde}{\widetilde}
\newcommand{iid}{\operatorname{i.i.d.}}
\newcommand{Exp}{\operatorname{Exp}}
\newcommand{abs}[1]{\left| {#1}\right|}
\newcommand{d}[2]{D_{\text{KL}}\left (#1\middle\| #2\right)}
\newcommand{n}[1]{\|#1\|}
\newcommand{norm}[1]{\left\|{#1}\right\|}
\newcommand{pd}[2]{\left \langle {#1},{#2} \right \rangle}
\newcommand{argmax}[1]{\underset{#1}{\operatorname{argmax}}}
\newcommand{argmin}[1]{\underset{#1}{\operatorname{argmin}}}
\newcommand{p}[1]{\left({#1}\right)}
\newcommand{c}[1]{\left \{ {#1}\right\}}
\newcommand{s}[1]{\left [{#1}\right]}
\newcommand{a}[1]{\left \langle{#1}\right\rangle}
\newcommand{cc}[2]{\left(\begin{array}{c} #1 \\ #2 \end{array}\right)}
\newcommand{f}{\mathfrak F}
\newcommand{fi}{\mathfrak F^{-1}}
\newcommand{Fi}{\mathcal F^{-1}}
\newcommand{l}{\mathfrak L}
\newcommand{li}{\mathfrak L^{-1}}
\newcommand{Li}{\mathcal L^{-1}}
\newcommand{const}{\text{const.}}
\newcommand{Int}{\operatorname{Int}}
\newcommand{Ext}{\operatorname{Ext}}
\newcommand{Bd}{\operatorname{Bd}}
\newcommand{Cl}{\operatorname{Cl}}
\newcommand{Iso}{\operatorname{Iso}}
\newcommand{Lim}{\operatorname{Lim}}
\newcommand{src}{\operatorname{src}}
\newcommand{tgt}{\operatorname{tgt}}
\newcommand{input}{\operatorname{input}}
\newcommand{output}{\operatorname{output}}
\newcommand{weight}{\operatorname{weight}}
\newcommand{paths}{\operatorname{paths}}
\newcommand{init}{\bullet}
\newcommand{fin}{\circledcirc}
\newcommand{advance}{\operatorname{advance}}
\newcommand{di}[2]{\frac{\part}{\part {#1}^{#2}}}
$$
Direct Preference Optimization
#
Direct Preference Optimization: Your Language Model is Secretly a Reward Model
https://openreview.net/forum?id=HPuSIXJaa9#
Bradley-Terry model
#
Consider agents $(A _ i) _ {i =1}^N$. Each agent has score $\alpha _ i > 0$.
Denote $i > j$ as the event $A _ i$ is winning $A _ j$. Bradley-Terry model defines the model winning rate as:
$$
P(i \succ j) = \frac{\alpha _ i}{\alpha _ i + \alpha _ j}
$$
RLHF and DPO
#
In RLHF, we implement an reward model $\alpha _ \theta(x, y)$ that represents an estimate of the score, where $x$ is the prompt, and $y$ is the answer of the LLM. Suppose $\alpha _ \theta(x, y) := \exp r _ \theta(x, y)$. Then:
$$
P(y _ w \succ y _ l) = \frac{\alpha _ \theta(x, y _ w)}{\alpha _ \theta(x, y _ w) + \alpha _ \theta(x, y _ l)}
$$
When maximizing the log likelihood, this term becomes:
$$
\log P(y _ w \succ y _ l) = \log \s{\frac{1}{1 + \exp\p{r _ \theta(x, y _ l) - r _ \theta(x, y _ w)}}} = \log \sigma\p{r _ \theta(x, y _ w) - r _ \theta(x, y _ l)}
$$
RLHF optimizes the following objective:
$$
\begin{aligned}
\L & = E _ {X \sim \D, Y \sim \pi} \s{r(X, Y)} - \beta \d{\pi(y | x)}{\pi _ {\text{ref}}(y | x)} = E \s{r(X, Y) - \beta \log \frac{\pi(Y | X)}{\pi _ {\text{ref}}( Y | X)}}
\end{aligned}
$$
We have a closed form solution $\pi _ * $ for $\pi$ here, since the loss is actually taking the form of a KL divergence:
$$
\begin{aligned}
-\L & \propto E\s{\log \frac{\pi(Y | X)}{\pi _ {\text{ref}}( Y | X)} - \frac{1}{\beta}r(X, Y)} = E\s{\log \frac{\pi(Y | X)}{\pi _ {\text{ref}}(Y | X)} - \log \exp\c{\frac{1}{\beta} r(X, Y)}}\\
& = E _ {Y \sim \pi(Y|X)}\s{\log \frac{\pi(Y | X)}{\pi _ {\text{ref}}(Y | X) \exp\p{r(X, Y) / \beta}}} = E\s{\log \frac{\pi(Y | X)/Z(X)}{\pi _ {\text{ref}}(Y | X) \exp\p{r(X, Y) / \beta} /Z(X)}}\\
& = E\s{\log \frac{\pi(Y | X)}{\pi _ * (Y | X)} - \log Z(X)} = \d{\pi(y | x)}{\pi _ * (y | x)} + \const
\end{aligned}
$$
There is a one-to-one relationship between the optimal oplicy $\pi _ * $ and the reward model $r(x, y)$:
$$
\pi _ * (y | x) = \frac{\pi _ {\text{ref}}(y | x) \exp\p{r(x, y) / \beta}}{Z(x)} \implies r(x, y) = \beta\cdot \log \c{\frac{\pi _ * (y | x)}{\pi _ {\text{ref}}(y | x)} Z(x)}
$$
- Where $Z(x)$ can be considered an aribrartrary factor per $X$.
- In RLHF, we start from an optimal reward model $r(x, y)$, then optimize for the model $\pi(y | x)$.
- But in DPO, we notice that $r(x, y)$ can be directly parameterized by $\pi _ * (y | x)$.
As a result, we can easily simultaneously optimize the policy and the reward model.
$$
\L _ {\text{DPO}}(\theta) = E\s{-\log \sigma \c{\beta \log \frac{\pi _ \theta(Y _ w | X)}{\pi _ \text{ref}(Y _ w | X)} - \beta \log \frac{\pi _ \theta(Y _ l|X)}{\pi _ \text{ref}(Y _ l | X)}}}
$$
You could probably just tune a discrete model in this fashion and it will perform better. Where $y _ w$ are samples from the real data, and $y _ l$ are generated samples from the model.
DPO Gradient
#
Recall that:
- $\sigma(x) = 1 / (1 + e^{-x})$ and $\sigma'(x) = \sigma(x)(1 - \sigma(x))$.
- And $f(x) = \log\sigma(x)$ gives $f'(x) = 1 - \sigma(x) = \sigma(-x)$.
The gradient of DPO loss takes the following form:
$$
\begin{aligned}
\nabla _ \theta \L _ \text{DPO} (\theta) & = E \s{\sigma{\p{r _ \theta(Y _ l | X) - r _ \theta(Y _ w|X)}} \cdot \beta \p{\nabla _ \theta \log \pi _ \theta(Y _ w | X) - \nabla _ \theta \log \pi _ \theta(Y _ l | X)}}
\end{aligned}
$$
Here $r _ \theta(y | x)$ is the induced reward model.