Skip to content

Speculative Sampling for Faster LLM Inference

Posted on:June 20, 2024 at 12:00 AM

背景介绍

在 LLM 推理过程中,主要采用了 AutoRegressive Sampling(ArS) 的方式来执行的,也就是说没有办法通过一次 forward pass 就可以获得最终的推理结果,这是和其他模型推理(比如图像检测,分类等)相比最大的区别。

下面是 ArS 推理过程中的算法描述

对于 target model 来说,在初始的 prompt 推理结束之后,就进入 for 循环开始 ArS 过程,直到生成到目标长度的为止。

由于每次推理只会生成一个 token,但是需要把所有的模型权重都 load 到 SRAM 进行计算,所以这个过程是 memory bandwidth bounded。

算法介绍

考虑到 ArS 是 memory bound,而且每次推理只能生成一个 token,有没有办法一次推理生成多个 token 呢?

这样做有下面的好处:

  1. 模型的总 latency 由 (单次推理时间 x 推理次数决定) 来决定,如果每次推理产出的 token 数量更多,那么总的推理次数就会更少,这样就可以减少模型的 latency;
  2. 由于 ArS 是 memory bound,所以同时推理一个 token 和多个 token 在 kernel 内部的计算上几乎没有区别,也就是说我们可以保证单次推理时间几乎不变;

Speculative Sampling 的核心 idea 就是单步推理多个 token,同时还保证了推理的结果分布和以前一致,那么他是怎么做到的呢?

  1. 使用一个小模型(draft model)来生成 K 个 draft tokens;
  2. 使用原始的大模型(target model)给生成的 K 个 draft tokens 打分;
  3. 采用 modified rejection sampling 算法来决定 K 个 draft tokens 的接受和拒绝。

首先第一步采用 draft model 进行推理,由于 draft model 可以选择一个非常小的模型(100M),所以和 target model(100B)相比,他的推理时间可以几乎忽略,这样我们就在几乎没有 cost 的情况下拿到了 K 个 token。

第二步使用 target model 对 K 个 draft tokens 进行打分,这一步可以一次 forward 推理 K 个 tokens,实现了我们之前说的单次推理多个 tokens,由于 memory bound 的特性,所以时间和单个 token 的推理非常接近。

第三步通过一个拒绝采样算法来判断我们是否要接受 draft tokens 的结果,这一步非常关键,因为 draft model 非常小,所以其生成的结果一定存在错误,如果可以通过 target model 对结果进行修正,那么才能保证生成的结果和原始分布一直。

通过上面三个步骤就实现了我们上面的需求,具体的算法流程如下

该算法流程基本描述了上面的三个步骤,其中最关键的就是第三步,主要是通过采样 γU[0,1]\gamma \sim U[0,1] ,判断是否 γ<min(1,qp)\gamma < \min(1, \frac{q}{p}) 来决定 draft token 是否被接受,如果接受那么继续判断下一个 draft token,如果拒绝则重新采样一个新的 token。

新的 token 采样是通过 x(qp)+x \sim (q - p)_+ 来确定的,其中

