背景介绍
在 LLM 推理过程中,主要采用了 AutoRegressive Sampling(ArS) 的方式来执行的,也就是说没有办法通过一次 forward pass 就可以获得最终的推理结果,这是和其他模型推理(比如图像检测,分类等)相比最大的区别。
下面是 ArS 推理过程中的算法描述
对于 target model 来说,在初始的 prompt 推理结束之后,就进入 for
循环开始 ArS 过程,直到生成到目标长度的为止。
由于每次推理只会生成一个 token,但是需要把所有的模型权重都 load 到 SRAM 进行计算,所以这个过程是 memory bandwidth bounded。
算法介绍
考虑到 ArS 是 memory bound,而且每次推理只能生成一个 token,有没有办法一次推理生成多个 token 呢?
这样做有下面的好处:
- 模型的总 latency 由 (单次推理时间 x 推理次数决定) 来决定,如果每次推理产出的 token 数量更多,那么总的推理次数就会更少,这样就可以减少模型的 latency;
- 由于 ArS 是 memory bound,所以同时推理一个 token 和多个 token 在 kernel 内部的计算上几乎没有区别,也就是说我们可以保证单次推理时间几乎不变;
Speculative Sampling 的核心 idea 就是单步推理多个 token,同时还保证了推理的结果分布和以前一致,那么他是怎么做到的呢?
- 使用一个小模型(draft model)来生成 K 个 draft tokens;
- 使用原始的大模型(target model)给生成的 K 个 draft tokens 打分;
- 采用 modified rejection sampling 算法来决定 K 个 draft tokens 的接受和拒绝。
首先第一步采用 draft model 进行推理,由于 draft model 可以选择一个非常小的模型(100M),所以和 target model(100B)相比,他的推理时间可以几乎忽略,这样我们就在几乎没有 cost 的情况下拿到了 K 个 token。
第二步使用 target model 对 K 个 draft tokens 进行打分,这一步可以一次 forward 推理 K 个 tokens,实现了我们之前说的单次推理多个 tokens,由于 memory bound 的特性,所以时间和单个 token 的推理非常接近。
第三步通过一个拒绝采样算法来判断我们是否要接受 draft tokens 的结果,这一步非常关键,因为 draft model 非常小,所以其生成的结果一定存在错误,如果可以通过 target model 对结果进行修正,那么才能保证生成的结果和原始分布一直。
通过上面三个步骤就实现了我们上面的需求,具体的算法流程如下
该算法流程基本描述了上面的三个步骤,其中最关键的就是第三步,主要是通过采样 γ∼U[0,1] ,判断是否 γ<min(1,pq) 来决定 draft token 是否被接受,如果接受那么继续判断下一个 draft token,如果拒绝则重新采样一个新的 token。
新的 token 采样是通过 x∼(q−p)+ 来确定的,其中
(f(x))+=∑xmax(0,f(x))max(0,f(x)
因为 q 和 p 分别是 draft token 和 target token 的概率分布,所以直接代入公式可以进行计算,接着再进行采样即可。
最后注意如果某个 draft token xn 被拒绝,那么他后面的 draft token xn+1…xn+k 都会被拒绝,因为他们都是基于 xn 做 ArS 生成的,所以结果都是错误的。
根据上面的结论,考虑最坏的情况,即第一个 draft token 就被拒绝了,那么通过拒绝采样算法,我们能 resample 一次结果,保证我们仍能够输出一个 token;而在最好的情况,即所有的 draft token 都被接受了,那么我们一次 forward 就获得了 K 个 token。
数学推导
下面来推导一下为什么采用投机采样是精确解,而不是近似解。
假设 Pt 和 Pd 分别表示 target model 和 draft model 的概率分布, x~∼Pd 表示 draft model 的下一次采样, X 表示采用投机采样之后,实际的采样结果。
如果 X=xi ,那么只有两种情况:如果 x~=xi,然后接受这个情况;或者是拒绝这次 draft model 的结果,然后 resample 新的结果。
可以有下面的公式
P(X=xi)=P(X=xi∣x~=xi)P(x~=xi)+∑jP(X=xi∣x~=xj)P(x~=xj)
对于第一项 P(X=xi∣x~=xi)P(x~=xi) 表示接受的情况,而根据前面的描述,设定 accept rate 是 P(X=xi∣x~=xi)=min(1,Pd(xi)Pt(xi)) , P(x~=xi)=Pd(xi) ,所以代入公式得到
P(X=xi∣x~=xi)P(x~=xi)=min(1,Pd(xi)Pt(xi))∗Pd(xi)=min(Pd(xi),Pt(xi))
对于第二项, P(X=xi∣x~=xj) 表示 draft model 选 xj 但是最终采样选 xi 的情况,根据论文中的设定,拒绝的概率是 1-min, resampling xi 概率是 (Pt(xi)−Pd(xi))+ ,其中
(f(x))+=∑xmax(0,f(x))max(0,f(x)
那么 P(X=xi∣x~=xj)=(1−min(1,Pd(xj)Pt(xj))(Pt(xi)−Pd(xi))+
所以最后有
j∑P(X=xi∣x~=xj)P(x~=xj)=j∑(1−min(1,Pd(xj)Pt(xj)))∑kmax(0,Pt(xk)−Pd(xk))max(0,Pt(xi)−Pd(xi))Pd(xj)=j∑(Pd(xj)−min(Pd(xj),Pt(xj))∑kmax(0,Pt(xk)−Pd(xk))max(0,Pt(xi)−Pd(xi))=max(0,Pt(xi)−Pd(xi))∑kmax(0,Pt(xk)−Pd(xk))∑j(Pd(xj)−min(Pd(xj),Pt(xj))
而我们知道下面的公式成立
min(a,b)+max(0,b−a)=b
所以代入 Pt 和 Pd 可以得到
max(0,Pd−Pt)=Pd−min(Pt,Pd)
所以可以得到
∑j(Pd(xj)−min(Pd(xj),Pt(xj))=∑jmax(0,Pd(xj)−Pt(xj))
所以最终可以得到
∑kmax(0,Pt(xk)−Pd(xk))∑j(Pd(xj)−min(Pd(xj),Pt(xj))=∑kmax(0,Pt(xk)−Pd(xk))∑jmax(0,Pd(xj)−Pt(xj))=1
注意对于最后一个等号,可以这样推导
∑iPd(xi)=∑iPt(xi)=1
设定 I={i∣Pd(xi)≤Pt(xi)} ,那么可以有
i∈I∑(Pt(xi)−Pd(xi))=i∈/I∑(Pd(xi)−Pt(xi))=i∑max(0,Pt(xi)−Pd(xi))=i∑max(0,Pd(xi)−Pt(xi))
综合上面所有的结论,可以得到
P(X=xi)=min(Pd(xi),Pt(xi))+max(0,Pt(xi)−Pd(xi))=Pt(xi)
Reference