Skip to content

RoPE and Length Scaling

Posted on:August 10, 2023 at 12:00 AM

Positional Encoding(PE) is very important for transformers, which is widely used in LLM nowadays. There are many different types of PE, and RoPE is one of them.

RoPE is first proposed in RoFormer1, and applied in many popular transformer models, such as LLaMa2, open-sourced by Meta. This blog we will introduce the basic concept of RoPE, its derivation, and length extrapolation related it.

Table of contents

Open Table of contents

Background

Before we start, let’s review some basic concepts of transformer and positional encoding.

Transformer 101

Given a sequence SNS_N with length NN, tit_i is the ii-th token in the sequence, and xix_i is the d-dim embedding of tit_i. We can formulate SNS_N and ENE_N as:

SN={t1,t2,,tN}EN={x1,x2,,xN}S_N = \{t_1, t_2, \dots, t_N\} \\ E_N = \{x_1, x_2, \dots, x_N\}

Before computing self-attention, we need transform xiENx_i \in E_N to QiQ_i, KiK_i, ViV_i with linear projection but adding extra positional information. We can formulate as:

Qi=fq(xi,pi)Kj=fk(xj,pj)Vj=fv(xj,pj)Q_i = f_q(x_i, p_i) \\ K_j = f_k(x_j, p_j) \\ V_j = f_v(x_j, p_j)

where pi,pjp_i, p_j is the positional information of xi,xjx_i, x_j respectively, and fqf_q, fkf_k, fvf_v are transform functions.

Then we can compute the self-attention scores as:

Attention(Qi,Kj)=exp(QiKjTdk)j=1Nexp(QiKjTdk)Output(Qi)=j=1NAttention(Qi,Kj)Vj\text{Attention}(Q_i, K_j) = \frac{exp(\frac{Q_i K_j^T}{\sqrt{d_k}})}{\sum_{j=1}^{N} exp(\frac{Q_i K_j^T}{\sqrt{d_k}})} \\ \text{Output}(Q_i) = \sum_{j=1}^{N} \text{Attention}(Q_i, K_j) V_j

Positional Encoding in Transformers

Transformers are parallel architectures which means they cannot capture the order of tokens in sequence. Positional encoding comes to rescue. Generally, there are two major approaches:

  1. Fuse positional information into input embedding, which called absolute positional embedding;
  2. Fuse positional information into self-attention scores, which called relative positional embedding.

Absolute Positional Embedding(APE). The most common way is proposed in original transformer paper3, which is adding a fixed positional embedding to input embedding. Periodic function like sine and cosine functions are used to generate positional embedding. The formula is:

PE(pos,2i)=sin(pos/100002id)PE(pos,2i+1)=cos(pos/100002id)PE_{(pos, 2i)} = \sin(pos / 10000^{\frac{2i}{d}}) \\ PE_{(pos, 2i+1)} = \cos(pos / 10000^{\frac{2i}{d}})

where pospos is the position of token, dd is the dimension of embedding, and ii is for computing the index of dimension.

Python code as follows:

# pos: position of token
# seq_len: length of sequence
# d: dimension of embedding

def get_pos_embedding(pos, seq_len, d):
    pos_embedding = np.zeros((seq_len, d))
    for i in range(d):
        if i % 2 == 0:
            # even index using sine
            pos_embedding[:, i] = np.sin(pos / 10000 ** (i / d))
        else:
            # odd index using cosine
            pos_embedding[:, i] = np.cos(pos / 10000 ** ((i - 1) / d))
    return pos_embedding

It’s evident that the characteristic of the sine/cosine positional encoding is periodical, hence it can be expected to have a certain degree of extrapolation. 4

Another common choice is to use learned version of APE, which is a trainable parameter, such as in GPT-35.

Relative Positional Embedding(RPE). Relative position encoding doesn’t model the position information of each token. Instead, it model the relative position when computing self-attention scores.

For example, T5’s Relative bias6 first maps the relative distance (ij)(i-j) between tokens at position i and j to a scalar bias value b=f(ij)b = f(i-j). Then it is added to the dot product of query and key in the self-attention mechanism.

Rotary Positional Embedding(RoPE)

Considering that APE is straightforward and easy to implement, and RPE is more intuitive and effective, RoPE can combine the advantages of both.

Formulation

Given q,kq, k, we can add absolute positional information as following:

q^m=f(q,m)k^n=f(k,n)\widehat{q}_m = f(q, m) \\ \widehat{k}_n = f(k, n)

f(,m)f(\cdot, m) is the function to add positional information to inputs. The equation below needs to be satisfied as attention computed by dot-product:

<f(q,m),f(k,n)>=g(q,k,mn)<f(q, m), f(k, n)> = g(q, k, m-n)

