# Transformer 从0到1:认知科学中的注意力——从直觉到算法

## 第一章:注意力的认知科学起源

### 1.1 人类注意力:认知资源的分配机制

在人类认知系统中,注意力是一种高度选择性的信息处理机制。当我们面对外部世界的海量信息时,大脑并没有能力对所有输入进行同等深度的处理。相反,认知系统会智能地分配有限的计算资源,聚焦于当前任务最相关的信息,而忽略或抑制无关的干扰。

这种选择性注意的经典实验可以追溯到20世纪50年代。英国心理学家Donald Broadbent提出的过滤器模型(Filter Model)认为,人类的感知系统就像一个狭窄的瓶颈,大量的感觉信息进入后,只有一个通道能够通过选择性过滤器,进入高级认知处理阶段。后来的研究进一步揭示了注意力的两种基本模式:

**内源性注意(Endogenous Attention)**:这是我们主动控制的、自上而下的注意力。当你正在阅读一篇技术文章时,决定忽略手机通知的声音,这就是内源性注意在工作。它是有目的的、受意志控制的,通常需要一定的认知努力。

**外源性注意(Exogenous Attention)**:这是一种被动的、刺激驱动的注意机制。当一个突然的响声吸引你的注意,或者页面上的闪烁元素抓住了你的视线,这就是外源性注意在起作用。它是自动的、快速的,但持续时间较短。

神经科学研究发现,注意力在神经层面体现为特定脑区(如前额叶皮层、顶叶皮层)对感觉皮层的调制作用。当我们将注意力投向某个视觉区域时,相应区域的神经元放电率会显著提高,信息处理变得更加精确和高效。这种“神经增益控制”机制,本质上是一种动态的资源分配策略。

### 1.2 从生物学到计算:注意力机制的算法化

认知科学对注意力的研究,为人工智能领域提供了重要的灵感来源。在机器学习领域,注意力机制的核心思想是:让模型在处理信息时,能够动态地为不同部分分配不同的权重。

这个思想最早在计算机视觉领域有所探索,但真正引发革命的是2014年Bahdanau等人将注意力机制引入神经机器翻译。在此之前,序列到序列(Seq2Seq)模型面临一个根本性的问题:编码器需要将整个输入序列压缩成一个固定长度的向量(称为“上下文向量”),然后解码器基于这个向量生成输出。当输入序列较长时,这个瓶颈结构会导致严重的信息丢失——早期的信息会被后续的信息稀释,模型难以有效记忆长距离依赖关系。

注意力机制的引入彻底解决了这个问题。在带注意力的翻译模型中,解码器在生成每个目标词时,不再依赖于一个单一的上下文向量,而是可以动态地“回看”编码器处理的所有源语言词,并为每个源语言词计算一个相关性分数(即注意力权重)。这些权重决定了在生成当前词时,每个源语言词应该贡献多少信息。

这种计算模式与人类注意力有着惊人的相似性:
- **选择性聚焦**:模型将大部分权重分配给少数几个最相关的源语言词,类似于人类的注意力聚焦
- **上下文整合**:模型仍然保留了其他词的少量信息,类似于人类注意力的外围视野
- **动态调整**:每个解码步骤的注意力分布都不同,体现了注意力的动态性

从认知科学的视角来看,注意力机制可以被理解为一种“软寻址”机制:编码器的输出构成了一个“记忆存储器”,解码器的当前状态作为“查询”,通过计算查询与每个记忆项的匹配度,得到注意力分布,最后加权求和得到“读取”的内容。

### 1.3 注意力的三种基本形式

在Transformer出现之前,注意力机制已经发展出几种不同的形式。理解这些基本形式是理解Transformer自注意力的基础。

**加性注意力(Additive Attention)**

Bahdanau等人提出的原始注意力机制,也称为加性注意力。其核心思想是:给定查询向量q和一组键向量k₁, k₂, …, kₙ,通过一个前馈神经网络来计算注意力分数:

score(q, k_i) = v^T · tanh(W_q q + W_k k_i)

其中v、W_q、W_k是可学习的参数。加性注意力的优势在于它能够处理查询和键具有不同维度的情况,但计算效率相对较低,因为每个注意力分数的计算都需要经过一次前馈网络。

**点积注意力(Dot-Product Attention)**

点积注意力是一种更简洁的形式:注意力分数直接通过查询和键的点积来计算:

score(q, k_i) = q · k_i

这种形式的计算效率更高,因为点积可以并行计算。但它要求查询和键具有相同的维度。在实际应用中,为了防止点积结果过大导致softmax进入梯度饱和区域,通常会对点积进行缩放:

score(q, k_i) = (q · k_i) / √d

其中d是向量的维度。这种缩放点积注意力正是Transformer采用的形式。

**多头注意力(Multi-Head Attention)**

多头注意力是Transformer的关键创新之一。其核心思想是:与其使用单一的注意力函数,不如将查询、键、值分别投影到多个不同的表示子空间,在每个子空间上独立执行注意力计算,然后将所有头的输出拼接起来,再通过一个线性变换。

这种设计的动机可以从认知科学的角度理解:人类注意力的功能是高度分化的。视觉注意系统包含多个并行的处理通路——负责运动检测的“where通路”和负责物体识别的“what通路”;听觉注意也包含对不同频率、不同空间位置的并行处理。多头注意力正是模拟了这种并行处理机制:每个注意力头可以学习关注不同类型的模式或关系,从而捕捉更丰富的特征交互。

