我们需要一个标量来衡量网络预测的好坏。对于回归或二分类问题,一个简单的选择是均方误差 (MSE) 损失。

# xs 是输入数据列表,ys 是对应的目标值列表
ypred = [n(x) for x in xs]
loss = sum((yout - ygt)**2 for ygt, yout in zip(ys, ypred))

损失值越低,表示预测越接近目标。

梯度下降训练循环

训练就是重复以下步骤:

  1. 前向传播:计算预测和损失。

  2. 反向传播:计算损失关于所有参数的梯度 (loss.backward())。

  3. 参数更新:沿着梯度反方向微调参数,以减小损失。

learning_rate = 0.01
for k in range(20): # 训练20步
    # 前向传播
    ypred = [n(x) for x in xs]
    loss = sum((yout - ygt)**2 for ygt, yout in zip(ys, ypred))

    # 反向传播
    for p in n.parameters():
        p.grad = 0.0 # 关键步骤:清零梯度
    loss.backward()

    # 更新参数(梯度下降)
    for p in n.parameters():
        p.data += -learning_rate * p.grad

    print(k, loss.data)

如果学习率设置得当,你应该会看到损失值随着训练步数增加而稳步下降,网络的预测 ypred 也会越来越接近目标 ys

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/b4725668b0b279be9a3c77be83499f5f_1.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/b4725668b0b279be9a3c77be83499f5f_3.png


总结 🎉

在本课程中,我们一起学习了神经网络训练的核心原理:

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/b4725668b0b279be9a3c77be83499f5f_5.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/b4725668b0b279be9a3c77be83499f5f_7.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/b4725668b0b279be9a3c77be83499f5f_8.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/b4725668b0b279be9a3c77be83499f5f_10.png

  1. 神经网络是数学表达式

构建GPT:从零开始,代码详解 🚀

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_1.png

概述

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_3.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_5.png

在本课程中,我们将从零开始构建一个类似ChatGPT的Transformer语言模型。我们将使用PyTorch,在一个小型数据集(Tiny Shakespeare)上训练一个字符级的语言模型,并最终生成莎士比亚风格的文本。通过这个过程,你将深入理解现代大型语言模型(如GPT)的核心工作原理。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_7.png


第1章:引言与背景

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_9.png

你可能已经听说过ChatGPT,它在全球和AI社区引起了巨大轰动。这是一个允许你与AI交互并给予其基于文本任务的系统。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_11.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_13.png

例如,我们可以要求ChatGPT为我们写一首关于理解AI重要性以及如何利用AI改善世界、使其更加繁荣的小俳句。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_15.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_17.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_19.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_21.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_23.png

当我们运行这个请求时,它生成了:“知识带来繁荣,人人可见。拥抱其力量。” 还不错。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_25.png

你可以看到ChatGPT从左到右,依次生成了所有这些单词。

我之前已经用完全相同的提示词问过它一次,它生成了一个略有不同的结果:“AI是成长的力量,忽视则倒退。学习,繁荣在等待。”

两种情况都相当不错,且略有不同。这表明ChatGPT是一个概率性系统,对于任何一个提示,它都可以给出多个答案作为回应。

这只是一个提示词的例子。人们已经想出了许多许多例子,甚至有整个网站专门索引与ChatGPT的互动记录,其中很多都相当幽默。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_27.png

例如,“像对狗解释一样向我解释HTL”,“为第二章写复习笔记”,“写一篇关于埃隆·马斯克收购推特的笔记”等等。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_29.png

再举一个例子:“请写一篇关于一片叶子从树上掉落的突发新闻文章。” 它生成了:“在一个令人震惊的事件中,当地公园的一棵树上一片叶子掉落了。目击者报告称,这片叶子之前附着在树枝上,后来自行脱落掉到了地上。非常戏剧性。”

你可以看到这是一个相当了不起的系统,我们称之为语言模型,因为它对单词、字符或更一般意义上的标记的序列进行建模,并且它知道英语中单词是如何相互跟随的。

从它的角度来看,它所做的是完成序列。我给它一个序列的开头,它用结果来完成这个序列。从这个意义上说,它是一个语言模型。

现在,我想关注ChatGPT工作背后的核心组件。那么,是什么神经网络在幕后对这些单词的序列进行建模呢?

这来自于一篇名为《Attention is All You Need》的论文。2017年,这篇具有里程碑意义的论文提出了Transformer架构

GPT是“Generative Pre-trained Transformer”(生成式预训练Transformer)的缩写。Transformer是真正在幕后完成所有繁重工作的神经网络,它来自2017年的这篇论文。

如果你阅读这篇论文,它读起来像一篇相当随机的机器翻译论文。我认为这是因为作者没有完全预见到Transformer对该领域的影响。他们在机器翻译背景下产生的这种架构,实际上在接下来的五年里接管了AI的其他领域。

因此,这种架构经过微小改动后,被复制粘贴到了近年来AI的大量应用中,这包括了ChatGPT的核心。

我们当然无法重现ChatGPT。这是一个非常严肃的生产级系统,它在互联网的很大一部分数据上进行训练,并且有许多预训练和微调阶段,非常复杂。

我想关注的是训练一个基于Transformer的语言模型。在我们的案例中,它将是一个字符级的语言模型。我仍然认为这对于理解这些系统的工作原理非常有教育意义。

我不想在互联网的大块数据上训练。我们需要一个更小的数据集。我建议使用我最喜欢的玩具数据集,它叫做Tiny Shakespeare

它基本上是莎士比亚所有作品的串联。据我理解,这就是单个文件中的所有莎士比亚作品。这个文件大约1MB,就是所有的莎士比亚。

我们现在要做的是,基本上对这些字符如何相互跟随进行建模。例如,给定这样一段字符,给定过去的一些字符上下文,Transformer神经网络将查看我高亮显示的字符,并预测序列中下一个字符很可能是“G”。

它会这样做,因为我们要在莎士比亚数据上训练这个Transformer,它将尝试生成看起来像这样的字符序列。在这个过程中,它将对这个数据中的所有模式进行建模。

一旦我们训练好系统,我想给你一个预览。我们可以生成无限的莎士比亚文本。当然,这是一个看起来有点像莎士比亚的假东西。

这里有一些我无法解决的干扰,但你可以看到它是如何一个字符一个字符地生成的,它有点像在预测莎士比亚风格的语言。

“诚然,我的主,景象已离开,国王再次带着我的诅咒和珍贵的苍白而来。” 然后特洛伊说了些别的,等等。这只是从Transformer中输出的,方式与ChatGPT非常相似。在我们的案例中是逐字符输出,在ChatGPT中是逐标记输出。标记是这些有点像子词片段的东西,所以它们不是单词级别,而是有点像单词块级别。

我已经写好了训练这些Transformer的完整代码,它在一个GitHub仓库中,叫做nanoGPT

nanoGPT是一个仓库,你可以在我的GitHub上找到,它是一个用于在任何给定文本上训练Transformer的仓库。

我认为它有趣的地方在于,虽然有很多训练Transformer的方法,但这是一个非常简单的实现。它只有两个文件,每个文件大约300行代码:一个文件定义了GPT模型(Transformer),另一个文件在给定的文本数据集上训练它。

