class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

名为LlamaRotaryEmbedding的PyTorch模型,用于旋转位置编码。我们将逐行解释该代码:

  1. class LlamaRotaryEmbedding(nn.Module): 定义一个名为LlamaRotaryEmbedding的类,该类继承自nn.Module

  2. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):

    • 初始化函数,设置默认的最大位置编码数为2048,基数base为10000。
  3. super().__init__():调用父类nn.Module的初始化函数

  4. self.dim = dim: 存储传入的dim到类的属性中。

  5. self.max_position_embeddings = max_position_embeddings: 存储最大的位置编码数。

  6. self.base = base: 存储基数。

  7. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)):

    • 计算逆频率inv_freqtorch.arange(0, self.dim, 2)生成从0到dim-1(不包括)的数字,步长为2。
    • 这个数组表示不同的频率,将它除以dim后再和base的逆次幂做运算,得到逆频率。
  8. self.register_buffer("inv_freq", inv_freq, persistent=False): 注册一个缓冲区,用于存储inv_freq,并确保它不会在保存模型时被视为模型的可训练参数。

  9. 后面几行代码是预先计算并缓存cosine和sine值,以加速前向计算。

  10. def _set_cos_sin_cache(self, seq_len, device, dtype):: 定义一个内部方法用于设置cosine和sine的缓存。

  11. self.max_seq_len_cached = seq_len: 存储传入的序列长度。

  12. t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype): 生成一个从0到self.max_seq_len_cached-1的数字数组。

  13. freqs = torch.einsum("i,j->ij", t, self.inv_freq): 使用外积计算freqs

  14. emb = torch.cat((freqs, freqs), dim=-1): 将freqs与自身进行拼接。

  15. self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False): 计算cosine值并将其缓存。

  16. self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False): 计算sine值并将其缓存。

  17. def forward(self, x, seq_len=None):: 定义前向传播函数。

  18. if seq_len > self.max_seq_len_cached:: 如果输入的序列长度大于缓存的长度,则更新缓存。

  19. return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ): 返回对应长度的cosine和sine缓存值。

总体来说,这是一个用于生成旋转位置编码的模块。其目的是为transformer模型(如BERT、GPT等)生成位置编码。

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