## 第二章:Transformer架构的深度解析

### 2.1 从RNN/LSTM的困境到Transformer的革命

在Transformer出现之前,循环神经网络(RNN)及其变体LSTM、GRU是处理序列数据的主流架构。这些模型通过循环连接逐步处理序列中的每个元素,在时间步t的状态依赖于前一时间步t-1的状态,从而形成对序列的“记忆”。

然而,RNN类模型存在几个根本性的问题:

**顺序计算瓶颈**:由于每个时间步的计算依赖于前一个时间步的结果,RNN无法并行化。这导致训练时间随序列长度线性增长,在处理长序列时效率极低。

**梯度消失与爆炸**:在反向传播过程中,梯度需要沿着时间维度传播。当序列较长时,连乘效应会导致梯度指数级衰减(消失)或增长(爆炸)。虽然LSTM等结构通过门控机制缓解了这个问题,但长距离依赖的学习仍然困难。

**固定长度的记忆状态**:RNN的隐藏状态是一个固定维度的向量,随着序列的推进,早期输入的信息会被不断覆盖和稀释。理论上,LSTM可以选择性地记忆长期信息,但在实践中,处理超过100个时间步的依赖关系仍然非常困难。

Transformer在2017年由Vaswani等人提出,通过彻底摒弃循环结构,开创了基于纯注意力的序列建模范式。其核心创新包括:

1. **自注意力机制**:序列中的每个位置都可以直接与所有位置交互,计算它们之间的相关性。这种全连接的交互模式使得任意两个位置之间的路径长度为O(1),彻底解决了长距离依赖问题。

2. **并行化计算**:由于没有时序依赖,Transformer可以在一次前向传播中处理整个序列,实现了真正的并行化,大幅提升了训练效率。

3. **位置编码**:由于自注意力本身是置换不变的(即不考虑词序),Transformer通过添加位置编码来注入序列顺序信息。

### 2.2 编码器-解码器架构概览

Transformer遵循经典的编码器-解码器架构,但与传统RNN-based Seq2Seq不同的是,编码器和解码器都由多个相同的层堆叠而成。

**编码器(Encoder)**

编码器由N个相同的层组成(原论文中N=6)。每个层包含两个子层:
- 多头自注意力子层
- 前馈神经网络子层

每个子层都使用了残差连接和层归一化:子层输出 = LayerNorm(x + Sublayer(x))

**解码器(Decoder)**

解码器同样由N个相同的层组成,但每个层包含三个子层:
- 掩码多头自注意力子层(Masked Multi-Head Attention)
- 编码器-解码器注意力子层(Cross-Attention)
- 前馈神经网络子层

掩码自注意力的作用是防止解码器在生成当前位置时“看到”未来的位置,确保自回归生成时的因果性。

### 2.3 自注意力:核心机制的数学原理

自注意力是Transformer最核心的组件。对于输入序列X = [x₁, x₂, …, xₙ],其中每个x_i是d_model维的向量,自注意力的计算过程如下:

**第一步:生成查询、键、值**

通过三个不同的线性变换矩阵W_Q、W_K、W_V,将每个输入向量投影到三个不同的空间:

Q = X W_Q
K = X W_K
V = X W_V

其中Q、K、V的形状都是(n, d_k),d_k通常取d_model / h,h是注意力头的数量。

**第二步:计算注意力分数**

通过查询矩阵和键矩阵的点积,计算所有位置之间的相关性分数:

Scores = Q K^T / √d_k

这里的缩放因子√d_k是为了防止点积结果过大。当d_k较大时,点积结果的方差也会变大,使得softmax函数的梯度进入饱和区域,导致训练困难。

**第三步:应用softmax归一化**

Attention_Weights = softmax(Scores)

Softmax函数将每一行的分数转换为概率分布,确保所有位置的权重之和为1。注意力的每一行表示了该位置对所有位置的关注程度。

**第四步:加权求和**

Output = Attention_Weights · V

最终的输出是值向量的加权和。每个位置的输出都是对所有位置的值向量进行加权组合,权重由该位置与各位置的相似度决定。

从认知科学的角度,这个过程可以这样理解:对于每个“查询”位置,系统会评估它与所有“键”位置的匹配程度,然后根据匹配度从对应的“值”中提取信息。这类似于人类在阅读时,当前关注的词会激活记忆中与之相关的其他词的语义表征,并将它们整合到当前的语义理解中。

### 2.4 多头注意力:并行特征提取

多头注意力将上述过程重复h次,每次使用不同的线性投影矩阵:

MultiHead(Q, K, V) = Concat(head₁, …, head_h) W_O
其中 head_i = Attention(Q W_{Q_i}, K W_{K_i}, V W_{V_i})

每个注意力头可以学习关注不同的关系模式。例如:
- 某些头可能关注相邻词的局部依赖关系
- 某些头可能关注跨越长距离的全局依赖关系
- 某些头可能关注语法结构的依存关系
- 某些头可能关注语义相似性

这种并行机制使得Transformer能够同时捕捉多种类型的特征交互,类似于人类认知系统中的多通道并行处理。