这里我展示了如果你在一个开放的网页数据集(一个相当大的网页数据集)上训练它,那么我可以重现GPT-2的性能。GPT-2是OpenAI在2017年发布的GPT早期版本(如果我没记错的话)。到目前为止,我只重现了最小的1.24亿参数模型,但这基本上证明了代码库的排列是正确的,并且我能够加载OpenAI后来发布的神经网络权重。

你可以在这里的nanoGPT中查看完成的代码,但我想在这节课中做的是,基本上从头开始编写这个仓库。

所以我们将从一个空文件开始,然后一块一块地定义一个Transformer。我们将在Tiny Shakespeare数据集上训练它。然后我们将看到如何生成无限的莎士比亚文本。

当然,这可以复制粘贴到任何你喜欢的任意文本数据集上,但我在这里的目标是让你理解并欣赏ChatGPT在幕后是如何工作的。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_31.png

真正需要的只是对Python的熟练程度以及对微积分和统计学的基本理解。如果你也看过我在同一YouTube频道上的先前视频,特别是我的“Make More”系列,在那里我定义了更小更简单的神经网络语言模型(如多层感知机等),那会很有帮助。它真正介绍了语言建模框架。然后在这个视频中,我们将专注于Transformer神经网络本身。


https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_33.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_35.png

第2章:数据准备与标记化

好的,我在这里创建了一个新的Google Colab Jupyter笔记本。这将允许我稍后轻松地与你分享我们将要一起开发的代码,以便你可以跟着做,所以这个链接稍后会在视频描述中。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_37.png

现在,这里我已经做了一些准备工作。我从这个URL下载了数据集,即Tiny Shakespeare数据集,你可以看到它大约是一个1MB的文件。

然后我在这里打开input.txt文件,并将所有文本读入一个字符串。

我们看到我们正在处理大约100万个字符。如果我们只打印出前1000个字符,基本上就是你期望的样子。这是Tiny Shakespeare数据集的前1000个字符,大约到这里。

到目前为止一切顺利。接下来,我们将获取这个文本。在Python中,文本是一个字符序列。所以当我对它调用set构造函数时,我只会得到这个文本中出现的所有字符的集合。

然后我对它调用list来创建这些字符的列表,而不仅仅是一个集合,这样我就有了一个任意的排序。然后我对其进行排序。

基本上,我们得到了整个数据集中出现的所有字符,并且它们被排序了。现在,它们的数量将成为我们的词汇表大小。这些是我们序列的可能元素。

我们看到当我在这里打印字符时,总共有65个。有一个空格字符,然后是各种特殊字符,接着是大写和小写字母。这就是我们的词汇表,这是模型可以看到或发出的可能字符。

好的,接下来我们需要开发一些策略来对输入文本进行标记化

当人们说标记化时,他们指的是根据某种可能元素的词汇表,将原始文本字符串转换为某种整数序列。

例如,在这里我们将构建一个字符级语言模型,所以我们只需将单个字符转换为整数。

让我给你看一段代码,它为我们实现了这一点。我们同时构建编码器和解码器。

让我解释一下这里发生了什么。当我们编码一个任意文本,比如“hi there”,我们将收到一个代表该字符串的整数列表。例如,46,47等等。

然后我们也有反向映射,所以我们可以获取这个列表并将其解码,以得到完全相同的字符串。

所以这真的就像是对任意字符串进行整数转换和反向转换。对我们来说,这是在字符级别完成的。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_39.png

实现这一点的方式是,我们只是遍历这里的所有字符,并创建一个从字符到整数的查找表,反之亦然。然后要编码某个字符串,我们只需单独翻译所有字符;要将其解码回来,我们使用反向映射并将所有内容连接起来。

这只是众多可能的编码或标记化方案之一,而且是一个非常简单的方案。但在实践中,人们提出了许多其他方案。例如,谷歌使用SentencePiece。

SentencePiece也会将文本编码为整数,但使用不同的方案和不同的词汇表,SentencePiece是一种子词标记器。这意味着你不是编码整个单词,但也不是编码单个字符。它通常在子词单元级别。这通常是实践中采用的方案。例如,OpenAI也有这个叫做Tiktoken的库,它使用字节对编码标记器,这就是GPT使用的。

你也可以直接将单词编码为整数。例如,我在这里使用Tiktoken库。我获取了GPT-2使用的编码。他们不是只有65个可能的字符或标记,而是有50,257个标记。所以当他们编码完全相同的字符串“hi there”时,我们只得到一个包含三个整数的列表,但这些整数不是在0到64之间,而是在0到50,256之间。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_41.png

基本上,你可以在码本大小和序列长度之间进行权衡。你可以有非常长的序列和非常小的词汇表,或者你可以有非常短的序列和非常大的词汇表。因此,在实践中,人们通常使用这些子词编码,但我想保持我们的标记器非常简单,所以我们使用字符级标记器。

这意味着我们有非常小的码本,我们有非常简单的编码和解码函数。但结果是我们得到了非常长的序列。但这就是我们在这节课中要坚持的水平,因为它是最简单的事情。

好的,现在我们有了编码器和解码器,实际上就是一个标记器,我们可以对整个莎士比亚训练集进行标记化。

这是一段实现此功能的代码。我将开始使用PyTorch库,特别是来自PyTorch库的torch.tensor

我们将获取Tiny Shakespeare中的所有文本,对其进行编码,然后将其包装到torch.tensor中以获取数据张量。

这是数据张量的样子,当我只看前1000个字符或它的前1000个元素时。

我们看到我们有一个巨大的整数序列,这个整数序列基本上是前100个字符的相同翻译。例如,我相信0是一个换行符,也许1是一个空格,我不完全确定。但从现在开始,整个文本数据集被表示为一个单一的、非常长的整数序列。

在继续之前,让我再做一件事。我想将我们的数据集分成训练集和验证集。

具体来说,我们将取数据集的前90%作为Transformer的训练数据,我们将保留最后10%作为验证数据。

这将帮助我们理解模型在多大程度上过拟合。我们将基本上隐藏并保留验证数据。

因为我们不想要对这个确切的莎士比亚文本的完美记忆。我们想要一个能生成类似莎士比亚文本的神经网络。因此,它应该相当有可能产生实际的、被隐藏的、真正的莎士比亚文本。我们将用它来了解过拟合情况。


第3章:批处理与数据加载

好的,现在我们想开始将这些文本序列或整数序列输入到Transformer中,以便它可以训练并学习这些模式。

需要认识到的重要一点是,我们永远不会一次性将整个文本输入到Transformer中,这在计算上非常昂贵且不可行。

因此,当我们在大量数据集上训练Transformer时,我们只处理数据块。当我们训练Transformer时,我们基本上从训练集中随机抽取小块,并一次只在小块上训练。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_43.png

这些块基本上有一个长度,这是一个最大长度。通常,至少在我通常编写的代码中,这个最大长度被称为块大小

让我们从块大小仅为8开始。让我看看训练数据的前几个字符,前块大小+1个字符。我稍后会解释为什么是加一。

这是训练集中序列的前九个字符。

现在我想指出的是,当你像这样采样一个数据块时,比如从训练集中采样这九个字符,这实际上包含了多个打包在一起的例子。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_45.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_47.png

