transformer模型详解 pytorch的nn.Transformer模块实战

PyTorch中的nn.Transformer模块 nn.Transformer是PyTorch内置的Transformer模型实现,基于标准Transformer架构,适用于机器翻译、文本生成等序列任务,提供了编码器、解码器、注意力机制等核心组件。

一、核心模块组件

nn.Transformer包含以下关键组件:

组件类 功能说明
nn.TransformerEncoder 编码器模块:将输入序列编码为隐藏表示,由多个TransformerEncoderLayer成
nn.TransformerDecoder 解码器模块:从隐藏表示生成输出序列,由多个TransformerDecoderLayer组成
nn.TransformerEncoderLayer 编码器基本单元:包含自注意力机制+前馈网络,支持归一化、Dropout
nn.TransformerDecoderLayer 解码器基本单元:包含自注意力+编码器-解码器注意力+前馈网络 nn.MultiheadAttention 多头注意力机制:将输入分多个头独立计算注意力,再合并结果
nn.Transformer 完整Transformer模型:由编码器+解码器组成,接收输入/目标序列并输出预测

二、__call__函数的作用

PyTorch中nn.Module类实现了__call__方法,因此创建Transformer实例后,可直接通过model(input_data)调用前向传播(无需显式调用forward()):

  # model(input_data) 等价于 model.forward(input_data)
  # 前者更简洁,是PyTorch中推荐的写法
  output = model(src_seq, tgt_seq) 

三、最简单的标准Transformer模型

通过nn.Transformer()可快速创建默认配置的Transformer模型(6层编码器+6层解码器)

import torch
import torch.nn as nn# 创建默认配置的Transformer模型
model = nn.Transformer()
# 打印模型结构
print(model) 

输出的模型结构示例:

Transformer(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (dropout3): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
)

四、纯解码器模型实现(以Llama风格为例)

很多大模型(如Llama)采用纯解码器架构
nn.TransformerDecoder的代码实战:

import torch
import torch.nn as nn

class TransformerDecoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,     # 词汇表大小
        emb_size: int,       # 词嵌入维度
        hidden_size: int,    # 隐藏层维度(前馈网络维度)
        num_layers: int,     # 解码器层数
        num_heads: int,      # 注意力头数
        dropout: float       # Dropout比例
    ):
        super().__init__()
        # 1. 词嵌入层:将token索引映射为向量
        self.embedding = nn.Embedding(vocab_size, emb_size)#传入词汇表大小和词嵌入纬度
        
        # 2. Transformer解码器:由多个解码器层组成
        self.decoder = nn.TransformerDecoder(
            # 单个解码器层配置
            nn.TransformerDecoderLayer(
                d_model=emb_size,            # 输入/输出维度(需与词嵌入维度一致)
                nhead=num_heads,             # 注意力头数
                dim_feedforward=hidden_size, # 前馈网络隐藏层维度
                dropout=dropout,             # Dropout比例
                batch_first=True             # 启用batch_first=True(与forward中输入格式一致)
            ),
            num_layers=num_layers,           # 解码器层数
            norm=nn.LayerNorm(emb_size)      # 最终归一化层(维度需与d_model一致)
        )
        
        # 3. 输出层:将解码器输出映射为词汇表概率分布
        self.fc = nn.Linear(emb_size, vocab_size)  # 输入维度是emb_size(与d_model一致)

    def forward(
        self,
        trg: torch.Tensor,          # 目标序列:[batch_size, trg_len](因batch_first=True)
        memory: torch.Tensor=None,  # 编码器输出(纯解码器时可忽略):[batch_size, src_len, emb_size]
        trg_mask: torch.Tensor=None, # 目标序列掩码:[trg_len, trg_len]
        memory_mask: torch.Tensor=None  # 编码器-解码器注意力掩码:[trg_len, src_len]
    ):
        # 步骤1:词嵌入(无需转置,因batch_first=True已在层中配置)
        trg_emb = self.embedding(trg)  # [batch_size, trg_len, emb_size]
        
        # 步骤2:解码器前向传播
        output = self.decoder(
            tgt=trg_emb,
            memory=memory,
            tgt_mask=trg_mask,
            memory_mask=memory_mask
        )  # 输出:[batch_size, trg_len, emb_size]
        
        # 步骤3:输出层映射为词汇表概率
        output = self.fc(output)  # [batch_size, trg_len, vocab_size]
        return output