### 2.5 位置编码:注入序列顺序信息

由于自注意力本身是置换不变的——将输入序列的顺序打乱后,每个位置仍然会与所有位置交互,但输出的顺序也会相应打乱。这意味着Transformer本身不具备感知词序的能力。

为了注入位置信息,Transformer使用位置编码(Positional Encoding)将位置信息添加到输入嵌入中。原始论文使用正弦和余弦函数:

PE_{(pos, 2i)} = sin(pos / 10000^{2i/d_model})
PE_{(pos, 2i+1)} = cos(pos / 10000^{2i/d_model})

其中pos是位置索引,i是维度索引。

选择这种函数形式的原因包括:
- 不同位置的编码向量具有可区分的模式
- 任意两个位置的编码之间的相对位置关系可以通过线性变换表示
- 正弦/余弦函数的周期性使得模型可以外推到训练时未见过的序列长度

### 2.6 前馈网络与归一化层

**前馈网络(Feed-Forward Network)**

每个编码器和解码器层都包含一个位置前馈网络(Position-wise FFN),它对每个位置独立地应用相同的全连接网络:

FFN(x) = max(0, xW₁ + b₁) W₂ + b₂

这是一个两层的全连接网络,中间使用ReLU激活函数。第一层将维度从d_model扩展到d_ff(原论文中d_ff=2048),第二层再将维度压缩回d_model。

“位置前馈”意味着对序列中的每个位置使用相同的网络参数,但不同位置的处理是独立的。这类似于对每个位置分别应用相同的非线性变换。

**残差连接与层归一化**

残差连接和层归一化对于训练深度Transformer至关重要:

- **残差连接**:将子层的输入直接加到输出上,形成“捷径”。这有助于梯度流动,缓解深层网络中的梯度消失问题。
- **层归一化**:对每个样本的特征维度进行归一化,使训练更加稳定。与批归一化不同,层归一化不依赖于批次大小,更适合处理变长序列。

### 2.7 编码器-解码器注意力

解码器中的交叉注意力子层(Cross-Attention)连接编码器和解码器。在这个子层中:
- 查询Q来自解码器的上一层的输出
- 键K和值V来自编码器的输出

这种设计允许解码器在生成每个输出词时,能够关注输入序列的所有位置。这类似于人类在翻译句子时,生成每个目标词时都会回顾源语言句子的相关部分。

### 2.8 掩码机制

解码器中的掩码自注意力使用了一个特殊的掩码,称为“后续掩码”(Look-ahead Mask)。其形式是一个上三角矩阵,其中未来位置的值被设置为-∞,使得softmax后这些位置的注意力权重为0:

Masked_Scores = Scores + M, 其中 M_{ij} = 0 if i ≥ j else -∞

这确保了在预测第i个位置时,模型只能依赖于位置0到i-1的信息,而不能“看到”未来的词。这是自回归生成的基本要求。

## 第三章:从零实现的完整代码

下面,我们将用PyTorch从零实现一个完整的Transformer模型。代码将遵循原始论文的架构,并包含详细的注释。

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy

class LayerNorm(nn.Module):
    """层归一化模块"""
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))  # 缩放参数
        self.b_2 = nn.Parameter(torch.zeros(features))  # 偏移参数
        self.eps = eps
        
    def forward(self, x):
        # x形状: [batch_size, seq_len, features]
        mean = x.mean(-1, keepdim=True)  # 沿特征维度计算均值
        std = x.std(-1, keepdim=True)    # 沿特征维度计算标准差
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class PositionalEncoding(nn.Module):
    """位置编码模块"""
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)   # 偶数维度用正弦
        pe[:, 1::2] = torch.cos(position * div_term)   # 奇数维度用余弦
        pe = pe.unsqueeze(0)  # 添加batch维度,形状: [1, max_len, d_model]
        
        self.register_buffer('pe', pe)  # 注册为buffer,不参与梯度更新
        
    def forward(self, x):
        # x形状: [batch_size, seq_len, d_model]
        x = x + self.pe[:, :x.size(1), :]  # 添加位置编码
        return self.dropout(x)


def clones(module, N):
    """创建N个相同的层"""
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def attention(query, key, value, mask=None, dropout=None):
    """缩放点积注意力核心函数
    
    参数:
        query: [batch_size, h, seq_len_q, d_k]
        key: [batch_size, h, seq_len_k, d_k]
        value: [batch_size, h, seq_len_v, d_v] (seq_len_v = seq_len_k)
        mask: [batch_size, seq_len_q, seq_len_k] 或广播兼容的形状
        dropout: Dropout层
        
    返回:
        output: [batch_size, h, seq_len_q, d_v]
        attention_weights: [batch_size, h, seq_len_q, seq_len_k]
    """
    d_k = query.size(-1)
    
    # 计算注意力分数: Q * K^T / sqrt(d_k)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 应用掩码(如果提供)
    if mask is not None:
        # 将掩码为0的位置设置为一个很小的负数,使其在softmax后权重接近0
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Softmax归一化
    p_attn = F.softmax(scores, dim=-1)
    
    # 应用dropout
    if dropout is not None:
        p_attn = dropout(p_attn)
    
    # 加权求和: attention_weights * V
    return torch.matmul(p_attn, value), p_attn