这是因为所有这些字符都是相互跟随的。所以当我们将其输入Transformer时,这个东西要说的是,我们将同时训练它在所有这些位置进行预测。

在一个包含9个字符的块中,实际上有8个独立的例子打包在里面。例如,在上下文为“18”时,“47”很可能接下来出现;在上下文为“18”和“47”时,“56”接下来出现;在上下文为“18, 47, 56”时,“57”接下来出现,依此类推。这就是那8个独立的例子。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_49.png

让我用代码来详细说明。这是一段代码来说明。x是Transformer的输入,它只是前块大小个字符。y是接下来的块大小个字符,所以它偏移了一个位置。这是因为y是输入中每个位置的目标。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_51.png

然后我在这里遍历所有的块大小。上下文始终是x中直到并包括t的所有字符,而目标始终是目标数组y中的第t个字符。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_53.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_55.png

让我运行这个。它基本上用代码阐述了我刚才说的话。这些是隐藏在我们从训练集中采样的9个字符块中的八个例子。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_57.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_58.png

我想再提一件事。我们在这里训练所有八个例子,上下文从1一直到块大小。我们这样做不仅仅是为了计算原因(因为我们恰好已经有了序列之类的原因),也不仅仅是为了效率。这样做也是为了让Transformer网络习惯于看到从少到多(从1一直到块大小)的各种上下文。我们希望Transformer习惯于看到所有情况。这在稍后的推理过程中会很有用,因为当我们进行采样生成时,我们可以从少到一个字符的上下文开始生成,然后Transformer知道如何仅用少到一个字符的上下文来预测下一个字符。然后它可以预测直到块大小的所有内容。在块大小之后,我们必须开始截断,因为Transformer在预测下一个字符时永远不会接收到超过块大小的输入。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_60.png

好的,我们已经研究了将要输入到Transformer的张量的时间维度。还有一个维度需要关心,那就是批次维度

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_62.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_64.png

当我们采样这些文本块时,实际上每次我们将它们输入Transformer时,我们都会有许多批次的多个文本块,它们都堆叠在一个张量中。这样做只是为了效率,以便我们可以让GPU保持忙碌,因为它们非常擅长数据的并行处理。所以我们只想同时处理多个块,但这些块是完全独立处理的,它们彼此不通信。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_66.png

让我基本上概括一下并引入批次维度。这是一段代码。让我运行它,然后我会解释它做了什么。

这里,因为我们将开始采样数据集中的随机位置来提取块,我设置了随机数生成器的种子,以便我在这里看到的数字与你稍后尝试重现时看到的数字相同。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_68.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_70.png

这里的batch_size是我们在Transformer的每次前向/后向传播中处理的独立序列的数量。block_size,正如我解释的,是进行这些预测的最大上下文长度。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_72.png

假设batch_size为4,block_size为8,然后这是我们获取批次的方式。对于任意的分割(训练或验证),如果分割是训练分割,那么我们将查看train_data,否则查看val_data。这给了我们数据数组。

然后当我生成随机位置以从中抓取一个块时,我实际上生成batch_size个随机偏移量。因为这是4,ix将是四个在0到数据长度 - 块大小之间随机生成的数字。这只是训练集中的随机偏移量。

然后x,正如我解释的,是从i开始的第一个块大小字符。y是它的偏移一个位置(即加一)。然后我们将为ix中的每个整数i获取这些块,并使用torch.stack将所有那些一维张量(就像我们在这里看到的那样)堆叠起来。它们都变成了一个4x8张量中的一行。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_74.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_76.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_78.png

然后这是我打印的内容。我采样了batch_xbatch_y。Transformer的输入x现在是4x8张量,4行8列,每一行都是训练集的一个块。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_80.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_82.png

然后相关的目标在数组y中,它们将在最后进入Transformer以创建损失函数,所以它们将为我们提供x中每个单个位置的正確答案。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_84.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_85.png

然后这些是四个独立的行。正如我们之前详细说明的那样,这个4x8数组总共包含32个例子,就Transformer而言,它们是完全独立的。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_87.png

所以,当输入是“24”时,目标是“43”,或者更准确地说,是y数组中的“43”。当输入是“24, 43”时,目标是“58”。当输入是“24, 43, 58”时,目标是“5”,依此类推。或者当它是“52, 58, 1”时,目标是“58”。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_89.png

所以你可以看到这个详细说明。这些是打包到单个输入批次x中的32个独立例子,然后期望的目标在y中。

现在,这个整数张量x将被输入到Transformer中。那个Transformer将同时处理所有这些例子,然后在张量y中的所有这些位置查找要预测的正确整数。


https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_91.png

第4章:构建基础模型(Bigram模型)

好的,现在我们有了想要输入到Transformer的输入批次,让我们开始将其输入到神经网络中。

我们将从最简单的神经网络开始,在我看来,对于语言建模来说,这就是Bigram语言模型。我在我的“Make More”系列中已经深入介绍了Bigram语言模型。

所以在这里我将进行得更快一些,让我们直接实现实现Bigram语言模型的PyTorch模块。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_93.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_95.png

我在这里导入PyTorch。为了可重现性。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_97.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_99.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_101.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_102.png

然后我在这里构建一个Bigram语言模型,它是nn.Module的子类。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_104.png

然后我调用它,并传入输入和目标。我只是打印一下。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_106.png

当输入和目标到达这里时。你看到我只是获取索引,这里的输入x(我重命名为idx),我只是将它们传入这个标记嵌入表

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_108.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_110.png

这里在构造函数中发生的是,我们正在创建一个标记嵌入表。它的大小是词汇表大小 x 词汇表大小。我们使用nn.Embedding,它基本上是一个形状为词汇表大小 x 词汇表大小的张量的非常薄的包装器。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_112.png

这里发生的是,当我们在这里传递idx时,我们输入中的每个整数都将引用这个嵌入表,并提取出与其索引对应的嵌入表的一行。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_114.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_116.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_118.png

所以这里的“24”将进入嵌入表并提取出第24行。“43”将进入这里并提取出第43行,等等。然后PyTorch将把所有这些东西排列成一个批次 x 时间 x 通道的张量。

在这种情况下,批次是4,时间是8,C(通道)是词汇表大小,即65。

所以我们只是提取所有这些行,将它们排列成B x T x C。现在我们将把它解释为逻辑值,这些基本上是序列中下一个字符的分数。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_120.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_122.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_124.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_126.png

这里发生的是,我们仅基于单个标记的身份来预测接下来会出现什么。你可以这样做,因为目前标记之间不相互通信,除了它们只看到自己之外,它们看不到任何上下文。我是一个标记编号5。然后我实际上可以通过知道我是标记5来对接下来可能出现的内容做出相当不错的预测,因为在典型场景中,某些字符知道跟随其他字符。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_128.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_129.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/bf8e317787d215eeac53dc1b82b1c757_131.png

我们在“Make More”系列中更深入地看到了很多这方面的内容。在这里,如果我运行这个,那么我们现在得到预测,分数,逻辑值。对于4x8个位置中的每一个。

语言建模的详解入门:构建 makemore:2:多层感知机 (MLP)