# 模型实例化(以Llama 2小配置为例)
model = TransformerDecoder(
    vocab_size=32000,    # 词汇表大小
    emb_size=512,        # 词嵌入维度(与d_model一致)
    hidden_size=1024,    # 前馈网络隐藏层维度
    num_layers=6,        # 解码器层数
    num_heads=4,         # 注意力头数(需能整除emb_size:512 ÷ 4 = 128,合法)
    dropout=0.1          # Dropout比例
)

print(model)

输出出来是这样喵

TransformerDecoder(
  (embedding): Embedding(32000, 512)
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (dropout3): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (fc): Linear(in_features=512, out_features=32000, bias=True)

这就是解码器的实现啦喵,我做了比较详细的注释,希望你能看得懂喵

五、Llama 2风格模型简化实现(含RMSNorm、旋转位置编码)

芝士适配Llama 2的简化模型(替换并行层为标准层,适配CPU运行)

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Tuple

# -------------------------- 模型配置类 --------------------------
@dataclass
class ModelArgs:
    """
    模型超参数配置类(基于Llama 2架构简化)
    用dataclass自动生成构造函数,方便参数管理
    """
    dim: int = 4096                  # 模型核心维度(隐藏层/词嵌入维度)
    n_layers: int = 6                # Transformer块层数(原Llama 2为32层,此处简化为6层)
    n_heads: int = 32                # 注意力头数(查询Q的头数)
    n_kv_heads: Optional[int] = None # K/V注意力头数(默认与n_heads相同,可设更小以节省显存)
    vocab_size: int = 32000          # 词汇表大小(Llama 2默认32000)
    multiple_of: int = 256           # SwiGLU前馈网络隐藏层维度对齐因子(确保维度是该值的倍数)
    ffn_dim_multiplier: Optional[float] = None  # 前馈网络维度乘数(覆盖默认计算逻辑)
    norm_eps: float = 1e-5           # RMSNorm归一化层的epsilon(避免除零)
    max_batch_size: int = 32         # 最大批次大小(用于初始化K/V缓存)
    max_seq_len: int = 2048          # 最大序列长度(用于初始化K/V缓存和位置编码)

# -------------------------- RMSNorm归一化层 --------------------------
class RMSNorm(nn.Module):
    """
     Root Mean Square Normalization(均方根归一化)
    相比LayerNorm,RMSNorm只归一化,不中心化(移除均值),计算更快、更稳定
    公式:output = (x / sqrt(E[x²] + eps)) * weight
    """
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps  # 数值稳定性参数
        self.weight = nn.Parameter(torch.ones(dim))  # 可学习的缩放参数(保持表达能力)

    def norm(self, x: torch.Tensor) -> torch.Tensor:
        """核心归一化计算:逐样本沿最后一维计算RMS"""
        # x.pow(2): 元素平方 -> mean(-1, keepdim=True): 最后一维求均值并保持维度
        # torch.rsqrt: 求倒数平方根 -> 等价于 1/sqrt(...)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播:归一化 + 可学习缩放"""
        # 转为float计算(避免低精度下的数值问题),再转回原数据类型
        output = self.norm(x.float()).type_as(x)
        return output * self.weight  # 缩放权重

# -------------------------- 旋转位置编码(RoPE) --------------------------
def precompute_freqs_cis(
    dim: int, end: int, theta: float = 10000.0
) -> torch.Tensor:
    """
    预计算旋转位置编码的复数形式(Complex Exponential)
    用于RoPE:将位置信息编码到Q/K的相位中,保持相对位置不变性
    Args:
        dim: 每个注意力头的维度(必须是偶数)
        end: 最大序列长度(预计算到该长度的编码)
        theta: 频率基数(默认10000,遵循原始Transformer位置编码)
    Returns:
        freqs_cis: [end, dim//2] 的复数Tensor(模长=1,相位随位置和维度变化)
    """
    # 计算频率:theta^(2i/dim) 的倒数,i是维度索引(0,1,...,dim//2-1)
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # 序列位置 [0,1,...,end-1]
    freqs = torch.outer(t, freqs).float()       # 位置×频率:[end, dim//2]
    # 转为复数形式:polar(模长, 相位) -> 模长=1,相位=freqs
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    调整freqs_cis的形状,使其能与Q/K进行广播乘法
    Args:
        freqs_cis: 预计算的位置编码 [seq_len, dim//2]
        x: Q或K [batch_size, seq_len, n_heads, head_dim]
    Returns:
        重塑后的freqs_cis,适配x的维度
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim, "x的维度必须至少为2(batch_size, seq_len, ...)"
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), "freqs_cis形状不匹配"
    # 构建广播形状:除了seq_len(第1维)和head_dim(最后1维),其余维度设为1
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    对查询Q和键K应用旋转位置编码
    原理:将Q/K的实数值转为复数,与位置编码复数相乘(旋转相位),再转回实数
    Args:
        xq: 查询向量 [batch_size, seq_len, n_heads, head_dim]
        xk: 键向量 [batch_size, seq_len, n_kv_heads, head_dim]
        freqs_cis: 预计算的位置编码 [seq_len, head_dim//2]
    Returns:
        应用RoPE后的Q和K
    """
    # 将Q/K重塑为复数形式:[..., head_dim] -> [..., head_dim//2, 2](实部+虚部)
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    # 调整freqs_cis形状以适配广播
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    # 复数乘法:旋转相位(模长不变,相位相加)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  # 转回实数并展平最后两维
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    # 转回原数据类型
    return xq_out.type_as(xq), xk_out.type_as(xk)

# -------------------------- 前馈网络(SwiGLU) --------------------------
class FeedForward(nn.Module):
    """
    基于SwiGLU激活函数的前馈网络(Llama 2使用的FFN结构)
    结构:Linear -> SwiGLU -> Dropout -> Linear
    其中中间层维度由multiple_of和ffn_dim_multiplier共同决定
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        # 计算中间层维度:默认是dim * 4,再按multiple_of对齐
        hidden_dim = 4 * args.dim
        if args.ffn_dim_multiplier is not None:
            hidden_dim = int(args.ffn_dim_multiplier * args.dim)
        # 确保中间层维度是multiple_of的整数倍(硬件优化)
        hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
        
        self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)  # 第一层投影
        self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)  # 输出投影
        self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)  # 门控投影(SwiGLU特有)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播:x -> w1(x) * sigmoid(w3(x)) -> w2(...)"""
        # SwiGLU激活:(w1(x) * sigmoid(w3(x))),相比ReLU更平滑、表达能力更强
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