where g(,,)g(\cdot, \cdot, \cdot) is the function to compute self-attention scores. We can assume f(q,0)=qf(q, 0) = q and f(k,0)=kf(k, 0) = k safely.

Considering the 2-d situation and complex field, q=(q1,q2)=q1+iq2q = (q_1, q_2) = q_1 + i * q_2, we can get:

<q,k>=q1k1+q2k2qkˉ=(q1+iq2)(k1ik2)=q1k1+q2k2+i(q1k2+q2k1)<q,k>=Re[qkˉ]\begin{align} <q, k> &= q_1 * k_1 + q_2 * k_2 \\ q*\bar{k} &= (q_1 + i * q_2) * (k_1 - i * k_2) \\ &= q_1 * k_1 + q_2 * k_2 + i * (q_1 * k_2 + q_2 * k_1) \\ <q, k> &= Re[q*\bar{k}] \end{align}

where kˉ=k1ik2\bar{k} = k_1 - i * k_2 is the conjugate of kk and Re[]Re[\cdot] is the real part of complex number.

Then we can get:

<f(q,m),f(k,n)>=Re[f(q,m)fˉ(k,n)]=g(q,k,mn)<f(q, m), f(k, n)> = Re[f(q, m) * \bar{f}(k, n)] = g(q, k, m-n)

We can assume f(q,m)fˉ(k,n)=g(q,k,mn)f(q, m) * \bar{f}(k, n) = g(q, k, m-n) simply.

Using exponential form to represent complex number, we can get:

f(q,m)=Re[f(q,m)]eiθf(q,m)f(k,n)=Re[f(k,n)]eiθf(k,n)f(q,m)fˉ(k,n)=Re[f(q,m)]Re[f(k,n)]ei(θf(q,m)θf(k,n))=Re[g(q,k,mn)]eiθg(q,k,mn)\begin{align} f(q, m) &= Re[f(q, m)] * e^{i * \theta_{f(q, m)}} \\ f(k, n) &= Re[f(k, n)] * e^{i * \theta_{f(k, n)}} \\ f(q, m) * \bar{f}(k, n) &= Re[f(q, m)] * Re[f(k, n)] * e^{i * (\theta_{f(q, m)} - \theta_{f(k, n)})} \\ &= Re[g(q, k, m-n)] * e^{i * \theta_{g(q, k, m-n)}} \end{align}

We can set these two equations equal according to Eq.(7) and Eq.(8):

Re[g(q,k,mn)]=Re[f(q,m)]Re[f(k,n)]θg(q,k,mn)=θf(q,m)θf(k,n)\begin{align} Re[g(q, k, m-n)] &= Re[f(q, m)] * Re[f(k, n)] \\ \theta_{g(q, k, m-n)} &= \theta_{f(q, m)} - \theta_{f(k, n)} \end{align}

To solve Eq.(9) and Eq.(10), we can set m=n=0m=n=0, then we can get:

Re[g(q,k,0)]=Re[f(q,m)]Re[f(k,m)]=Re[f(q,0)]Re[f(k,0)]=Re[q]Re[k]=qkRe[f(q,m)]Re[f(k,n)]=qk\begin{align} Re[g(q, k, 0)] &= Re[f(q, m)] * Re[f(k, m)] \\ &= Re[f(q, 0)] * Re[f(k, 0)] \\ &= Re[q] * Re[k] \\ &= ||q|| * ||k|| \\ Re[f(q, m)] * Re[f(k, n)] &= ||q|| * ||k|| \end{align}

According to Eq.(15), it indicates that Re[f(q,m)]Re[f(q, m)] is irrelevant to mm, which means Re[f(q,m)]Re[f(q, m)] is q||q|| and Re[f(k,n)]Re[f(k, n)] is k||k||.

For θg(q,k,mn)\theta_{g(q, k, m-n)}, we can set m=nm=n, then we can get:

θg(q,k,0)=θf(q,m)θf(k,m)\theta_{g(q, k, 0)} = \theta_{f(q, m)} - \theta_{f(k, m)}

If set m=0m=0, we can get:

θf(q,0)θf(k,0)=θqθk=θg(q,k,0)=θf(q,m)θf(k,m)θf(q,m)θq=θf(k,m)θk\theta_{f(q, 0)} - \theta_{f(k, 0)} = \theta_q - \theta_k = \theta_{g(q, k, 0)} = \theta_{f(q, m)} - \theta_{f(k, m)} \\ \Downarrow \\ \theta_{f(q, m)} - \theta_q = \theta_{f(k, m)} - \theta_k

where θq\theta_q and θk\theta_k are the angles of qq and kk respectively.

