Positional Encoding(PE) is very important for transformers, which is widely used in LLM nowadays. There are many different types of PE, and RoPE is one of them.

RoPE is first proposed in RoFormer1, and applied in many popular transformer models, such as LLaMa2, open-sourced by Meta. This blog we will introduce the basic concept of RoPE, its derivation, and length extrapolation related it.

Table of contents

Background

Before we start, let’s review some basic concepts of transformer and positional encoding.

Transformer 101

Given a sequence $S_N$ with length $N$, $t_i$ is the $i$-th token in the sequence, and $x_i$ is the d-dim embedding of $t_i$. We can formulate $S_N$ and $E_N$ as:

$$ S_N = \{t_1, t_2, \dots, t_N\} \\ E_N = \{x_1, x_2, \dots, x_N\} $$

Before computing self-attention, we need transform $x_i \in E_N$ to $Q_i$, $K_i$, $V_i$ with linear projection but adding extra positional information. We can formulate as:

$$ Q_i = f_q(x_i, p_i) \\ K_j = f_k(x_j, p_j) \\ V_j = f_v(x_j, p_j) $$

where $p_i, p_j$ is the positional information of $x_i, x_j$ respectively, and $f_q$, $f_k$, $f_v$ are transform functions.

Then we can compute the self-attention scores as:

$$ \text{Attention}(Q_i, K_j) = \frac{exp(\frac{Q_i K_j^T}{\sqrt{d_k}})}{\sum_{j=1}^{N} exp(\frac{Q_i K_j^T}{\sqrt{d_k}})} \\ \text{Output}(Q_i) = \sum_{j=1}^{N} \text{Attention}(Q_i, K_j) V_j $$

Positional Encoding in Transformers

Transformers are parallel architectures which means they cannot capture the order of tokens in sequence. Positional encoding comes to rescue. Generally, there are two major approaches:

  1. Fuse positional information into input embedding, which called absolute positional embedding;
  2. Fuse positional information into self-attention scores, which called relative positional embedding.

Absolute Positional Embedding(APE). The most common way is proposed in original transformer paper3, which is adding a fixed positional embedding to input embedding. Periodic function like sine and cosine functions are used to generate positional embedding. The formula is:

$$ PE_{(pos, 2i)} = \sin(pos / 10000^{\frac{2i}{d}}) \\ PE_{(pos, 2i+1)} = \cos(pos / 10000^{\frac{2i}{d}}) $$

where $pos$ is the position of token, $d$ is the dimension of embedding, and $i$ is for computing the index of dimension.

Python code as follows:

# pos: position of token
# seq_len: length of sequence
# d: dimension of embedding

def get_pos_embedding(pos, seq_len, d):
    pos_embedding = np.zeros((seq_len, d))
    for i in range(d):
        if i % 2 == 0:
            # even index using sine
            pos_embedding[:, i] = np.sin(pos / 10000 ** (i / d))
        else:
            # odd index using cosine
            pos_embedding[:, i] = np.cos(pos / 10000 ** ((i - 1) / d))
    return pos_embedding

It’s evident that the characteristic of the sine/cosine positional encoding is periodical, hence it can be expected to have a certain degree of extrapolation. 4

Another common choice is to use learned version of APE, which is a trainable parameter, such as in GPT-35.

Relative Positional Embedding(RPE). Relative position encoding doesn’t model the position information of each token. Instead, it model the relative position when computing self-attention scores.

For example, T5’s Relative bias6 first maps the relative distance $(i-j)$ between tokens at position i and j to a scalar bias value $b = f(i-j)$. Then it is added to the dot product of query and key in the self-attention mechanism.

Rotary Positional Embedding(RoPE)

Considering that APE is straightforward and easy to implement, and RPE is more intuitive and effective, RoPE can combine the advantages of both.

Formulation

Given $q, k$, we can add absolute positional information as following:

$$ \widehat{q}_m = f(q, m) \\ \widehat{k}_n = f(k, n) $$

$f(\cdot, m)$ is the function to add positional information to inputs. The equation below needs to be satisfied as attention computed by dot-product:

$$ = g(q, k, m-n) $$

where $g(\cdot, \cdot, \cdot)$ is the function to compute self-attention scores. We can assume $f(q, 0) = q$ and $f(k, 0) = k$ safely.

Considering the 2-d situation and complex field, $q = (q_1, q_2) = q_1 + i * q_2$, we can get:

$$ \begin{align} &= q_1 * k_1 + q_2 * k_2 \\ q*\bar{k} &= (q_1 + i * q_2) * (k_1 - i * k_2) \\ &= q_1 * k_1 + q_2 * k_2 + i * (q_1 * k_2 + q_2 * k_1) \\ &= Re[q*\bar{k}] \end{align} $$

where $\bar{k} = k_1 - i * k_2$ is the conjugate of $k$ and $Re[\cdot]$ is the real part of complex number.

Then we can get:

$$ = Re[f(q, m) * \bar{f}(k, n)] = g(q, k, m-n) $$

We can assume $f(q, m) * \bar{f}(k, n) = g(q, k, m-n)$ simply.

Using exponential form to represent complex number, we can get:

$$ \begin{align} f(q, m) &= Re[f(q, m)] * e^{i * \theta_{f(q, m)}} \\ f(k, n) &= Re[f(k, n)] * e^{i * \theta_{f(k, n)}} \\ f(q, m) * \bar{f}(k, n) &= Re[f(q, m)] * Re[f(k, n)] * e^{i * (\theta_{f(q, m)} - \theta_{f(k, n)})} \\ &= Re[g(q, k, m-n)] * e^{i * \theta_{g(q, k, m-n)}} \end{align} $$

