RoPE旋转编码已经成为大模型的基础建设,RoPE区别于传统的绝对位置编码,为q,k向量左乘一个旋转变量,将q,k的绝对位置转换为相对位置,使不同位置的q,k进行计算后拥有相对位置信息。 绝对位置编码:(seq_len, hidden_size) theta的分子为seq_len,分母由hidden_size控制。 同一个位置的位置编码的角度越来越小。 Self-Attention计算公式: 我们希望计算Slef-Attention时,q,k拥有绝对位置,q^T k的计算可以包含相对位置信息,故我们定义一个可以实现的范式: 我们找到一个符合上述范式的函数g(): fq(xm, m)在2d状态下: fq(xm, m)在N维度状态下: fq(xm, m)在N维度状态下简便运算: 故位置i和位置j的Attention,除了本身i,j的向量信息之外,还包含了相对位置信息。

1 LLaMa实现

一般RoPE的实现都分为两步:
  • 1 生成不同seq位置不同维度的theta,(seq_len, hidden_size//2)
  • 2 将对应的x与theta按照RoPE的简便计算方法进行相乘。
LLaMa的第2步利用了虚数的计算方法。
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)
import torch

def precompute_freqs(dim, end, theta):
    freqs = 1 / (theta ** torch.arange(0, dim, 2)[: dim//2].float() / dim)
    t = torch.arange(end)
    freqs = torch.outer(t, freqs)
    # seq_len, hidden/2
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def applay_rotary_emb(xq, xk, freqs_cis):
    # batch_size, seq_len, head_num, hidden/2 , 2 -> batch_size, seq_len, head_num, hidden/2
    xq = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2))
    xk = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2))
    # seq_len, hidden/2 -> 1, seq_len, 1, hidden/2
    freqs_cis = reshape_broadcast(freqs_cis)
    # batch_size, seq_len, head_num, hidden/2 -> batch_size, seq_len, head_num, hidden/2, 2 -> batch_size, seq_len, head_num, hidden
    xq_out = torch.view_as_real(xq * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk * freqs_cis).flatten(3)
    return xq_out, xk_out

2 ChatGLM3 实现

chatglm使纯粹的按照RoPE论文中的计算方式进行计算。
def forward_impl(
        self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
):
    # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))

    # Create position indexes `[0, 1, ..., seq_len - 1]`
    seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)

    # Calculate the product of position index and $\theta_i$
    idx_theta = torch.outer(seq_idx, theta).float()

    cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

    # this is to mimic the behaviour of complex32, else we will get different results
    if dtype in (torch.float16, torch.bfloat16, torch.int8):
        cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
    return cache

def forward(self, max_seq_len, offset=0):
    return self.forward_impl(
        max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
    )


def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    rot_dim = rope_cache.shape[-2] * 2
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # truncate to support variable sizes
    rope_cache = rope_cache[:sq]
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)