It shows that θf(q,m)θq\theta_{f(q, m)} - \theta_q is only related to m because changing qq to kk remain the same value. If set θf(q,m)θq=Φ(m)\theta_{f(q, m)} - \theta_q = \Phi(m), we can get:

Φ(m)Φ(m1)=(θf(q,m)θq)(θf(k,m1)θk)=θf(q,m)θf(k,m1)+θkθq=θg(q,k,1)+θkθq\begin{align} \Phi(m) - \Phi(m-1) &= (\theta_{f(q, m)} - \theta_q) - (\theta_{f(k, m-1)} - \theta_k) \\ &= \theta_{f(q, m)} - \theta_{f(k, m-1)} + \theta_k - \theta_q \\ &= \theta_{g(q, k, 1)} + \theta_k - \theta_q \end{align}

where qq and kk are not related to mm, so Φ(m)\Phi(m) is an arithmetic progressions.

If set Φ(m)=mθ=θf(q,m)θq\Phi(m) = m* \theta = \theta_{f(q, m)} - \theta_q, we can get:

f(q,m)=Re[f(q,m)]eiθf(q,m)=qei(mθ+θq)=qeiθqeimθ=qeimθ\begin{align} f(q, m) &= Re[f(q, m)] * e^{i * \theta_{f(q, m)}} \\ &= ||q|| * e^{i * (m*\theta + \theta_q)} \\ &= ||q|| * e^{i * \theta_q} * e^{i * m\theta} \\ &= q * e^{i * m \theta} \end{align}

For 2-D dimension, we can transform qeimθq * e^{i * m \theta} to the following matrix form:

qeimθ=(cos(mθ)sin(mθ)sin(mθ)cos(mθ))(q1q2)q * e^{i * m \theta} = \begin{pmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{pmatrix} * \begin{pmatrix} q1 \\ q2 \end{pmatrix}

As I said at the beginning, f(,m)f(\cdot, m) is the function to add position information, so it can be considered as qq rotating θ\theta angle. In other words, f(,m)f(\cdot, m) is a rotation function.

General Forms

For higher-dimensional space, we can decompose into block-wise repetition of 2-d rotation simply. Putting all these pieces together, we can get the general form of f(q,m)f(q, m):

f(q,m)=(M1M2Md/2)(q0q1qd1)f(q, m) = \begin{pmatrix} M_1 \\ & M_2 \\ & & \ddots \\ & & & M_{d/2} \end{pmatrix} * \begin{pmatrix} q0 \\ q1 \\ \vdots \\ q_{d-1} \end{pmatrix}

where MiM_i is a 2-d rotation matrix Mi=(cos(mθi)sin(mθi)sin(mθi)cos(mθi))M_i = \begin{pmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{pmatrix}, and θi\theta_i is the angle of ii-th dimension.

Considering the sparse of MM, we can use the element-wise form to computing:

(q0q1qd-2qd-1)(cos(mθ0)sin(mθ0)cos(mθd/21)sin(mθd/21))+(q0q1qd-2qd-1)(sin(mθ0)cos(mθ0)sin(mθd/21)cos(mθd/21))\begin{pmatrix} q_0 \\ q_1 \\ \vdots \\ q_\text{d-2} \\q_\text{d-1} \end{pmatrix} \odot \begin{pmatrix} \cos(m\theta_0) \\ \sin(m\theta_0) \\ \vdots \\ \cos(m\theta_{d/2-1}) \\ \sin(m\theta_{d/2-1}) \end{pmatrix} + \begin{pmatrix} -q_0 \\ q_1 \\ \vdots \\ -q_\text{d-2} \\q_\text{d-1} \end{pmatrix} \odot \begin{pmatrix} -\sin(m\theta_0) \\ \cos(m\theta_0) \\ \vdots \\ -\sin(m\theta_{d/2-1}) \\ \cos(m\theta_{d/2-1}) \end{pmatrix}

where \odot is the element-wise multiplication. We can follow sine/cosine position encoding to set θi=100002i/d\theta_i = 10000^{-2i/d}.

There are two ways to implement RoPE. The first one is to use complex number to represent the rotation matrix, and the second one is to use

# ref: https://github.com/facebookresearch/llama/blob/main/llama/model.py#L64

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Length Extrapolation

Linear Scaling

NTK-aware Scaling

Dynamic NTK Scaling

Reference

Footnotes

  1. RoFormer: Enhanced Transformer with Rotary Position Embedding

  2. LLaMA: Open and Efficient Foundation Language Models

  3. Attention is all you need

  4. 让研究人员绞尽脑汁的 Transformer 位置编码

  5. Language Models are Few-Shot Learners

  6. Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer


Comments