class MultiHeadAttention(nn.Module):
    """多头注意力模块"""
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0, "d_model必须能被h整除"
        
        self.h = h                      # 注意力头数
        self.d_k = d_model // h         # 每个头的维度
        self.d_model = d_model
        
        # 线性投影层
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None                # 存储注意力权重用于可视化
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        """
        参数:
            query: [batch_size, seq_len_q, d_model]
            key: [batch_size, seq_len_k, d_model]
            value: [batch_size, seq_len_v, d_model]
            mask: [batch_size, seq_len_q, seq_len_k] 或广播兼容的形状
        """
        if mask is not None:
            # 为多头扩展mask维度: [batch_size, 1, seq_len_q, seq_len_k]
            mask = mask.unsqueeze(1)
        
        batch_size = query.size(0)
        
        # 1. 线性投影并分割成多头
        # 将query, key, value通过线性层,然后reshape为[batch_size, h, seq_len, d_k]
        query, key, value = [
            lin(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))
        ]
        
        # 2. 应用注意力机制
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        
        # 3. 合并多头: [batch_size, h, seq_len_q, d_k] -> [batch_size, seq_len_q, d_model]
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # 4. 最终的线性投影
        return self.linears[-1](x)


class PositionwiseFeedForward(nn.Module):
    """位置前馈网络"""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)   # 第一层线性变换
        self.w_2 = nn.Linear(d_ff, d_model)   # 第二层线性变换
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x形状: [batch_size, seq_len, d_model]
        return self.w_2(self.dropout(F.relu(self.w_1(x))))


class SublayerConnection(nn.Module):
    """子层连接模块:残差连接 + 层归一化"""
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, sublayer):
        """应用残差连接:x + dropout(sublayer(norm(x)))"""
        return x + self.dropout(sublayer(self.norm(x)))


class EncoderLayer(nn.Module):
    """编码器层"""
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size
        
    def forward(self, x, mask):
        """编码器层前向传播"""
        # 子层1: 多头自注意力
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        # 子层2: 前馈网络
        return self.sublayer[1](x, self.feed_forward)


class Encoder(nn.Module):
    """编码器:由N个编码器层堆叠而成"""
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask):
        """通过所有编码器层"""
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)


class DecoderLayer(nn.Module):
    """解码器层"""
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn          # 掩码自注意力
        self.src_attn = src_attn            # 交叉注意力
        self.feed_forward = feed_forward    # 前馈网络
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        """解码器层前向传播
        
        参数:
            x: 解码器输入 [batch_size, seq_len_tgt, d_model]
            memory: 编码器输出 [batch_size, seq_len_src, d_model]
            src_mask: 源序列的填充掩码
            tgt_mask: 目标序列的后续掩码
        """
        # 子层1: 掩码自注意力(不能看到未来位置)
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        # 子层2: 编码器-解码器注意力(查询来自解码器,键值来自编码器)
        x = self.sublayer[1](x, lambda x: self.src_attn(x, memory, memory, src_mask))
        # 子层3: 前馈网络
        return self.sublayer[2](x, self.feed_forward)


class Decoder(nn.Module):
    """解码器:由N个解码器层堆叠而成"""
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        """通过所有解码器层"""
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)


class Transformer(nn.Module):
    """完整的Transformer模型"""
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator, d_model):
        super(Transformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed      # 源序列嵌入层
        self.tgt_embed = tgt_embed      # 目标序列嵌入层
        self.generator = generator      # 输出生成层
        self.d_model = d_model
        
    def encode(self, src, src_mask):
        """编码器前向传播"""
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        """解码器前向传播"""
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
    
    def forward(self, src, tgt, src_mask, tgt_mask):
        """完整前向传播"""
        memory = self.encode(src, src_mask)
        output = self.decode(memory, src_mask, tgt, tgt_mask)
        return self.generator(output)


class Generator(nn.Module):
    """输出生成层:线性变换 + Softmax"""
    def __init__(self, d_model, vocab_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        # x形状: [batch_size, seq_len, d_model]
        # 输出形状: [batch_size, seq_len, vocab_size]
        return F.log_softmax(self.proj(x), dim=-1)


def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    """构建Transformer模型的工厂函数"""
    c = copy.deepcopy
    attn = MultiHeadAttention(h, d_model, dropout)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    model = Transformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab),
        d_model
    )
    
    # 参数初始化
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return model


class Embeddings(nn.Module):
    """词嵌入层,同时包含缩放因子√d_model"""
    def __init__(self, d_model, vocab_size):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model
        
    def forward(self, x):
        # 按照论文,对嵌入乘以√d_model
        return self.lut(x) * math.sqrt(self.d_model)