(f(x))+=max(0,f(x)xmax(0,f(x))(f(x))_+ = \frac{\max(0, f(x)}{\sum_x \max(0, f(x))}

因为 q 和 p 分别是 draft token 和 target token 的概率分布,所以直接代入公式可以进行计算,接着再进行采样即可。

最后注意如果某个 draft token xnx_n 被拒绝,那么他后面的 draft token xn+1xn+kx_{n+1} \dots x_{n+k} 都会被拒绝,因为他们都是基于 xnx_n 做 ArS 生成的,所以结果都是错误的。

根据上面的结论,考虑最坏的情况,即第一个 draft token 就被拒绝了,那么通过拒绝采样算法,我们能 resample 一次结果,保证我们仍能够输出一个 token;而在最好的情况,即所有的 draft token 都被接受了,那么我们一次 forward 就获得了 K 个 token。

数学推导

下面来推导一下为什么采用投机采样是精确解,而不是近似解。

假设 PtP_tPdP_d 分别表示 target model 和 draft model 的概率分布, x~Pd\tilde{x} \sim P_d 表示 draft model 的下一次采样, XX 表示采用投机采样之后,实际的采样结果。

如果 X=xiX=x_i ,那么只有两种情况:如果 x~=xi\tilde{x} = x_i,然后接受这个情况;或者是拒绝这次 draft model 的结果,然后 resample 新的结果。

可以有下面的公式

P(X=xi)=P(X=xix~=xi)P(x~=xi)+jP(X=xix~=xj)P(x~=xj)P(X=x_i) = P(X=x_i | \tilde{x} = x_i)P(\tilde{x}=x_i) + \sum_j P(X=x_i | \tilde{x} = x_j)P(\tilde{x} = x_j)

对于第一项 P(X=xix~=xi)P(x~=xi)P(X=x_i | \tilde{x} = x_i)P(\tilde{x}=x_i) 表示接受的情况,而根据前面的描述,设定 accept rate 是 P(X=xix~=xi)=min(1,Pt(xi)Pd(xi))P(X=x_i | \tilde{x} = x_i)=\min(1, \frac{P_t(x_i)}{P_d(x_i)})P(x~=xi)=Pd(xi)P(\tilde{x}=x_i) = P_d(x_i) ,所以代入公式得到

P(X=xix~=xi)P(x~=xi)=min(1,Pt(xi)Pd(xi))Pd(xi)=min(Pd(xi),Pt(xi))P(X=x_i | \tilde{x} = x_i)P(\tilde{x}=x_i) = \min(1, \frac{P_t(x_i)}{P_d(x_i)}) * P_d(x_i) = \min(P_d(x_i), P_t(x_i))

对于第二项, P(X=xix~=xj)P(X=x_i | \tilde{x} = x_j) 表示 draft model 选 xjx_j 但是最终采样选 xix_i 的情况,根据论文中的设定,拒绝的概率是 1-min, resampling xix_i 概率是 (Pt(xi)Pd(xi))+(P_t(x_i)-P_d(x_i))_+ ,其中

(f(x))+=max(0,f(x)xmax(0,f(x))(f(x))_+ = \frac{\max(0, f(x)}{\sum_x \max(0, f(x))}

那么 P(X=xix~=xj)=(1min(1,Pt(xj)Pd(xj))(Pt(xi)Pd(xi))+P(X=x_i | \tilde{x} = x_j)=(1-\min(1, \frac{P_t(x_j)}{P_d(x_j)}) (P_t(x_i)-P_d(x_i))_+

所以最后有

jP(X=xix~=xj)P(x~=xj)=j(1min(1,Pt(xj)Pd(xj)))max(0,Pt(xi)Pd(xi))kmax(0,Pt(xk)Pd(xk))Pd(xj)=j(Pd(xj)min(Pd(xj),Pt(xj))max(0,Pt(xi)Pd(xi))kmax(0,Pt(xk)Pd(xk))=max(0,Pt(xi)Pd(xi))j(Pd(xj)min(Pd(xj),Pt(xj))kmax(0,Pt(xk)Pd(xk))\begin{align} \sum_j P(X=x_i | \tilde{x} = x_j)P(\tilde{x} = x_j) =\sum_j(1-\min(1, \frac{P_t(x_j)}{P_d(x_j)})) \frac{\max(0, P_t(x_i) - P_d(x_i))}{\sum_k \max(0, P_t(x_k) - P_d(x_k))} P_d(x_j) \\ =\sum_j (P_d(x_j) - \min(P_d(x_j), P_t(x_j)) \frac{\max(0, P_t(x_i) - P_d(x_i))}{\sum_k \max(0, P_t(x_k) - P_d(x_k))} \\ =\max(0, P_t(x_i) - P_d(x_i)) \frac{\sum_j (P_d(x_j) - \min(P_d(x_j), P_t(x_j))}{\sum_k \max(0, P_t(x_k) - P_d(x_k))} \end{align}

而我们知道下面的公式成立

min(a,b)+max(0,ba)=b\min(a, b) + \max(0, b-a) = b

所以代入 PtP_tPdP_d 可以得到

max(0,PdPt)=Pdmin(Pt,Pd)\max(0, P_d - P_t) = P_d - \min(P_t, P_d)

所以可以得到

j(Pd(xj)min(Pd(xj),Pt(xj))=jmax(0,Pd(xj)Pt(xj))\sum_j (P_d(x_j) - \min(P_d(x_j), P_t(x_j))=\sum_j \max(0, P_d(x_j) - P_t(x_j))

所以最终可以得到

j(Pd(xj)min(Pd(xj),Pt(xj))kmax(0,Pt(xk)Pd(xk))=jmax(0,Pd(xj)Pt(xj))kmax(0,Pt(xk)Pd(xk))=1\frac{\sum_j (P_d(x_j) - \min(P_d(x_j), P_t(x_j))}{\sum_k \max(0, P_t(x_k) - P_d(x_k))} =\frac{\sum_j \max(0, P_d(x_j) - P_t(x_j))}{\sum_k \max(0, P_t(x_k) - P_d(x_k))} \\ =1

注意对于最后一个等号,可以这样推导

iPd(xi)=iPt(xi)=1\sum_i P_d(x_i) = \sum_i P_t(x_i) = 1

设定 I={iPd(xi)Pt(xi)}I = \{ i | P_d(x_i) \leq P_t(x_i)\} ,那么可以有

iI(Pt(xi)Pd(xi))=iI(Pd(xi)Pt(xi))=imax(0,Pt(xi)Pd(xi))=imax(0,Pd(xi)Pt(xi))\begin{align} \sum_{i \in I} (P_t(x_i) - P_d(x_i)) &= \sum_{ i \notin I}(P_d(x_i) - P_t(x_i)) \\ &= \sum_i \max(0, P_t(x_i) - P_d(x_i)) \\ &= \sum_i \max(0, P_d(x_i) - P_t(x_i)) \end{align}

综合上面所有的结论,可以得到

P(X=xi)=min(Pd(xi),Pt(xi))+max(0,Pt(xi)Pd(xi))=Pt(xi)P(X=x_i) = \min(P_d(x_i), P_t(x_i)) + \max(0, P_t(x_i) - P_d(x_i)) = P_t(x_i)

Reference


Comments