We can set these two equations equal according to Eq.(7) and Eq.(8):

$$ \begin{align} Re[g(q, k, m-n)] &= Re[f(q, m)] * Re[f(k, n)] \\ \theta_{g(q, k, m-n)} &= \theta_{f(q, m)} - \theta_{f(k, n)} \end{align} $$

To solve Eq.(9) and Eq.(10), we can set $m=n=0$, then we can get:

$$ \begin{align} Re[g(q, k, 0)] &= Re[f(q, m)] * Re[f(k, m)] \\ &= Re[f(q, 0)] * Re[f(k, 0)] \\ &= Re[q] * Re[k] \\ &= ||q|| * ||k|| \\ Re[f(q, m)] * Re[f(k, n)] &= ||q|| * ||k|| \end{align} $$

According to Eq.(15), it indicates that $Re[f(q, m)]$ is irrelevant to $m$, which means $Re[f(q, m)]$ is $||q||$ and $Re[f(k, n)]$ is $||k||$.

For $\theta_{g(q, k, m-n)}$, we can set $m=n$, then we can get:

$$ \theta_{g(q, k, 0)} = \theta_{f(q, m)} - \theta_{f(k, m)} $$

If set $m=0$, we can get:

$$ \theta_{f(q, 0)} - \theta_{f(k, 0)} = \theta_q - \theta_k = \theta_{g(q, k, 0)} = \theta_{f(q, m)} - \theta_{f(k, m)} \\ \Downarrow \\ \theta_{f(q, m)} - \theta_q = \theta_{f(k, m)} - \theta_k $$

where $\theta_q$ and $\theta_k$ are the angles of $q$ and $k$ respectively.

It shows that $\theta_{f(q, m)} - \theta_q$ is only related to m because changing $q$ to $k$ remain the same value. If set $\theta_{f(q, m)} - \theta_q = \Phi(m)$, we can get:

$$ \begin{align} \Phi(m) - \Phi(m-1) &= (\theta_{f(q, m)} - \theta_q) - (\theta_{f(k, m-1)} - \theta_k) \\ &= \theta_{f(q, m)} - \theta_{f(k, m-1)} + \theta_k - \theta_q \\ &= \theta_{g(q, k, 1)} + \theta_k - \theta_q \end{align} $$

where $q$ and $k$ are not related to $m$, so $\Phi(m)$ is an arithmetic progressions.

If set $\Phi(m) = m* \theta = \theta_{f(q, m)} - \theta_q$, we can get:

$$ \begin{align} f(q, m) &= Re[f(q, m)] * e^{i * \theta_{f(q, m)}} \\ &= ||q|| * e^{i * (m*\theta + \theta_q)} \\ &= ||q|| * e^{i * \theta_q} * e^{i * m\theta} \\ &= q * e^{i * m \theta} \end{align} $$

For 2-D dimension, we can transform $q * e^{i * m \theta}$ to the following matrix form:

$$ q * e^{i * m \theta} = \begin{pmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{pmatrix} * \begin{pmatrix} q1 \\ q2 \end{pmatrix} $$

As I said at the beginning, $f(\cdot, m)$ is the function to add position information, so it can be considered as $q$ rotating $\theta$ angle. In other words, $f(\cdot, m)$ is a rotation function.

General Forms

For higher-dimensional space, we can decompose into block-wise repetition of 2-d rotation simply. Putting all these pieces together, we can get the general form of $f(q, m)$:

$$ f(q, m) = \begin{pmatrix} M_1 \\ & M_2 \\ & & \ddots \\ & & & M_{d/2} \end{pmatrix} * \begin{pmatrix} q0 \\ q1 \\ \vdots \\ q_{d-1} \end{pmatrix} $$

where $M_i$ is a 2-d rotation matrix $M_i = \begin{pmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{pmatrix}$, and $\theta_i$ is the angle of $i$-th dimension.

Considering the sparse of $M$, we can use the element-wise form to computing:

$$ \begin{pmatrix} q_0 \\ q_1 \\ \vdots \\ q_\text{d-2} \\q_\text{d-1} \end{pmatrix} \odot \begin{pmatrix} \cos(m\theta_0) \\ \sin(m\theta_0) \\ \vdots \\ \cos(m\theta_{d/2-1}) \\ \sin(m\theta_{d/2-1}) \end{pmatrix} + \begin{pmatrix} -q_0 \\ q_1 \\ \vdots \\ -q_\text{d-2} \\q_\text{d-1} \end{pmatrix} \odot \begin{pmatrix} -\sin(m\theta_0) \\ \cos(m\theta_0) \\ \vdots \\ -\sin(m\theta_{d/2-1}) \\ \cos(m\theta_{d/2-1}) \end{pmatrix} $$

where $\odot$ is the element-wise multiplication. We can follow sine/cosine position encoding to set $\theta_i = 10000^{-2i/d}$.

There are two ways to implement RoPE. The first one is to use complex number to represent the rotation matrix, and the second one is to use

# ref: https://github.com/facebookresearch/llama/blob/main/llama/model.py#L64

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Length Extrapolation

Linear Scaling

NTK-aware Scaling

Dynamic NTK Scaling

Reference