```

## 第四章:训练与推理实践

### 4.1 数据准备与批处理

在训练Transformer时,需要准备适当的批处理策略。由于序列长度不同,通常需要对同一批次内的序列进行填充(padding),使其长度一致。同时,需要创建相应的掩码来忽略填充位置。

```python
class Batch:
    """用于训练的数据批次"""
    def __init__(self, src, tgt=None, pad=0):
        self.src = src  # 源序列 [batch_size, src_len]
        self.src_mask = (src != pad).unsqueeze(-2)  # [batch_size, 1, src_len]
        
        if tgt is not None:
            # 创建目标序列的输入(去掉最后一个token)和输出(去掉第一个token)
            # 用于自回归训练
            self.tgt = tgt[:, :-1]  # 解码器输入
            self.tgt_y = tgt[:, 1:]  # 解码器期望输出
            self.tgt_mask = self.make_std_mask(self.tgt, pad)
            self.ntokens = (self.tgt_y != pad).data.sum()
    
    @staticmethod
    def make_std_mask(tgt, pad):
        """创建目标序列掩码:填充掩码 + 后续掩码"""
        # 填充掩码
        tgt_mask = (tgt != pad).unsqueeze(-2)
        # 后续掩码(上三角矩阵)
        subsequent_mask = torch.triu(
            torch.ones((tgt.size(-1), tgt.size(-1)), device=tgt.device), diagonal=1
        ).bool()
        # 组合掩码:填充位置和未来位置都不可见
        tgt_mask = tgt_mask & ~subsequent_mask
        return tgt_mask.unsqueeze(1)  # [batch_size, 1, seq_len, seq_len]
```

### 4.2 训练循环

```python
def run_epoch(data_iter, model, loss_compute, epoch, device):
    """运行一个训练/验证epoch"""
    total_tokens = 0
    total_loss = 0
    tokens = 0
    
    for i, batch in enumerate(data_iter):
        # 将数据移到GPU
        src = batch.src.to(device)
        tgt = batch.tgt.to(device)
        src_mask = batch.src_mask.to(device)
        tgt_mask = batch.tgt_mask.to(device)
        tgt_y = batch.tgt_y.to(device)
        
        # 前向传播
        out = model(src, tgt, src_mask, tgt_mask)
        
        # 计算损失
        loss = loss_compute(out, tgt_y, batch.ntokens)
        
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        
        # 每100步打印一次进度
        if i % 100 == 1:
            print(f"Epoch {epoch}, Step {i}, Loss: {loss:.4f}, Tokens: {tokens}")
            tokens = 0
    
    return total_loss / total_tokens
```

### 4.3 标签平滑

标签平滑是一种正则化技术,通过将目标分布从one-hot分布调整为软分布,防止模型对预测过于自信,从而提高泛化能力。

```python
class LabelSmoothing(nn.Module):
    """标签平滑损失函数"""
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction='sum')
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        """
        x: 模型输出log概率 [batch_size, seq_len, vocab_size]
        target: 目标索引 [batch_size, seq_len]
        """
        assert x.size(1) == target.size(1)
        true_dist = x.data.clone()
        
        # 初始化为平滑值
        true_dist.fill_(self.smoothing / (self.size - 2))
        
        # 设置正确类别的置信度
        true_dist.scatter_(2, target.unsqueeze(2), self.confidence)
        
        # 忽略填充位置
        true_dist[:, :, self.padding_idx] = 0
        mask = (target == self.padding_idx).unsqueeze(2)
        true_dist.masked_fill_(mask, 0)
        
        self.true_dist = true_dist
        return self.criterion(x, true_dist.detach())
```

### 4.4 学习率调度

Transformer的训练使用了特定的学习率调度策略:前warmup步学习率线性增长,之后按步数的平方根倒数衰减。

```python
class NoamOpt:
    """Transformer的学习率调度器"""
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        """执行优化步骤"""
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step=None):
        """计算当前学习率"""
        if step is None:
            step = self._step
        return self.factor * (
            self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5))
        )
```

### 4.5 推理:贪婪解码与束搜索

在推理阶段,Transformer通过自回归方式逐个生成输出词。

```python
def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol, device):
    """贪婪解码:每一步选择概率最大的词"""
    memory = model.encode(src, src_mask)
    
    # 初始化解码器输入:只包含起始符号
    tgt = torch.ones(1, 1, device=device).fill_(start_symbol).long()
    
    for i in range(max_len - 1):
        # 创建目标序列掩码
        tgt_mask = Batch.make_std_mask(tgt, pad=0).to(device)
        
        # 解码器前向传播
        out = model.decode(memory, src_mask, tgt, tgt_mask)
        
        # 获取最后一个位置的输出
        prob = model.generator(out[:, -1])
        
        # 选择概率最大的词
        _, next_word = torch.max(prob, dim=-1)
        next_word = next_word.item()
        
        # 将新词添加到序列中
        tgt = torch.cat([tgt, torch.ones(1, 1, device=device).fill_(next_word).long()], dim=-1)
        
        # 如果生成了结束符,停止生成
        if next_word == end_symbol:
            break
    
    return tgt


def beam_search_decode(model, src, src_mask, max_len, start_symbol, end_symbol, 
                       beam_size, device):
    """束搜索解码:维护多个候选序列"""
    memory = model.encode(src, src_mask)
    
    # 初始候选:包含起始符号和初始分数
    candidates = [(torch.tensor([[start_symbol]], device=device), 0.0)]
    
    for _ in range(max_len - 1):
        all_candidates = []
        
        for seq, score in candidates:
            if seq[0, -1].item() == end_symbol:
                # 已经结束的序列直接保留
                all_candidates.append((seq, score))
                continue
            
            # 创建掩码
            tgt_mask = Batch.make_std_mask(seq, pad=0).to(device)
            
            # 解码
            out = model.decode(memory, src_mask, seq, tgt_mask)
            prob = model.generator(out[:, -1])
            log_probs = F.log_softmax(prob, dim=-1)[0]
            
            # 取前beam_size个最优的下一步
            top_log_probs, top_indices = torch.topk(log_probs, beam_size)
            
            for i in range(beam_size):
                next_word = top_indices[i].item()
                next_log_prob = top_log_probs[i].item()
                new_seq = torch.cat([seq, torch.tensor([[next_word]], device=device)], dim=-1)
                new_score = score + next_log_prob
                all_candidates.append((new_seq, new_score))
        
        # 选择分数最高的beam_size个候选
        all_candidates.sort(key=lambda x: x[1], reverse=True)
        candidates = all_candidates[:beam_size]
    
    # 返回分数最高的序列
    return candidates[0][0]
