Transformers源码解析:transformers/src/transformers/models/llama/modeling_llama.py RotaryEmbedding
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
名为LlamaRotaryEmbedding的PyTorch模型,用于旋转位置编码。我们将逐行解释该代码:
-
class LlamaRotaryEmbedding(nn.Module):定义一个名为LlamaRotaryEmbedding的类,该类继承自nn.Module。 -
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):- 初始化函数,设置默认的最大位置编码数为2048,基数
base为10000。
- 初始化函数,设置默认的最大位置编码数为2048,基数
-
super().__init__():调用父类nn.Module的初始化函数。 -
self.dim = dim: 存储传入的dim到类的属性中。 -
self.max_position_embeddings = max_position_embeddings: 存储最大的位置编码数。 -
self.base = base: 存储基数。 -
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)):- 计算逆频率
inv_freq,torch.arange(0, self.dim, 2)生成从0到dim-1(不包括)的数字,步长为2。 - 这个数组表示不同的频率,将它除以
dim后再和base的逆次幂做运算,得到逆频率。
- 计算逆频率
-
self.register_buffer("inv_freq", inv_freq, persistent=False): 注册一个缓冲区,用于存储inv_freq,并确保它不会在保存模型时被视为模型的可训练参数。 -
后面几行代码是预先计算并缓存cosine和sine值,以加速前向计算。
-
def _set_cos_sin_cache(self, seq_len, device, dtype):: 定义一个内部方法用于设置cosine和sine的缓存。 -
self.max_seq_len_cached = seq_len: 存储传入的序列长度。 -
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype): 生成一个从0到self.max_seq_len_cached-1的数字数组。 -
freqs = torch.einsum("i,j->ij", t, self.inv_freq): 使用外积计算freqs。 -
emb = torch.cat((freqs, freqs), dim=-1): 将freqs与自身进行拼接。 -
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False): 计算cosine值并将其缓存。 -
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False): 计算sine值并将其缓存。 -
def forward(self, x, seq_len=None):: 定义前向传播函数。 -
if seq_len > self.max_seq_len_cached:: 如果输入的序列长度大于缓存的长度,则更新缓存。 -
return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ): 返回对应长度的cosine和sine缓存值。
总体来说,这是一个用于生成旋转位置编码的模块。其目的是为transformer模型(如BERT、GPT等)生成位置编码。
更多推荐


所有评论(0)