在本节课中,我们将继续构建 makemore 项目。上一节我们介绍了基于计数的二元语言模型,以及一个使用单层线性神经网络的简单实现。本节中,我们将探讨如何利用多层感知机模型来预测序列中的下一个字符,这种方法能更好地处理更长的上下文信息。

概述

上一节我们实现的模型仅考虑前一个字符来预测下一个字符。这种方法虽然简单,但预测效果不佳,因为上下文信息太少。如果我们尝试考虑更多上下文(例如前两个或三个字符),可能性组合的数量会呈指数级增长(例如,考虑三个字符时,有 27^3 = 19683 种可能的上下文组合),导致数据稀疏,模型无法有效学习。

为了解决这个问题,我们将实现一个多层感知机模型。该模型的灵感来源于 Bengio 等人在 2003 年发表的论文《A Neural Probabilistic Language Model》。该论文提出使用词嵌入(将单词映射为低维向量)和神经网络来预测下一个词,从而能够通过向量空间的相似性来泛化到未见过的上下文组合。

模型架构详解

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/24b592ee3726a6bddb17d0caaed5664a_1.png

论文中的模型是一个词级语言模型,但我们将遵循其架构思想,构建一个字符级模型。核心思想是为每个字符学习一个低维的嵌入向量,然后使用一个神经网络基于多个前序字符的嵌入向量来预测下一个字符。

以下是模型的关键组成部分:

  1. 嵌入层 (Embedding Layer):这是一个查找表 C,其大小为 [vocab_size, embedding_dim]vocab_size 是词汇表大小(对我们来说是 27 个字符),embedding_dim 是嵌入向量的维度(例如 2、10、30 等)。该层将输入的字符索引(整数)转换为对应的嵌入向量。

  2. 输入层 (Input Layer):我们将多个(例如 3 个)前序字符的嵌入向量拼接起来,形成一个更长的输入向量。如果嵌入维度是 2,上下文长度是 3,那么输入向量的长度就是 3 * 2 = 6

  3. 隐藏层 (Hidden Layer):这是一个全连接层,具有 hidden_size 个神经元(例如 100、200),并使用 tanh 作为激活函数。该层接收拼接后的嵌入向量作为输入。

  4. 输出层 (Output Layer):这是另一个全连接层,其神经元数量等于词汇表大小(27)。它接收隐藏层的输出,并产生每个可能的下一个字符的“得分”(logits)。

  5. Softmax 层:将输出层的 logits 通过 softmax 函数转换为概率分布,所有可能字符的概率之和为 1。

模型的参数包括嵌入矩阵 C、隐藏层的权重和偏置、以及输出层的权重和偏置。训练时,我们通过反向传播最大化训练数据中真实下一个字符的对数似然(即最小化交叉熵损失)。

代码实现步骤

现在,让我们一步步实现这个模型。

1. 准备数据集

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/24b592ee3726a6bddb17d0caaed5664a_3.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/24b592ee3726a6bddb17d0caaed5664a_5.png

首先,我们需要构建数据集。与之前不同,现在每个输入样本是多个(block_size 个)连续的字符,标签是紧随其后的那个字符。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/24b592ee3726a6bddb17d0caaed5664a_7.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/24b592ee3726a6bddb17d0caaed5664a_9.png

import torch
import matplotlib.pyplot as plt

# 读取数据
words = open('names.txt', 'r').read().splitlines()
print(f"总单词数: {len(words)}")
print(words[:8])

# 构建字符词汇表
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(f"词汇表: {itos}")
print(f"词汇表大小: {vocab_size}")

# 构建数据集函数
def build_dataset(words, block_size=3):
    X, Y = [], []
    for w in words:
        context = [0] * block_size # 初始用 '.' (索引0) 填充
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)       # 输入是当前上下文
            Y.append(ix)            # 标签是下一个字符
            context = context[1:] + [ix] # 滑动窗口,移出最旧的字符,加入新字符
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(f"构建了 {len(X)} 个样本")
    return X, Y

# 划分训练集、验证集和测试集
import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

Xtr, Ytr = build_dataset(words[:n1])   # 训练集 (80%)
Xdev, Ydev = build_dataset(words[n1:n2]) # 验证集 (10%)
Xte, Yte = build_dataset(words[n2:])    # 测试集 (10%)

2. 初始化模型参数

接下来,我们初始化模型的所有参数:嵌入矩阵、隐藏层和输出层的权重与偏置。

# 超参数
block_size = 3        # 上下文长度
embedding_dim = 10    # 每个字符的嵌入维度
hidden_size = 200     # 隐藏层神经元数量
learning_rate = 0.1
iterations = 100000

# 初始化参数
g = torch.Generator().manual_seed(2147483647) # 为了可复现性

# 嵌入层 C: [vocab_size, embedding_dim]
C = torch.randn((vocab_size, embedding_dim), generator=g)

# 隐藏层: 输入维度 = block_size * embedding_dim, 输出维度 = hidden_size
W1 = torch.randn((block_size * embedding_dim, hidden_size), generator=g) * 0.2
b1 = torch.randn(hidden_size, generator=g) * 0.01

# 输出层: 输入维度 = hidden_size, 输出维度 = vocab_size
W2 = torch.randn((hidden_size, vocab_size), generator=g) * 0.01
b2 = torch.randn(vocab_size, generator=g) * 0

# 将所有参数放入一个列表,方便管理
parameters = [C, W1, b1, W2, b2]
print(f"参数总数: {sum(p.nelement() for p in parameters)}")

# 告诉 PyTorch 这些张量需要计算梯度
for p in parameters:
    p.requires_grad = True

3. 训练循环

我们将使用小批量随机梯度下降来训练模型。在每个迭代中,我们从训练集中随机抽取一批样本,计算前向传播的损失,然后执行反向传播来更新参数。

# 记录损失用于绘图
lossi = []

for i in range(iterations):
    # 1. 构造一个小批量 (mini-batch)
    ix = torch.randint(0, Xtr.shape[0], (32,)) # 批量大小 32
    Xb, Yb = Xtr[ix], Ytr[ix]

    # 2. 前向传播
    # 嵌入层: 将输入索引转换为嵌入向量
    emb = C[Xb] # 形状: [batch_size, block_size, embedding_dim]
    # 将多个嵌入向量在最后一个维度之后拼接起来
    # 使用 `view` 来高效地改变形状,相当于拼接
    emb_concatenated = emb.view(emb.shape[0], -1) # 形状: [batch_size, block_size * embedding_dim]

    # 隐藏层 (带 tanh 激活)
    h = torch.tanh(emb_concatenated @ W1 + b1) # 形状: [batch_size, hidden_size]

    # 输出层 (logits)
    logits = h @ W2 + b2 # 形状: [batch_size, vocab_size]

    # 计算损失 (交叉熵)
    loss = F.cross_entropy(logits, Yb)

    # 3. 反向传播
    for p in parameters:
        p.grad = None # 将梯度置零
    loss.backward()

    # 4. 更新参数 (梯度下降)
    lr = learning_rate
    if i > iterations // 2: # 后半程学习率衰减
        lr = 0.01
    for p in parameters:
        p.data += -lr * p.grad

    # 记录损失
    lossi.append(loss.log10().item())