```

## 第五章:Transformer的演进与变体

### 5.1 BERT:双向编码器表示

BERT(Bidirectional Encoder Representations from Transformers)是Transformer架构最重要的演进之一。它采用了Transformer的编码器部分,但进行了两项关键创新:

**双向训练**:与原始Transformer解码器的单向自回归不同,BERT使用双向自注意力,允许每个词同时关注左右两侧的上下文。这种双向性是通过掩码语言模型(Masked Language Model)任务实现的:随机遮蔽输入中15%的词,然后让模型预测这些被遮蔽的词。

**下一句预测**:为了捕捉句子级别的关系,BERT还引入了一个辅助任务:给定两个句子,预测它们是否是连续的。

BERT的预训练-微调范式彻底改变了NLP领域。在大规模语料上预训练的BERT模型,只需要在下游任务数据上进行轻量级微调,就能达到当时最先进的性能。

### 5.2 GPT系列:自回归生成模型

GPT(Generative Pre-trained Transformer)采用了与BERT相反的策略:只使用Transformer的解码器部分,保持自回归的生成方式。GPT系列的演进体现了规模法则(Scaling Laws)的力量:

**GPT-1**:在BookCorpus上预训练,展示了预训练语言模型在多种下游任务上的有效性。

**GPT-2**:扩展到15亿参数,展示了零样本学习的能力,在多种任务上无需微调就能取得不错的效果。

**GPT-3**:扩展到1750亿参数,引入了上下文学习(In-Context Learning)的能力,通过少量示例即可适应新任务。

**GPT-4**:进一步扩展到多模态,展示了接近人类水平的多种能力。

### 5.3 T5:统一的文本到文本框架

T5(Text-to-Text Transfer Transformer)将所有NLP任务统一为文本到文本的格式。无论是分类、翻译还是摘要,模型的输入都是文本字符串,输出也是文本字符串。这种统一使得预训练、微调和推理的流程更加简洁。

T5对Transformer架构进行了一些改进:
- 简化了层归一化(移除偏置项)
- 使用相对位置编码替代绝对位置编码
- 引入了更高效的预训练目标(去噪目标)

### 5.4 高效Transformer

原始Transformer的复杂度为O(n²),在处理长序列时面临巨大的计算和内存压力。为此,研究者提出了多种高效Transformer变体:

**稀疏注意力**:通过限制每个位置的注意力范围来降低复杂度。例如,Longformer使用窗口注意力+全局注意力的组合;BigBird引入了随机注意力和全局注意力。

**低秩近似**:通过将注意力矩阵分解为低秩矩阵的乘积来降低复杂度。Linformer证明,注意力矩阵可以用低秩矩阵近似,从而将复杂度降至O(n)。

**循环注意力**:将长序列分段处理,在段内使用完整注意力,段间使用循环连接。Transformer-XL引入了片段递归机制,能够处理极长序列。

### 5.5 视觉Transformer(ViT)

Transformer在视觉领域的应用始于ViT(Vision Transformer)。ViT将图像分割为固定大小的图像块(patches),将每个块线性投影为向量,然后输入标准Transformer编码器。

ViT的成功表明,Transformer架构不仅适用于序列数据,也能在图像分类等视觉任务上达到甚至超越卷积神经网络(CNN)的性能。此后,视觉Transformer迅速扩展到了目标检测、分割、生成等多个视觉任务。

## 第六章:Transformer的认知科学再审视

### 6.1 注意力机制的生物合理性

回到认知科学的视角,我们可以思考:Transformer的注意力机制在多大程度上是生物合理的?

**优点**:
- **并行性**:人类大脑的视觉和听觉处理高度并行,Transformer的并行计算与这一特点相符
- **选择性**:注意力权重可以理解为神经元的增益调制,这与神经科学中的注意调制机制一致
- **动态性**:注意力分布随上下文变化,反映了认知系统对任务和环境的适应性

**差异**:
- **全局连接**:Transformer中每个位置可以关注所有位置,而生物大脑的神经连接是稀疏的、局部的
- **软注意力**:Transformer使用的是软注意力(所有位置都被加权),而生物注意力有时表现为硬选择(完全抑制无关信息)
- **可逆性**:Transformer的注意力矩阵是对称的(Q和K的角色可以互换),而生物注意力的机制可能是非对称的

### 6.2 多尺度处理的缺失

人类认知系统具有明显的层次化、多尺度处理特性。例如,在阅读时,我们同时处理字母、词、短语、句子、段落等多个层次的结构。原始Transformer虽然通过多层堆叠实现了某种层次化,但缺乏明确的多尺度机制。

最近的Hierarchical Transformer、Longformer等模型试图通过引入不同尺度的注意力来弥补这一缺陷,但距离人类认知的多尺度处理仍有差距。

### 6.3 工作记忆与长期记忆的整合

从认知科学的角度,人类记忆系统分为工作记忆(短期、容量有限)和长期记忆(持久、容量巨大)。Transformer可以类比为一种工作记忆模型:上下文窗口内的信息通过注意力机制进行交互,但超出窗口的信息就会丢失。

Transformer-XL、Memorizing Transformer等模型尝试引入外部记忆来模拟长期记忆,但如何有效整合工作记忆和长期记忆仍然是一个开放的研究方向。

### 6.4 元认知与自我监控

人类的高级认知系统具有元认知能力——能够监控和调节自己的认知过程。当前的Transformer模型缺乏这种自我监控机制。它们无法判断自己何时可能出错,也无法主动调整计算资源来应对困难情况。

将元认知机制引入Transformer,可能是一个重要的未来方向。例如,通过不确定性估计来引导模型在必要时进行更深入的推理,或者通过自适应计算时间来动态分配计算资源。

## 第七章:Transformer的核心代码补充

### 7.1 完整的训练脚本示例

```python
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np