# -------------------------- 多头注意力层 --------------------------
class Attention(nn.Module):
    """
    多头注意力层(支持K/V头数少于Q头数,即分组注意力)
    包含K/V缓存机制,用于推理时加速(避免重复计算历史序列的K/V)
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        # K/V头数:默认与Q头数相同,可单独设置以减少计算量
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_local_heads = args.n_heads          # Q的注意力头数
        self.n_local_kv_heads = self.n_kv_heads    # K/V的注意力头数
        self.n_rep = self.n_local_heads // self.n_local_kv_heads  # K/V头重复次数(对齐Q头数)
        self.head_dim = args.dim // args.n_heads   # 每个注意力头的维度
        
        # Q/K/V投影层:将输入dim映射到(头数×头维度)
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        
        # 输出投影层:将多头注意力结果映射回模型dim
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
        
        # K/V缓存:存储历史序列的K和V,形状[max_batch_size, max_seq_len, n_kv_heads, head_dim]
        # 用于推理时增量计算(只计算当前token的Q与历史K/V的注意力)
        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)
        )
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)
        )

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        前向传播:计算多头注意力
        Args:
            x: 输入张量 [batch_size, seq_len, dim]
            start_pos: 当前序列在缓存中的起始位置(推理时用)
            freqs_cis: 旋转位置编码 [seq_len, head_dim//2]
            mask: 注意力掩码 [seq_len, cache_len](避免关注未来token)
        Returns:
            注意力输出 [batch_size, seq_len, dim]
        """
        bsz, seqlen, _ = x.shape  # batch_size, 当前序列长度, 模型dim

        # 1. Q/K/V投影 + 形状调整:[batch_size, seq_len, 头数×头维度] -> [batch_size, seq_len, 头数, 头维度]
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # 2. 对Q/K应用旋转位置编码(RoPE)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        # 3. 更新K/V缓存:将当前序列的K/V存入缓存(覆盖对应位置)
        self.cache_k = self.cache_k.to(xq.device)  # 确保缓存与输入在同一设备
        self.cache_k[:bsz, start_pos:start_pos+seqlen] = xk
        self.cache_v = self.cache_v.to(xq.device)
        self.cache_v[:bsz, start_pos:start_pos+seqlen] = xv

        # 4. 取出缓存中的所有K/V(包含历史序列 + 当前序列)
        keys = self.cache_k[:bsz, :start_pos+seqlen]  # [bsz, cache_len, n_kv_heads, head_dim]
        values = self.cache_v[:bsz, :start_pos+seqlen]

        # 5. 重复K/V头:若K/V头数 < Q头数,通过重复对齐(分组注意力机制)
        keys = keys.repeat_interleave(self.n_rep, dim=2)  # [bsz, cache_len, n_heads, head_dim]
        values = values.repeat_interleave(self.n_rep, dim=2)

        # 6. 调整维度顺序:为注意力计算做准备
        # [batch_size, 头数, seq_len, 头维度](适配torch.matmul的广播)
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)  # [batch_size, 头数, cache_len, 头维度]
        values = values.transpose(1, 2)  # [batch_size, 头数, cache_len, 头维度]

        # 7. 计算注意力分数:Q @ K^T / sqrt(head_dim)(缩放点积注意力)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        # 应用掩码:屏蔽未来token(推理时)或无效位置
        if mask is not None:
            scores = scores + mask  # 未来位置设为-inf,softmax后概率为0

        # 8. 计算注意力权重(softmax归一化)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        # 9. 注意力加权求和:权重 @ values
        output = torch.matmul(scores, values)  # [batch_size, n_heads, seq_len, head_dim]

        # 10. 维度调整 + 输出投影:合并多头结果,映射回模型dim
        output = output.transpose(1, 2).contiguous()  # [batch_size, seq_len, n_heads, head_dim]
        output = output.view(bsz, seqlen, -1)  # 合并头数和头维度:[batch_size, seq_len, n_heads×head_dim]
        output = self.wo(output)  # 投影到模型dim:[batch_size, seq_len, dim]

        return output