# 绘制训练损失曲线
plt.plot(lossi)
plt.title('Training Loss (log10 scale)')
plt.show()

4. 模型评估

训练完成后,我们在训练集和验证集上评估模型的损失,以检查是否过拟合或欠拟合。

@torch.no_grad() # 在此上下文中不计算梯度,节省内存和计算
def evaluate_loss(X, Y):
    emb = C[X]
    emb_concatenated = emb.view(emb.shape[0], -1)
    h = torch.tanh(emb_concatenated @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y)
    return loss.item()

train_loss = evaluate_loss(Xtr, Ytr)
dev_loss = evaluate_loss(Xdev, Ydev)
print(f'训练集损失: {train_loss:.4f}')
print(f'验证集损失: {dev_loss:.4f}')

5. 从模型生成样本

最后,我们可以使用训练好的模型来生成新的“名字”。

# 生成样本
g = torch.Generator().manual_seed(2147483647 + 10)
for _ in range(20):
    out = []
    context = [0] * block_size # 以 '.' 开始
    while True:
        # 前向传播
        emb = C[torch.tensor([context])] # 注意增加批次维度
        emb_concatenated = emb.view(1, -1)
        h = torch.tanh(emb_concatenated @ W1 + b1)
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=1)
        # 根据概率分布采样下一个字符
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        # 更新上下文
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0: # 遇到 '.' 结束
            break
    # 将索引解码为字符并打印
    print(''.join(itos[i] for i in out))

总结

本节课中,我们一起学习了如何构建一个基于多层感知机的字符级语言模型。我们首先分析了仅考虑短上下文的局限性,然后引入了 Bengio 2003 年论文中的思想,即使用嵌入向量和神经网络来捕获更长的依赖关系并实现泛化。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/24b592ee3726a6bddb17d0caaed5664a_11.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/24b592ee3726a6bddb17d0caaed5664a_12.png

我们详细实现了模型的各个部分:嵌入层、隐藏层和输出层,并使用小批量随机梯度下降进行训练。我们还介绍了如何划分数据集以进行正确的模型评估,以及如何从训练好的模型中采样生成新的序列。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/24b592ee3726a6bddb17d0caaed5664a_14.png

通过调整嵌入维度、隐藏层大小、上下文长度和优化超参数,你可以进一步降低损失,生成更逼真的“名字”。这个简单的 MLP 模型为我们理解更复杂的现代语言模型(如 RNN、LSTM 和 Transformer)奠定了重要的基础。

语言建模的详解入门:构建 makemore:P3:激活与梯度,批归一化 🧠

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_1.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_2.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_4.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_6.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_7.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_8.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_10.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_12.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_14.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_15.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_17.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_19.png

在本节课中,我们将继续构建 makemore 项目,并深入探讨神经网络训练中的两个核心概念:激活值梯度。我们将分析它们在训练过程中的行为,并学习如何通过权重初始化批归一化等技术来稳定和优化深度神经网络的训练。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_20.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_22.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_24.png


https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_26.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_28.png

概述:为什么需要关注激活与梯度?

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_30.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_32.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_34.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_36.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_38.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_40.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_42.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_44.png

上一节我们实现了一个用于字符级语言建模的多层感知机。在转向更复杂的循环神经网络之前,我们需要对神经网络内部的激活值和反向传播的梯度有更深入的理解。理解它们的分布和行为,对于理解为何某些网络结构难以优化以及后续的改进技术至关重要。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_46.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_48.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_50.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_52.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_54.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_56.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_58.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_59.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_61.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_63.png


https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_65.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_67.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_69.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_70.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_72.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_74.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_76.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_78.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_80.png

1. 审视初始化的损失值 🔍

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_82.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_84.png

首先,我们审视当前模型的初始化状态。在第一次迭代时,我们记录到的损失值高达 27,这远高于我们的预期。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_86.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_88.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_90.png

对于一个有 27 个可能字符的分类问题,在初始化时,网络没有理由偏向任何一个字符。因此,我们希望输出是一个均匀分布,每个字符的概率约为 1/27。对应的期望损失(负对数似然)应为:

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_92.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_94.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_95.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_97.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_99.png

import torch
expected_loss = -torch.log(torch.tensor(1/27.0))
# 输出:tensor(3.2958)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_101.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_103.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_104.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_106.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_107.png

初始损失 27 意味着网络在初始化时就“过于自信地犯错”,输出概率分布极不均匀。这通常是由于输出层的权重 W2 和偏置 B2 初始化不当,导致 logits 值过大。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_109.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_110.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_112.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_114.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_116.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_118.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_120.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_121.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_123.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_125.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_126.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_128.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_130.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_132.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_134.png

解决方案:我们应缩小输出层权重 W2 的尺度,并将偏置 B2 初始化为零或接近零的小值。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_136.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_138.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_140.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_142.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_143.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_145.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_147.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_149.png

# 改进的初始化
W2 = torch.randn(n_hidden, vocab_size) * 0.01  # 缩小权重
B2 = torch.zeros(vocab_size)                    # 偏置置零

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_151.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_153.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_155.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_157.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_159.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_161.png

调整后,初始损失会降至接近 3.3 的合理范围,避免了训练初期不必要的“权重压缩”阶段,使优化更高效。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_163.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_165.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_167.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_168.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_170.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_172.png


https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_174.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_176.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_178.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_180.png

2. 隐藏层激活的饱和问题 ⚠️

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_182.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_183.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_185.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_187.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_188.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_189.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_191.png

即使输出层初始化正确,隐藏层的激活值也可能存在问题。我们使用 tanh 作为激活函数,其输出范围在 (-1, 1)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_193.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_195.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_197.png

在初始化时,如果输入到 tanh 的预激活值 H_preact 过大(例如,绝对值远大于 0),tanh 的输出就会饱和在 -11 附近。这会导致一个严重问题:在反向传播时,梯度会消失。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_199.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_201.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_203.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_205.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_206.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_208.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_210.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_211.png

梯度消失的原理

tanh 的导数为 1 - tanh(x)^2。当 tanh(x) 接近 ±1 时,其导数接近 0。这意味着,对于饱和的神经元,其权重和偏置几乎无法通过梯度下降进行更新。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_213.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_214.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_215.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_217.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_219.png

我们可以通过直方图检查隐藏层激活 H 的分布:

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_221.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_223.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_225.png

import matplotlib.pyplot as plt
plt.hist(H.view(-1).tolist(), bins=50)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_227.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_229.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_231.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_233.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_235.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_237.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_239.png

如果大量激活值集中在 -11 附近,就表明存在饱和问题。更严重的是,如果某个神经元对所有样本都饱和(即“死亡神经元”),它将永远无法学习。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_241.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_242.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_244.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_246.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_248.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_250.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_252.png

解决方案:同样,我们需要调整第一层权重 W1 和偏置 B1 的初始化尺度,使 H_preact 的分布更集中,避免进入 tanh 的饱和区。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_254.png

# 改进的初始化
W1 = torch.randn(block_size * n_embd, n_hidden) * 0.2  # 适当缩小
B1 = torch.zeros(n_hidden) * 0.01                      # 小偏置引入熵

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_256.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_258.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_260.png