class TranslationDataset(Dataset):
    """机器翻译数据集"""
    def __init__(self, src_file, tgt_file, src_vocab, tgt_vocab, max_len=100):
        self.src_data = []
        self.tgt_data = []
        
        # 读取数据(简化实现)
        with open(src_file, 'r', encoding='utf-8') as f_src, \
             open(tgt_file, 'r', encoding='utf-8') as f_tgt:
            for src_line, tgt_line in zip(f_src, f_tgt):
                src_tokens = src_line.strip().split()
                tgt_tokens = tgt_line.strip().split()
                
                # 截断过长序列
                if len(src_tokens) > max_len:
                    src_tokens = src_tokens[:max_len]
                if len(tgt_tokens) > max_len:
                    tgt_tokens = tgt_tokens[:max_len]
                
                # 转换为索引
                src_ids = [src_vocab.get(t, src_vocab['<unk>']) for t in src_tokens]
                tgt_ids = [tgt_vocab.get(t, tgt_vocab['<unk>']) for t in tgt_tokens]
                
                # 添加特殊标记
                src_ids = [src_vocab['<sos>']] + src_ids + [src_vocab['<eos>']]
                tgt_ids = [tgt_vocab['<sos>']] + tgt_ids + [tgt_vocab['<eos>']]
                
                self.src_data.append(src_ids)
                self.tgt_data.append(tgt_ids)
    
    def __len__(self):
        return len(self.src_data)
    
    def __getitem__(self, idx):
        return torch.tensor(self.src_data[idx]), torch.tensor(self.tgt_data[idx])


def collate_batch(batch, src_pad_idx, tgt_pad_idx):
    """批处理函数:填充序列到相同长度"""
    src_batch, tgt_batch = zip(*batch)
    
    # 获取最大长度
    src_len = max([len(s) for s in src_batch])
    tgt_len = max([len(t) for t in tgt_batch])
    
    # 填充
    src_padded = torch.stack([
        torch.cat([s, torch.full((src_len - len(s),), src_pad_idx)]) 
        for s in src_batch
    ])
    
    tgt_padded = torch.stack([
        torch.cat([t, torch.full((tgt_len - len(t),), tgt_pad_idx)]) 
        for t in tgt_batch
    ])
    
    return Batch(src_padded, tgt_padded, pad=src_pad_idx)