# -------------------------- Transformer块(解码器层) --------------------------
class TransformerBlock(nn.Module):
    """
    单个Transformer解码器块(Llama 2架构)
    结构:预归一化(RMSNorm)-> 多头注意力 -> 残差连接 ->
          预归一化(RMSNorm)-> 前馈网络 -> 残差连接
    """
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.layer_id = layer_id  # 层索引(用于调试/扩展)
        self.args = args

        # 注意力层前的归一化
        self.norm_1 = RMSNorm(args.dim, eps=args.norm_eps)
        # 多头注意力层
        self.attn = Attention(args)
        # 前馈网络前的归一化
        self.norm_2 = RMSNorm(args.dim, eps=args.norm_eps)
        # 前馈网络(SwiGLU)
        self.feed_forward = FeedForward(args)

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        前向传播(预归一化+残差连接)
        Args:
            x: 输入 [batch_size, seq_len, dim]
            start_pos: 缓存起始位置
            freqs_cis: 旋转位置编码
            mask: 注意力掩码
        Returns:
            块输出 [batch_size, seq_len, dim]
        """
        # 注意力分支:归一化 -> 注意力 -> 残差连接
        h = x + self.attn(self.norm_1(x), start_pos, freqs_cis, mask)
        # 前馈网络分支:归一化 -> 前馈网络 -> 残差连接
        out = h + self.feed_forward(self.norm_2(h))
        return out

# -------------------------- 完整Transformer解码器模型 --------------------------
class Transformer(nn.Module):
    """
    完整的Transformer解码器模型(Llama 2核心架构)
    结构:词嵌入 -> 多层TransformerBlock -> 最终归一化 -> 输出投影
    """
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size

        # 1. 词嵌入层:将token索引映射为模型dim的向量
        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

        # 2. 构建多层Transformer块(ModuleList支持索引和迭代)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        # 3. 最终归一化层(所有块之后)
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)

        # 4. 输出投影层:将模型dim映射为词汇表大小(预测下一个token的概率)
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

        # 5. 预计算旋转位置编码(只计算一次,推理时复用)
        # 编码维度 = 每个头的维度(params.dim // params.n_heads)
        self.freqs_cis = precompute_freqs_cis(
            self.params.dim // self.params.n_heads, self.params.max_seq_len
        )

    @torch.inference_mode()  # 推理模式:禁用梯度计算,节省显存
    def forward(self, tokens: torch.Tensor, start_pos: int) -> torch.Tensor:
        """
        模型前向传播(推理阶段)
        Args:
            tokens: 输入token序列 [batch_size, seq_len](整数索引)
            start_pos: 当前序列在缓存中的起始位置(0表示全新序列,>0表示增量推理)
        Returns:
            logits: 每个位置的词汇表概率分布 [batch_size, seq_len, vocab_size]
        """
        bsz, seqlen = tokens.shape  # batch_size, 输入序列长度

        # 1. 词嵌入:[batch_size, seq_len] -> [batch_size, seq_len, dim]
        h = self.tok_embeddings(tokens)

        # 2. 加载预计算的位置编码(确保与输入在同一设备)
        self.freqs_cis = self.freqs_cis.to(h.device)
        # 截取当前序列对应的位置编码(start_pos 到 start_pos+seqlen)
        freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]

        # 3. 生成注意力掩码(仅当序列长度>1时,屏蔽未来token)
        mask = None
        if seqlen > 1:
            # 创建上三角掩码:[seq_len, seq_len],上三角为-inf(屏蔽未来)
            mask = torch.full(
                (seqlen, seqlen), float("-inf"), device=tokens.device
            )
            mask = torch.triu(mask, diagonal=1)  # diagonal=1表示不屏蔽当前位置

            # 拼接历史序列的掩码:历史位置允许关注(全0),当前序列用原掩码
            # 最终掩码形状:[seq_len, start_pos + seqlen](cache_len = start_pos + seqlen)
            mask = torch.hstack([
                torch.zeros((seqlen, start_pos), device=tokens.device),
                mask
            ]).type_as(h)

        # 4. 逐层通过Transformer块
        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)

        # 5. 最终归一化 + 输出投影(得到logits)
        h = self.norm(h)
        output = self.output(h).float()  # 转为float避免低精度数值问题

        return output

# -------------------------- 模型实例化 --------------------------
if __name__ == "__main__":
    # 初始化模型配置
    model_args = ModelArgs()
    model_args.vocab_size = 32000  # 显式设置词汇表大小(与默认一致,此处为演示)
    
    # 创建模型实例
    model = Transformer(model_args)
    
    # 打印模型结构(查看各层配置和参数数量)
    print("模型结构:")
    print(model)
    
    # 计算模型参数量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\n模型总参数量:{total_params / 1e9:.2f}B" if total_params > 1e9 else f"\n模型总参数量:{total_params / 1e6:.2f}M")


输出是这样喵:

transformer(
  (tok_embeddings): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-5): 6 x TransformerBlock(
      (norm_1): RMSNorm()
      (attn): Attention(
        (wq): Linear(in_features=4096, out_features=4096, bias=False)
        (wk): Linear(in_features=4096, out_features=4096, bias=False)
        (wv): Linear(in_features=4096, out_features=4096, bias=False)
        (wo): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (norm_2): RMSNorm()
      (feed_forward): FeedForward(
        (w1): Linear(in_features=4096, out_features=16384, bias=False)
        (w2): Linear(in_features=16384, out_features=4096, bias=False)
        (w3): Linear(in_features=4096, out_features=16384, bias=False)
      )
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=4096, out_features=32000, bias=False)
)

模型总参数量:1.87B

我在代码中做了比较详细的注释,希望能读者帮助理解喵

关键设计亮点(也是 Llama 2 的核心优化):

  1. 纯解码器结构(仅用自注意力,无编码器 - 解码器注意力),专注文本生成;
  2. 预归一化(RMSNorm)+ 残差连接,训练更稳定、收敛更快;
  3. 旋转位置编码(RoPE),保持相对位置不变性,支持更长序列;
  4. SwiGLU 前馈网络,比 ReLU 激活更高效,表达能力更强;
  5. 支持 K/V 头数与 Q 头数分离(分组注意力),节省显存;
  6. K/V 缓存,推理时加速增量生成(不用重复计算历史序列)

模型代码文件我会单独放到资源里面的,免费开源,可以自行提取,喜欢我带来的文章的小伙伴别忘了赏点小鱼干喵~

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