通过调整,H_preact 的分布会更接近均值为 0、标准差适中的高斯分布,tanh 的激活值将更多地落在其敏感的非饱和区域。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_262.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_264.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_266.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_268.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_270.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_272.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_274.png


https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_276.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_278.png

3. 权重初始化的数学原理 📐

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_280.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_282.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_283.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_285.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_286.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_287.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_289.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_290.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_292.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_294.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_296.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_297.png

对于深度网络,手动调整每个层的缩放因子是不现实的。我们需要一个系统化的初始化方法。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_299.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_301.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_303.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_305.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_307.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_308.png

考虑一个简单的线性层:y = x @ W,其中 xW 的元素都采样自标准正态分布 N(0, 1)。如果 x 的维度(fan_in)为 n,那么 y 中每个元素的方差将是 n(因为它是 n 个独立方差为 1 的变量之和)。因此,y 的标准差是 sqrt(n)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_310.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_312.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_313.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_315.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_317.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_319.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_321.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_323.png

为了使 y 的方差保持为 1(即保持激活尺度稳定),我们需要将权重 W 的尺度缩放 1 / sqrt(fan_in)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_325.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_326.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_328.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_329.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_331.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_332.png

Kaiming 初始化

对于使用 ReLU 等激活函数的网络,由于它们会“丢弃”一半的分布(负值置零),需要进行补偿。Kaiming He 等人提出的初始化方法将权重的标准差设为 sqrt(2 / fan_in)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_334.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_336.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_338.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_340.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_341.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_343.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_344.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_345.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_347.png

对于 tanh 激活函数,PyTorch 中使用的增益因子是 5/3。因此,初始化公式为:

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_349.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_350.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_352.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_354.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_356.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_357.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_358.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_360.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_361.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_363.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_364.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_365.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_367.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_369.png

gain = 5/3  # tanh 的推荐增益
std = gain / math.sqrt(fan_in)
W = torch.randn(fan_in, fan_out) * std

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_371.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_372.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_373.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_375.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_376.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_378.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_380.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_382.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_384.png

这种半原则性的方法可以扩展到更深层的网络,无需手动调参。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_386.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_388.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_389.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_390.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_392.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_394.png


https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_396.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_398.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_399.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_401.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_403.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_405.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_406.png

4. 批归一化:自动控制激活分布 🛡️

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_408.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_409.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_411.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_413.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_415.png

尽管精心设计的初始化有助于稳定训练,但对于非常深的网络,精确控制所有层的激活分布仍然非常困难。批归一化 应运而生,它通过一个可微操作,在训练过程中主动将激活值标准化。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_417.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_418.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_420.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_422.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_424.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_426.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_428.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_430.png

批归一化的核心思想

如果希望某层的输入(即前一层的激活)是均值为 0、方差为 1 的高斯分布,为什么不直接对每个小批量的数据进行标准化呢?

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_432.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_434.png

批归一化层的操作

对于一个输入张量 x(形状为 [batch_size, features]):

  1. 计算当前小批量的均值 μ_B 和方差 σ_B^2

  2. 对输入进行标准化:x_hat = (x - μ_B) / sqrt(σ_B^2 + ε),其中 ε 是一个防止除零的小常数。

  3. 引入可学习的缩放参数 γ 和偏移参数 βy = γ * x_hat + β

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_436.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_437.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_439.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_441.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_443.png

在训练时,使用当前批次的统计量。同时,该层会维护一个运行均值运行方差,通过指数移动平均在训练过程中更新。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_444.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_445.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_447.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_448.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_450.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_452.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_454.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_455.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_457.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_458.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_460.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_461.png

在推理(或测试)时,则使用训练过程中积累的运行统计量,而不是当前批次的统计量。这使得网络可以对单个样本进行前向传播。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_463.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_465.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_467.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_469.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_471.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_472.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_474.png

批归一化的优势与代价

  • 优势:极大地稳定了深度网络的训练,对初始化不那么敏感,并且由于引入了批次内的噪声,还具有一定的正则化效果。

  • 代价:它耦合了批次内样本的计算。一个样本的输出会受到同批次其他样本的影响,这有时会导致意想不到的 bug,并使推理逻辑复杂化。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_476.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_477.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_478.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_479.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_481.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_482.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_484.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_486.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_488.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_490.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_492.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_494.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_495.png

代码示例

# 训练阶段
bn_mean = x.mean(dim=0, keepdim=True)
bn_var = x.var(dim=0, keepdim=True, unbiased=False)
x_hat = (x - bn_mean) / torch.sqrt(bn_var + 1e-5)
y = bn_gain * x_hat + bn_bias

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_497.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_498.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_499.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_500.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_502.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_504.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_505.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_507.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_509.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_511.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_513.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_514.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_515.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_516.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_517.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_519.png>

# 更新运行统计量 (动量更新)
running_mean = 0.999 * running_mean + 0.001 * bn_mean
running_var = 0.999 * running_var + 0.001 * bn_var

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_521.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_523.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_525.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_527.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_529.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_531.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_532.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_533.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_535.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_536.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_538.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_540.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_542.png>

<https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_544.png>

# 推理阶段
x_hat = (x - running_mean) / torch.sqrt(running_var + 1e-5)
y = bn_gain * x_hat + bn_bias

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_546.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_548.png

使用注意:当线性层后接批归一化层时,线性层的偏置 bias 是多余的,因为它会在归一化时被减去。因此,通常将线性层的 bias=False,让批归一化层的 β 参数负责偏移。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_550.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_551.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_552.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_554.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_556.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_558.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_560.png


https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_562.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_564.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_566.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_568.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_570.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_571.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_573.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_575.png

5. 诊断工具:监控训练健康度 📊

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_577.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_579.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_581.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_583.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_585.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_587.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_589.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_591.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_593.png

在训练神经网络时,监控以下统计量对于诊断问题至关重要:

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_595.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_597.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_599.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_601.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_602.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_604.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_606.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_608.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_610.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_612.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_614.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_615.png

  1. 前向传播激活直方图:检查各层激活值是否过度饱和或过度稀疏。

  2. 反向传播梯度直方图:检查梯度是否消失或爆炸。

  3. 参数更新比率:计算 (学习率 * 梯度).std() / 参数.std() 的对数。这个比率反映了参数相对其当前值的变化幅度。一个经验法则是,这个比率在 1e-3 左右(即对数尺度下约为 -3)是比较健康的。如果远大于此,可能学习率太高;如果远小于此,可能学习率太低。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_617.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_619.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_620.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_621.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_623.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_625.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_627.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_629.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_631.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_633.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_634.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_635.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_636.png

通过绘制这些比率随训练迭代的变化曲线,可以直观地判断网络各层的学习速度是否均衡,以及学习率设置是否合适。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_638.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_639.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_641.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_643.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_645.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_647.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_649.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_651.png


https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_653.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_655.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_657.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_659.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_661.png

总结 🎯

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_663.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_665.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_667.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_669.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_671.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_672.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_674.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_675.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_677.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_679.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_681.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_683.png