def train_transformer(model, train_dataset, val_dataset, src_vocab, tgt_vocab, 
                      num_epochs=10, batch_size=32, lr=0.1, warmup=4000):
    """训练Transformer模型"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # 创建数据加载器
    src_pad_idx = src_vocab['<pad>']
    tgt_pad_idx = tgt_vocab['<pad>']
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        collate_fn=lambda x: collate_batch(x, src_pad_idx, tgt_pad_idx)
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size,
        collate_fn=lambda x: collate_batch(x, src_pad_idx, tgt_pad_idx)
    )
    
    # 优化器和学习率调度器
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)
    lr_scheduler = NoamOpt(model.d_model, 2, warmup, optimizer)
    
    # 损失函数
    criterion = LabelSmoothing(tgt_vocab['<pad>'], smoothing=0.1)
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # 训练
        model.train()
        train_loss = run_epoch(train_loader, model, criterion, epoch, device)
        
        # 验证
        model.eval()
        with torch.no_grad():
            val_loss = run_epoch(val_loader, model, criterion, epoch, device)
        
        print(f'Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}')
        
        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_transformer.pt')
    
    return model
```

### 7.2 注意力可视化

```python
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(model, src_tokens, tgt_tokens, src_vocab, tgt_vocab, device):
    """可视化注意力权重"""
    model.eval()
    
    # 准备输入
    src_ids = torch.tensor([src_vocab[t] for t in src_tokens]).unsqueeze(0).to(device)
    tgt_ids = torch.tensor([tgt_vocab[t] for t in tgt_tokens]).unsqueeze(0).to(device)
    
    src_mask = (src_ids != src_vocab['<pad>']).unsqueeze(-2)
    tgt_mask = Batch.make_std_mask(tgt_ids, tgt_vocab['<pad>']).to(device)
    
    # 前向传播,收集注意力权重
    memory = model.encode(src_ids, src_mask)
    
    # 手动传播解码器层,收集注意力
    attentions = []
    x = model.tgt_embed(tgt_ids)
    x = model.decoder.norm(x)
    
    for layer in model.decoder.layers:
        # 自注意力
        x = layer.sublayer[0](x, lambda x: layer.self_attn(x, x, x, tgt_mask))
        # 交叉注意力,收集权重
        x = layer.sublayer[1](x, lambda x: layer.src_attn(x, memory, memory, src_mask))
        attentions.append(layer.src_attn.attn)
    
    # 选择最后一层的交叉注意力进行可视化
    attn_weights = attentions[-1][0, 0].cpu().detach().numpy()  # [tgt_len, src_len]
    
    # 绘图
    plt.figure(figsize=(12, 8))
    sns.heatmap(attn_weights, 
                xticklabels=src_tokens, 
                yticklabels=tgt_tokens,
                cmap='Blues',
                annot=True if len(src_tokens) < 20 else False,
                fmt='.2f')
    plt.xlabel('Source Tokens')
    plt.ylabel('Target Tokens')
    plt.title('Cross-Attention Weights')
    plt.tight_layout()
    plt.savefig('attention_visualization.png')
    plt.show()
```

## 第八章:总结与展望

### 8.1 Transformer的核心贡献回顾

Transformer架构的提出,不仅是序列建模技术上的一次飞跃,更是对人工智能范式的一次深刻重塑。其核心贡献可以总结为以下几点:

**彻底摒弃循环结构**:通过自注意力机制替代循环连接,解决了长距离依赖和并行化两大核心问题。这使得Transformer能够高效处理长序列,并且在大规模数据上展现出前所未有的扩展性。

**统一的序列建模框架**:Transformer提供了一个统一的框架,能够处理各种类型的序列数据——自然语言、图像、音频、时序数据等。这种统一性为多模态模型的发展奠定了基础。

**预训练-微调范式的成功**:以BERT和GPT为代表的预训练语言模型,证明了在大规模无监督数据上预训练、再在小规模监督数据上微调的模式,能够极大地提升模型性能。这种范式已经成为现代NLP的标配。

**规模法则的验证**:Transformer系列模型的发展历程验证了一个重要规律:随着模型规模、数据规模和计算规模的同步增长,模型性能呈现出可预测的、持续的提升。这为AI系统的规模化发展提供了理论依据。

### 8.2 认知科学视角的启示

回顾Transformer的发展历程,我们能看到认知科学与人工智能的深刻互动:

一方面,注意力机制的灵感来自认知科学对人类注意力的研究。Transformer将这种直觉转化为可计算的算法,并取得了巨大成功。

另一方面,Transformer的成功也反过来启发了认知科学研究。通过分析Transformer的注意力分布、表示空间和计算模式,研究者可以更好地理解人类认知系统的运作机制。例如,研究发现Transformer的语言表示在某些方面与人类大脑的语言处理模式存在相似性。

这种双向的互动关系预示着,未来人工智能与认知科学的交叉研究将更加深入。

### 8.3 未来的研究方向

尽管Transformer已经取得了巨大的成功,但仍然存在许多值得探索的方向:

**更高效的计算**:原始的O(n²)复杂度限制了Transformer处理超长序列的能力。虽然已经有了许多高效Transformer的变体,但在保持模型能力的同时进一步降低计算复杂度仍然是一个活跃的研究方向。

**更强的推理能力**:当前的Transformer主要学习统计模式,在需要符号推理、逻辑推理的任务上表现有限。如何将符号推理能力融入连接主义模型,是一个重要的问题。

**更深的认知架构**:人类认知系统包含多个相互作用的子系统——感知、记忆、推理、规划等。将Transformer扩展为更完整的认知架构,可能推动通用人工智能的发展。

**可解释性与可信性**:虽然注意力权重提供了一定的可解释性,但Transformer的内部工作机制仍然像一个黑箱。提高模型的可解释性和可信性,对于在医疗、金融等高风险领域的应用至关重要。

**持续学习与适应性**:当前的Transformer在训练后是固定的,缺乏持续学习和适应新环境的能力。如何让模型在部署后继续学习而不忘记已有知识,是一个具有挑战性的问题。

### 8.4 结语

Transformer从认知科学的注意力概念出发,经过巧妙的算法设计,发展成为现代人工智能的基石架构。它不仅在自然语言处理领域取得了革命性突破,还扩展到计算机视觉、语音处理、多模态学习等多个领域。

从0到1理解Transformer,不仅需要掌握其数学原理和代码实现,更需要理解其背后的认知科学思想——注意力作为一种选择性信息处理机制,如何从人类认知中被抽象、形式化,最终转化为可计算的算法。

未来,随着对Transformer理解的深入和技术的持续演进,我们有望见证更强大的AI系统的诞生。而认知科学与人工智能的交叉研究,将继续为这一进程提供重要的洞见和灵感。

正如Attention Is All You Need这篇论文的标题所暗示的,注意力不仅是一切所需,更是连接人类认知与机器智能的桥梁。理解这座桥梁,或许正是理解智能本质的关键一步。

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