transformer模型详解 手把手教你做个llama模型 Transformer模块实战
·
transformer模型详解 pytorch的nn.Transformer模块实战
PyTorch中的nn.Transformer模块 nn.Transformer是PyTorch内置的Transformer模型实现,基于标准Transformer架构,适用于机器翻译、文本生成等序列任务,提供了编码器、解码器、注意力机制等核心组件。
一、核心模块组件
nn.Transformer包含以下关键组件:
| 组件类 | 功能说明 |
|---|---|
| nn.TransformerEncoder | 编码器模块:将输入序列编码为隐藏表示,由多个TransformerEncoderLayer成 |
| nn.TransformerDecoder | 解码器模块:从隐藏表示生成输出序列,由多个TransformerDecoderLayer组成 |
| nn.TransformerEncoderLayer | 编码器基本单元:包含自注意力机制+前馈网络,支持归一化、Dropout |
| nn.TransformerDecoderLayer | 解码器基本单元:包含自注意力+编码器-解码器注意力+前馈网络 nn.MultiheadAttention 多头注意力机制:将输入分多个头独立计算注意力,再合并结果 |
| nn.Transformer | 完整Transformer模型:由编码器+解码器组成,接收输入/目标序列并输出预测 |
二、__call__函数的作用
PyTorch中nn.Module类实现了__call__方法,因此创建Transformer实例后,可直接通过model(input_data)调用前向传播(无需显式调用forward()):
# model(input_data) 等价于 model.forward(input_data)
# 前者更简洁,是PyTorch中推荐的写法
output = model(src_seq, tgt_seq)
三、最简单的标准Transformer模型
通过nn.Transformer()可快速创建默认配置的Transformer模型(6层编码器+6层解码器)
import torch
import torch.nn as nn# 创建默认配置的Transformer模型
model = nn.Transformer()
# 打印模型结构
print(model)
输出的模型结构示例:
Transformer(
(encoder): TransformerEncoder(
(layers): ModuleList(
(0-5): 6 x TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=2048, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=2048, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
)
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(decoder): TransformerDecoder(
(layers): ModuleList(
(0-5): 6 x TransformerDecoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=2048, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=2048, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
(dropout3): Dropout(p=0.1, inplace=False)
)
)
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
)
四、纯解码器模型实现(以Llama风格为例)
很多大模型(如Llama)采用纯解码器架构
nn.TransformerDecoder的代码实战:
import torch
import torch.nn as nn
class TransformerDecoder(nn.Module):
def __init__(
self,
vocab_size: int, # 词汇表大小
emb_size: int, # 词嵌入维度
hidden_size: int, # 隐藏层维度(前馈网络维度)
num_layers: int, # 解码器层数
num_heads: int, # 注意力头数
dropout: float # Dropout比例
):
super().__init__()
# 1. 词嵌入层:将token索引映射为向量
self.embedding = nn.Embedding(vocab_size, emb_size)#传入词汇表大小和词嵌入纬度
# 2. Transformer解码器:由多个解码器层组成
self.decoder = nn.TransformerDecoder(
# 单个解码器层配置
nn.TransformerDecoderLayer(
d_model=emb_size, # 输入/输出维度(需与词嵌入维度一致)
nhead=num_heads, # 注意力头数
dim_feedforward=hidden_size, # 前馈网络隐藏层维度
dropout=dropout, # Dropout比例
batch_first=True # 启用batch_first=True(与forward中输入格式一致)
),
num_layers=num_layers, # 解码器层数
norm=nn.LayerNorm(emb_size) # 最终归一化层(维度需与d_model一致)
)
# 3. 输出层:将解码器输出映射为词汇表概率分布
self.fc = nn.Linear(emb_size, vocab_size) # 输入维度是emb_size(与d_model一致)
def forward(
self,
trg: torch.Tensor, # 目标序列:[batch_size, trg_len](因batch_first=True)
memory: torch.Tensor=None, # 编码器输出(纯解码器时可忽略):[batch_size, src_len, emb_size]
trg_mask: torch.Tensor=None, # 目标序列掩码:[trg_len, trg_len]
memory_mask: torch.Tensor=None # 编码器-解码器注意力掩码:[trg_len, src_len]
):
# 步骤1:词嵌入(无需转置,因batch_first=True已在层中配置)
trg_emb = self.embedding(trg) # [batch_size, trg_len, emb_size]
# 步骤2:解码器前向传播
output = self.decoder(
tgt=trg_emb,
memory=memory,
tgt_mask=trg_mask,
memory_mask=memory_mask
) # 输出:[batch_size, trg_len, emb_size]
# 步骤3:输出层映射为词汇表概率
output = self.fc(output) # [batch_size, trg_len, vocab_size]
return output
# 模型实例化(以Llama 2小配置为例)
model = TransformerDecoder(
vocab_size=32000, # 词汇表大小
emb_size=512, # 词嵌入维度(与d_model一致)
hidden_size=1024, # 前馈网络隐藏层维度
num_layers=6, # 解码器层数
num_heads=4, # 注意力头数(需能整除emb_size:512 ÷ 4 = 128,合法)
dropout=0.1 # Dropout比例
)
print(model)
输出出来是这样喵
TransformerDecoder(
(embedding): Embedding(32000, 512)
(decoder): TransformerDecoder(
(layers): ModuleList(
(0-5): 6 x TransformerDecoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=1024, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=1024, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
(dropout3): Dropout(p=0.1, inplace=False)
)
)
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(fc): Linear(in_features=512, out_features=32000, bias=True)
这就是解码器的实现啦喵,我做了比较详细的注释,希望你能看得懂喵
五、Llama 2风格模型简化实现(含RMSNorm、旋转位置编码)
芝士适配Llama 2的简化模型(替换并行层为标准层,适配CPU运行)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Tuple
# -------------------------- 模型配置类 --------------------------
@dataclass
class ModelArgs:
"""
模型超参数配置类(基于Llama 2架构简化)
用dataclass自动生成构造函数,方便参数管理
"""
dim: int = 4096 # 模型核心维度(隐藏层/词嵌入维度)
n_layers: int = 6 # Transformer块层数(原Llama 2为32层,此处简化为6层)
n_heads: int = 32 # 注意力头数(查询Q的头数)
n_kv_heads: Optional[int] = None # K/V注意力头数(默认与n_heads相同,可设更小以节省显存)
vocab_size: int = 32000 # 词汇表大小(Llama 2默认32000)
multiple_of: int = 256 # SwiGLU前馈网络隐藏层维度对齐因子(确保维度是该值的倍数)
ffn_dim_multiplier: Optional[float] = None # 前馈网络维度乘数(覆盖默认计算逻辑)
norm_eps: float = 1e-5 # RMSNorm归一化层的epsilon(避免除零)
max_batch_size: int = 32 # 最大批次大小(用于初始化K/V缓存)
max_seq_len: int = 2048 # 最大序列长度(用于初始化K/V缓存和位置编码)
# -------------------------- RMSNorm归一化层 --------------------------
class RMSNorm(nn.Module):
"""
Root Mean Square Normalization(均方根归一化)
相比LayerNorm,RMSNorm只归一化,不中心化(移除均值),计算更快、更稳定
公式:output = (x / sqrt(E[x²] + eps)) * weight
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps # 数值稳定性参数
self.weight = nn.Parameter(torch.ones(dim)) # 可学习的缩放参数(保持表达能力)
def norm(self, x: torch.Tensor) -> torch.Tensor:
"""核心归一化计算:逐样本沿最后一维计算RMS"""
# x.pow(2): 元素平方 -> mean(-1, keepdim=True): 最后一维求均值并保持维度
# torch.rsqrt: 求倒数平方根 -> 等价于 1/sqrt(...)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""前向传播:归一化 + 可学习缩放"""
# 转为float计算(避免低精度下的数值问题),再转回原数据类型
output = self.norm(x.float()).type_as(x)
return output * self.weight # 缩放权重
# -------------------------- 旋转位置编码(RoPE) --------------------------
def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0
) -> torch.Tensor:
"""
预计算旋转位置编码的复数形式(Complex Exponential)
用于RoPE:将位置信息编码到Q/K的相位中,保持相对位置不变性
Args:
dim: 每个注意力头的维度(必须是偶数)
end: 最大序列长度(预计算到该长度的编码)
theta: 频率基数(默认10000,遵循原始Transformer位置编码)
Returns:
freqs_cis: [end, dim//2] 的复数Tensor(模长=1,相位随位置和维度变化)
"""
# 计算频率:theta^(2i/dim) 的倒数,i是维度索引(0,1,...,dim//2-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # 序列位置 [0,1,...,end-1]
freqs = torch.outer(t, freqs).float() # 位置×频率:[end, dim//2]
# 转为复数形式:polar(模长, 相位) -> 模长=1,相位=freqs
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
调整freqs_cis的形状,使其能与Q/K进行广播乘法
Args:
freqs_cis: 预计算的位置编码 [seq_len, dim//2]
x: Q或K [batch_size, seq_len, n_heads, head_dim]
Returns:
重塑后的freqs_cis,适配x的维度
"""
ndim = x.ndim
assert 0 <= 1 < ndim, "x的维度必须至少为2(batch_size, seq_len, ...)"
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), "freqs_cis形状不匹配"
# 构建广播形状:除了seq_len(第1维)和head_dim(最后1维),其余维度设为1
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
对查询Q和键K应用旋转位置编码
原理:将Q/K的实数值转为复数,与位置编码复数相乘(旋转相位),再转回实数
Args:
xq: 查询向量 [batch_size, seq_len, n_heads, head_dim]
xk: 键向量 [batch_size, seq_len, n_kv_heads, head_dim]
freqs_cis: 预计算的位置编码 [seq_len, head_dim//2]
Returns:
应用RoPE后的Q和K
"""
# 将Q/K重塑为复数形式:[..., head_dim] -> [..., head_dim//2, 2](实部+虚部)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# 调整freqs_cis形状以适配广播
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
# 复数乘法:旋转相位(模长不变,相位相加)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # 转回实数并展平最后两维
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
# 转回原数据类型
return xq_out.type_as(xq), xk_out.type_as(xk)
# -------------------------- 前馈网络(SwiGLU) --------------------------
class FeedForward(nn.Module):
"""
基于SwiGLU激活函数的前馈网络(Llama 2使用的FFN结构)
结构:Linear -> SwiGLU -> Dropout -> Linear
其中中间层维度由multiple_of和ffn_dim_multiplier共同决定
"""
def __init__(self, args: ModelArgs):
super().__init__()
# 计算中间层维度:默认是dim * 4,再按multiple_of对齐
hidden_dim = 4 * args.dim
if args.ffn_dim_multiplier is not None:
hidden_dim = int(args.ffn_dim_multiplier * args.dim)
# 确保中间层维度是multiple_of的整数倍(硬件优化)
hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) # 第一层投影
self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) # 输出投影
self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) # 门控投影(SwiGLU特有)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""前向传播:x -> w1(x) * sigmoid(w3(x)) -> w2(...)"""
# SwiGLU激活:(w1(x) * sigmoid(w3(x))),相比ReLU更平滑、表达能力更强
return self.w2(F.silu(self.w1(x)) * self.w3(x))
# -------------------------- 多头注意力层 --------------------------
class Attention(nn.Module):
"""
多头注意力层(支持K/V头数少于Q头数,即分组注意力)
包含K/V缓存机制,用于推理时加速(避免重复计算历史序列的K/V)
"""
def __init__(self, args: ModelArgs):
super().__init__()
# K/V头数:默认与Q头数相同,可单独设置以减少计算量
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_local_heads = args.n_heads # Q的注意力头数
self.n_local_kv_heads = self.n_kv_heads # K/V的注意力头数
self.n_rep = self.n_local_heads // self.n_local_kv_heads # K/V头重复次数(对齐Q头数)
self.head_dim = args.dim // args.n_heads # 每个注意力头的维度
# Q/K/V投影层:将输入dim映射到(头数×头维度)
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
# 输出投影层:将多头注意力结果映射回模型dim
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
# K/V缓存:存储历史序列的K和V,形状[max_batch_size, max_seq_len, n_kv_heads, head_dim]
# 用于推理时增量计算(只计算当前token的Q与历史K/V的注意力)
self.cache_k = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)
)
self.cache_v = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)
)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
前向传播:计算多头注意力
Args:
x: 输入张量 [batch_size, seq_len, dim]
start_pos: 当前序列在缓存中的起始位置(推理时用)
freqs_cis: 旋转位置编码 [seq_len, head_dim//2]
mask: 注意力掩码 [seq_len, cache_len](避免关注未来token)
Returns:
注意力输出 [batch_size, seq_len, dim]
"""
bsz, seqlen, _ = x.shape # batch_size, 当前序列长度, 模型dim
# 1. Q/K/V投影 + 形状调整:[batch_size, seq_len, 头数×头维度] -> [batch_size, seq_len, 头数, 头维度]
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# 2. 对Q/K应用旋转位置编码(RoPE)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# 3. 更新K/V缓存:将当前序列的K/V存入缓存(覆盖对应位置)
self.cache_k = self.cache_k.to(xq.device) # 确保缓存与输入在同一设备
self.cache_k[:bsz, start_pos:start_pos+seqlen] = xk
self.cache_v = self.cache_v.to(xq.device)
self.cache_v[:bsz, start_pos:start_pos+seqlen] = xv
# 4. 取出缓存中的所有K/V(包含历史序列 + 当前序列)
keys = self.cache_k[:bsz, :start_pos+seqlen] # [bsz, cache_len, n_kv_heads, head_dim]
values = self.cache_v[:bsz, :start_pos+seqlen]
# 5. 重复K/V头:若K/V头数 < Q头数,通过重复对齐(分组注意力机制)
keys = keys.repeat_interleave(self.n_rep, dim=2) # [bsz, cache_len, n_heads, head_dim]
values = values.repeat_interleave(self.n_rep, dim=2)
# 6. 调整维度顺序:为注意力计算做准备
# [batch_size, 头数, seq_len, 头维度](适配torch.matmul的广播)
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2) # [batch_size, 头数, cache_len, 头维度]
values = values.transpose(1, 2) # [batch_size, 头数, cache_len, 头维度]
# 7. 计算注意力分数:Q @ K^T / sqrt(head_dim)(缩放点积注意力)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
# 应用掩码:屏蔽未来token(推理时)或无效位置
if mask is not None:
scores = scores + mask # 未来位置设为-inf,softmax后概率为0
# 8. 计算注意力权重(softmax归一化)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# 9. 注意力加权求和:权重 @ values
output = torch.matmul(scores, values) # [batch_size, n_heads, seq_len, head_dim]
# 10. 维度调整 + 输出投影:合并多头结果,映射回模型dim
output = output.transpose(1, 2).contiguous() # [batch_size, seq_len, n_heads, head_dim]
output = output.view(bsz, seqlen, -1) # 合并头数和头维度:[batch_size, seq_len, n_heads×head_dim]
output = self.wo(output) # 投影到模型dim:[batch_size, seq_len, dim]
return output
# -------------------------- Transformer块(解码器层) --------------------------
class TransformerBlock(nn.Module):
"""
单个Transformer解码器块(Llama 2架构)
结构:预归一化(RMSNorm)-> 多头注意力 -> 残差连接 ->
预归一化(RMSNorm)-> 前馈网络 -> 残差连接
"""
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.layer_id = layer_id # 层索引(用于调试/扩展)
self.args = args
# 注意力层前的归一化
self.norm_1 = RMSNorm(args.dim, eps=args.norm_eps)
# 多头注意力层
self.attn = Attention(args)
# 前馈网络前的归一化
self.norm_2 = RMSNorm(args.dim, eps=args.norm_eps)
# 前馈网络(SwiGLU)
self.feed_forward = FeedForward(args)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
前向传播(预归一化+残差连接)
Args:
x: 输入 [batch_size, seq_len, dim]
start_pos: 缓存起始位置
freqs_cis: 旋转位置编码
mask: 注意力掩码
Returns:
块输出 [batch_size, seq_len, dim]
"""
# 注意力分支:归一化 -> 注意力 -> 残差连接
h = x + self.attn(self.norm_1(x), start_pos, freqs_cis, mask)
# 前馈网络分支:归一化 -> 前馈网络 -> 残差连接
out = h + self.feed_forward(self.norm_2(h))
return out
# -------------------------- 完整Transformer解码器模型 --------------------------
class Transformer(nn.Module):
"""
完整的Transformer解码器模型(Llama 2核心架构)
结构:词嵌入 -> 多层TransformerBlock -> 最终归一化 -> 输出投影
"""
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
# 1. 词嵌入层:将token索引映射为模型dim的向量
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
# 2. 构建多层Transformer块(ModuleList支持索引和迭代)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
# 3. 最终归一化层(所有块之后)
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
# 4. 输出投影层:将模型dim映射为词汇表大小(预测下一个token的概率)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
# 5. 预计算旋转位置编码(只计算一次,推理时复用)
# 编码维度 = 每个头的维度(params.dim // params.n_heads)
self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads, self.params.max_seq_len
)
@torch.inference_mode() # 推理模式:禁用梯度计算,节省显存
def forward(self, tokens: torch.Tensor, start_pos: int) -> torch.Tensor:
"""
模型前向传播(推理阶段)
Args:
tokens: 输入token序列 [batch_size, seq_len](整数索引)
start_pos: 当前序列在缓存中的起始位置(0表示全新序列,>0表示增量推理)
Returns:
logits: 每个位置的词汇表概率分布 [batch_size, seq_len, vocab_size]
"""
bsz, seqlen = tokens.shape # batch_size, 输入序列长度
# 1. 词嵌入:[batch_size, seq_len] -> [batch_size, seq_len, dim]
h = self.tok_embeddings(tokens)
# 2. 加载预计算的位置编码(确保与输入在同一设备)
self.freqs_cis = self.freqs_cis.to(h.device)
# 截取当前序列对应的位置编码(start_pos 到 start_pos+seqlen)
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
# 3. 生成注意力掩码(仅当序列长度>1时,屏蔽未来token)
mask = None
if seqlen > 1:
# 创建上三角掩码:[seq_len, seq_len],上三角为-inf(屏蔽未来)
mask = torch.full(
(seqlen, seqlen), float("-inf"), device=tokens.device
)
mask = torch.triu(mask, diagonal=1) # diagonal=1表示不屏蔽当前位置
# 拼接历史序列的掩码:历史位置允许关注(全0),当前序列用原掩码
# 最终掩码形状:[seq_len, start_pos + seqlen](cache_len = start_pos + seqlen)
mask = torch.hstack([
torch.zeros((seqlen, start_pos), device=tokens.device),
mask
]).type_as(h)
# 4. 逐层通过Transformer块
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
# 5. 最终归一化 + 输出投影(得到logits)
h = self.norm(h)
output = self.output(h).float() # 转为float避免低精度数值问题
return output
# -------------------------- 模型实例化 --------------------------
if __name__ == "__main__":
# 初始化模型配置
model_args = ModelArgs()
model_args.vocab_size = 32000 # 显式设置词汇表大小(与默认一致,此处为演示)
# 创建模型实例
model = Transformer(model_args)
# 打印模型结构(查看各层配置和参数数量)
print("模型结构:")
print(model)
# 计算模型参数量
total_params = sum(p.numel() for p in model.parameters())
print(f"\n模型总参数量:{total_params / 1e9:.2f}B" if total_params > 1e9 else f"\n模型总参数量:{total_params / 1e6:.2f}M")
输出是这样喵:
transformer(
(tok_embeddings): Embedding(32000, 4096)
(layers): ModuleList(
(0-5): 6 x TransformerBlock(
(norm_1): RMSNorm()
(attn): Attention(
(wq): Linear(in_features=4096, out_features=4096, bias=False)
(wk): Linear(in_features=4096, out_features=4096, bias=False)
(wv): Linear(in_features=4096, out_features=4096, bias=False)
(wo): Linear(in_features=4096, out_features=4096, bias=False)
)
(norm_2): RMSNorm()
(feed_forward): FeedForward(
(w1): Linear(in_features=4096, out_features=16384, bias=False)
(w2): Linear(in_features=16384, out_features=4096, bias=False)
(w3): Linear(in_features=4096, out_features=16384, bias=False)
)
)
)
(norm): RMSNorm()
(output): Linear(in_features=4096, out_features=32000, bias=False)
)
模型总参数量:1.87B
我在代码中做了比较详细的注释,希望能读者帮助理解喵
关键设计亮点(也是 Llama 2 的核心优化):
- 纯解码器结构(仅用自注意力,无编码器 - 解码器注意力),专注文本生成;
- 预归一化(RMSNorm)+ 残差连接,训练更稳定、收敛更快;
- 旋转位置编码(RoPE),保持相对位置不变性,支持更长序列;
- SwiGLU 前馈网络,比 ReLU 激活更高效,表达能力更强;
- 支持 K/V 头数与 Q 头数分离(分组注意力),节省显存;
- K/V 缓存,推理时加速增量生成(不用重复计算历史序列)
模型代码文件我会单独放到资源里面的,免费开源,可以自行提取,喜欢我带来的文章的小伙伴别忘了赏点小鱼干喵~
更多推荐




所有评论(0)