本节课我们一起深入探讨了神经网络训练的内部机制:

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_685.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_686.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_687.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_689.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_690.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_692.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_693.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_695.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_696.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_697.png

  1. 初始化的重要性:不恰当的初始化会导致输出层过于自信、隐藏层饱和,从而使得训练初期低效,甚至阻碍网络学习。我们学习了如何通过缩放权重来获得合理的初始激活分布。

  2. 梯度流与饱和:我们分析了 tanhReLU 等激活函数在饱和区会导致梯度消失的问题,理解了“死亡神经元”的概念。

  3. 系统化初始化:介绍了基于方差守恒原则的 Kaiming 初始化方法,它为我们提供了一种无需手动调参的、可扩展的初始化策略。

  4. 批归一化:作为一种强大的技术,批归一化通过在网络内部显式地标准化激活值,极大地稳定了深度网络的训练,降低了对初始化的精细要求。我们也了解了其内部机制和使用时的注意事项。

  5. 训练诊断:我们学习了一套实用的工具,用于监控前向激活、反向梯度以及参数更新比率,这些工具能帮助我们判断网络是否处于健康训练状态。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_699.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_701.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_702.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_703.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_704.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_705.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_706.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_708.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_709.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_711.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_712.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/88dd9899d3a9ef459d32c32418af4090_714.png

理解激活和梯度的行为是掌握深度学习的基石。虽然现代技术(如批归一化、残差连接、高级优化器)让训练变得更加鲁棒,但底层原理依然至关重要,尤其是在设计和调试新模型架构时。在接下来的课程中,我们将把这些知识应用于更复杂的循环神经网络。

语言建模的详解入门:4:成为反向传播高手 🥷

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_1.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_2.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_4.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_6.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_7.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_9.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_11.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_13.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_15.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_17.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_18.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_19.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_20.png

在本节课中,我们将学习如何手动实现神经网络的反向传播,摆脱对 PyTorch loss.backward() 的依赖。通过亲手计算张量级别的梯度,我们将深入理解反向传播的内部机制,这对于调试神经网络和优化模型至关重要。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_22.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_24.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_26.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_27.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_29.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_30.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_31.png

概述

到目前为止,我们已经实现了一个包含批归一化的两层多层感知机,并获得了不错的损失值。然而,我们一直依赖 PyTorch 的自动微分来计算梯度。本节课程的目标是移除 loss.backward() 的调用,手动在张量级别实现反向传播。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_33.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_34.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_36.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_37.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_39.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_40.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_41.png

理解反向传播的内部原理非常重要,因为它是一个“有漏洞的抽象”。仅仅堆叠各种函数模块并指望反向传播自动工作是不够的。如果不理解其内部机制,可能会遇到梯度消失、爆炸或死神经元等问题,并且难以进行有效的调试。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_43.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_44.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_45.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_47.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_49.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_50.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_52.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_53.png

历史背景与动机

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_55.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_56.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_57.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_59.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_61.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_63.png

大约十年前,手动编写反向传播是深度学习的标准做法。如今,虽然自动微分已成为主流,但手动实现反向传播仍然是一项极具价值的练习。它能让你对神经网络的工作原理有更直观、更深刻的理解,使你成为更强大的神经网络调试者和构建者。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_65.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_67.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_68.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_70.png

练习设置

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_72.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_73.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_75.png

我们将通过一系列练习来逐步实现完整的反向传播:

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_77.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_79.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_81.png

  1. 逐原子操作反向传播:将前向传播分解为最小的原子操作,并手动为每个中间变量计算梯度。

  2. 交叉熵损失的高效反向传播:使用数学推导,直接计算损失函数对 logits 的梯度,避免繁琐的逐层回传。

  3. 批归一化的高效反向传播:类似地,为批归一化层推导并实现一个高效、紧凑的梯度计算公式。

  4. 整合训练循环:将手动计算的反向传播代码整合到完整的训练循环中,验证其效果。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_83.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_84.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_86.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_88.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_90.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_92.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_93.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_94.png

接下来,我们将开始第一个练习,手动反向传播整个计算图。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_95.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_97.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_98.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_99.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_101.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_103.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_104.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_105.png


https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_107.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_109.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_111.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_112.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_114.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_116.png

练习一:逐原子反向传播 🔬

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_118.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_119.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_121.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_123.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_125.png

上一节我们概述了目标,本节我们将从计算图的末端开始,一步步手动计算每个中间变量的梯度。我们首先关注损失函数对 logprobs 的梯度。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_127.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_129.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_131.png

计算 Dlogprobs

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_133.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_135.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_137.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_139.png

logprobs 是一个形状为 (32, 27) 的张量。损失函数是正确字符对数概率的负平均值。因此,只有被 y_b 索引到的那些 logprobs 元素对损失有贡献。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_141.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_142.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_143.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_145.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_147.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_148.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_150.png

公式推导

对于被选中的元素,其梯度为 -1/n(其中 n 是批量大小,此处为 32)。对于其他未被选中的元素,梯度为 0。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_152.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_153.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_155.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_157.png

代码实现

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), yb] = -1.0 / n

通过比较,我们验证了手动计算的 dlogprobs 与 PyTorch 自动计算的 logprobs.grad 完全一致。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_159.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_161.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_162.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_164.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_166.png

计算 Dprobs

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_168.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_170.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_172.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_174.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_176.png

logprobsprobs 经过逐元素对数运算得到。根据微积分,d(ln(x))/dx = 1/x

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_178.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_179.png

公式推导

dprobs = (1 / probs) * dlogprobs

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_181.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_183.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_185.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_187.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_189.png

代码实现

dprobs = (1 / probs) * dlogprobs

验证通过。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_191.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_193.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_195.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_196.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_197.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_199.png

计算 Dcounts_sum_inv 和 Dcounts

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_201.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_203.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_205.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_207.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_209.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_211.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_213.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_215.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_217.png

probscounts 除以 counts_sum(并经过广播)得到。这里涉及乘法和广播操作的反向传播。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_219.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_221.png

公式推导

  • 对于乘法 probs = counts * counts_sum_invcounts_sum_inv 的梯度需要从 dprobs 获取,并由于广播(重用)而在第0维(样本维)求和。

  • counts 的梯度来自两条路径:一条直接来自 dprobs,另一条通过 counts_sumcounts_sum_inv 间接传来。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_223.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_224.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_226.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_228.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_230.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_232.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_234.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_236.png

代码实现

# 梯度流向 counts_sum_inv (广播前的变量)
dcounts_sum_inv = (counts * dprobs).sum(0, keepdim=True)
# 梯度流向 counts (第一条路径)
dcounts = counts_sum_inv * dprobs
# 后续还会通过 counts_sum 分支补充梯度

验证了 dcounts_sum_inv 的正确性。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_237.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_238.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_240.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_242.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_244.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_246.png

计算 Dcounts_sum 和补充 Dcounts

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_248.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_250.png

counts_sum_invcounts_sum-1 次幂。counts_sumcounts 在第1维(特征维)的和。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_252.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_254.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_256.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_257.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_259.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_261.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_263.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_265.png

公式推导

  • d(x^-1)/dx = -x^-2,因此 dcounts_sum = -counts_sum**-2 * dcounts_sum_inv

  • 求和操作的反向传播是广播:counts_sum 的梯度被复制到 counts 的每个贡献元素上。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_267.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_269.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_271.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_273.png

代码实现

# 梯度流向 counts_sum
dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv
# 梯度通过求和操作流向 counts (第二条路径)
dcounts += torch.ones_like(counts) * dcounts_sum

验证通过,并且补充第二条路径后,dcounts 也正确了。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_275.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_276.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_278.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_280.png

计算 Dnormlogits

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_282.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_284.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_285.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_287.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_289.png

countsnormlogits 的逐元素指数。d(e^x)/dx = e^x

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_291.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_293.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_294.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_296.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_298.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_299.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_301.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_303.png

公式推导

dnormlogits = counts * dcounts

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_305.png

代码实现

dnormlogits = counts * dcounts

验证通过。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_307.png

计算 Dlogits 和 Dlogitmax

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_309.png

normlogits = logits - logitmax,其中 logitmaxlogits 每行的最大值,会进行广播。

公式推导

  • dlogits 直接接收来自 dnormlogits 的梯度。

  • dlogitmax 接收来自 dnormlogits 的梯度,并由于广播而在第1维求和(带负号)。

  • logitmax 本身又由 logits 计算而来,因此 logits 还会从 logitmax 接收第二条路径的梯度。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_311.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_313.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_315.png

代码实现

# 第一条路径:来自减法的梯度
dlogits = dnormlogits.clone()
dlogitmax = (-dnormlogits).sum(1, keepdim=True)
# 第二条路径:来自最大值操作的梯度
# 使用 one-hot 编码将梯度散射回最大值出现的位置
dlogits += F.one_hot(logitmax_idx, logits.shape[1]) * dlogitmax

验证通过。值得注意的是,logitmax 的梯度理论上应为0(因为平移不影响 softmax),实际数值也极小(~1e-9),符合预期。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_317.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_319.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_320.png

计算 Dh, Dw2, Db2 (第一个线性层)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_322.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_324.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_325.png

logits = h @ w2 + b2。这是矩阵乘法和加法的反向传播。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_327.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_329.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_331.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_333.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_334.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_335.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_336.png

公式推导(通过维度匹配法)

  • dh 的形状必须与 h (32, 64) 相同。dlogits(32, 27)w2(64, 27)。唯一能使维度匹配的运算是:dh = dlogits @ w2.T

  • dw2 的形状是 (64, 27)h(32, 64)。唯一匹配的运算是:dw2 = h.T @ dlogits

  • db2 的形状是 (27,)。加法操作的梯度是求和,需要消除批量维:db2 = dlogits.sum(0)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_338.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_340.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_342.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_344.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_346.png

代码实现

dh = dlogits @ w2.T
dw2 = h.T @ dlogits
db2 = dlogits.sum(0)

验证通过。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_348.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_350.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_352.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_354.png

计算 Dhpreact (Tanh 激活层)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_356.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_357.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_358.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_360.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_362.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_364.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_365.png

h = tanh(hpreact)d(tanh(x))/dx = 1 - tanh(x)^2

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_367.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_369.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_371.png

公式推导

dhpreact = (1 - h**2) * dh

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_373.png

代码实现

dhpreact = (1 - h**2) * dh

验证通过。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_375.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_377.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_379.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_381.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_383.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_385.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_387.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_389.png

计算 Dbngain, Dbias, Dbnraw (批归一化缩放偏移)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_391.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_392.png

hpreact = bnraw * bngain + bnbias,涉及逐元素乘法和广播。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_394.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_396.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_398.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_400.png

公式推导

  • dbngain:来自乘法,梯度为 bnraw * dhpreact,并由于 bngain 的广播而在第0维求和。

  • dbnraw:来自乘法,梯度为 bngain * dhpreact(广播自动处理)。

  • dbnbias:来自加法,梯度为 dhpreact,并由于广播而在第0维求和。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_402.png

代码实现

dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
dbnraw = bngain * dhpreact
dbnbias = dhpreact.sum(0, keepdim=True)

验证通过。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_404.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_406.png

计算 Dbnvar_inv 和 Dbnvar

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_408.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_409.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_411.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_412.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_414.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_416.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_417.png

bnraw = bndiff * bnvar_invbnvar_inv = (bnvar + eps)**-0.5

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_419.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_421.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_423.png

公式推导

  • dbndiff (第一条路径):dbndiff = bnvar_inv * dbnraw

  • dbnvar_invdbnvar_inv = bndiff * dbnraw,并求和(因为 bnvar_inv 是标量广播而来?这里需要根据形状仔细处理,bnvar_inv(1,64),广播到 (32,64),因此梯度需要沿第0维求和)。

  • dbnvard(x^-0.5)/dx = -0.5 * x^-1.5,所以 dbnvar = -0.5 * (bnvar+eps)**-1.5 * dbnvar_inv

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_425.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_427.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_429.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_431.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_433.png

代码实现

# 第一条路径流向 bndiff
dbndiff = bnvar_inv * dbnraw
# 流向 bnvar_inv
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
# 流向 bnvar
dbnvar = (-0.5 * (bnvar+eps)**-1.5) * dbnvar_inv

验证了 dbnvar_invdbnvar 的正确性。dbndiff 此时还不完整,因为 bndiff 还有另一条输入路径。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_435.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_437.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_439.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_441.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_442.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_444.png

计算 Dbndiff2 和补充 Dbndiff

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_446.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_448.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_450.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_452.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_454.png

bnvar = (bndiff2).sum(0) / (n-1)bndiff2 = bndiff**2

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_456.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_458.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_460.png

公式推导

  • dbndiff2:求和的反向传播是广播。dbndiff2 = (1/(n-1)) * torch.ones_like(bndiff2) * dbnvar(利用广播)。

  • dbndiff (第二条路径):d(x^2)/dx = 2x,所以 dbndiff += 2 * bndiff * dbndiff2

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_462.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_464.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_466.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_468.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_470.png

代码实现

# 流向 bndiff2
dbndiff2 = (1.0/(n-1)) * torch.ones_like(bndiff2) * dbnvar
# 第二条路径流向 bndiff
dbndiff += 2 * bndiff * dbndiff2

验证通过,现在 dbndiff 也正确了。

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_472.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_474.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_476.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_478.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_480.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_482.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_484.png

计算 Dhprebn 和 Dbndiff_mean (均值减法)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_486.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_487.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_489.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_491.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_493.png

bndiff = hprebn - bndiff_meanbndiff_mean = hprebn.mean(0, keepdim=True)

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_495.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_497.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_499.png

公式推导

  • dhprebn (第一条路径):直接来自 dbndiff,梯度为 1 * dbndiff

  • dbndiff_mean:来自 dbndiff,梯度为 -dbndiff.sum(0, keepdim=True)

  • hprebn (第二条路径):来自 bndiff_mean。均值的反向传播是广播。dhprebn += (1/n) * torch.ones_like(hprebn) * dbndiff_mean

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_501.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_503.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_505.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_507.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_509.png

https://github.com/OpenDocCN/dsai-notes-pt1-zh/raw/master/docs/andrej/img/a3d21168dc6b383673c85535a01b0c8f_511.png

代码实现

# 第一条路径
dhprebn = dbndiff.clone()
dbndiff_mean = (-dbndiff).sum(0, keepdim=True)
# 第二条路径
dhprebn += (1.0/n) * torch.ones_like(hprebn) * dbndiff_mean
Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