作者:Basil Hosmer

使用 3D 可视化矩阵乘法表达式、具有真实权重的注意力头等。

矩阵乘法(matmuls)是当今机器学习模型的基本组成部分。本文介绍了一个可视化工具 mm,用于可视化矩阵乘法和矩阵乘法的组合。

矩阵乘法本质上是一个三维运算。由于 mm 使用了所有三个空间维度,因此它可以比通常的“纸上画方块”方法更清晰、更直观地传达意义,尤其是对于视觉/空间思维者而言(尽管不限于此)。

我们还有空间以几何一致的方式组合矩阵乘法——因此我们可以使用与简单表达式相同的规则可视化像注意力头和 MLP 层这样的大型复合结构。更高级的功能,如动画展示不同的矩阵乘法算法、并行分区以及加载外部数据来探索实际模型的行为,都自然地建立在这个基础上。

mm 完全交互式,可以在浏览器中运行,并将其完整状态保存在 URL 中,因此链接即为可共享的会话(本文中的截图和视频都有链接,可以直接在工具中打开相应的可视化)。本参考指南描述了所有可用功能。

我们将首先介绍可视化方法,通过可视化一些简单的矩阵乘法和表达式来建立直观理解,然后深入探讨一些更扩展的例子。

  1. 亮点 - 为什么这种可视化方法更好?
  2. 热身 - 动画 - 观看典型的矩阵乘法分解过程
  3. 热身 - 表达式 - 快速浏览一些基本的表达式构建块
  4. 注意力头内部 - 通过 NanoGPT 深入了解 GPT2 中一些注意力头的结构、值和计算行为
  5. 并行化注意力 - 通过最近的 Blockwise Parallel Transformer 论文中的例子,可视化注意力头的并行化
  6. 注意力层中的尺寸 - 当我们将整个注意力层可视化为一个单一结构时,MHA 和 FFA 两半看起来是什么样的?在自回归解码期间,图像会如何变化?
  7. LoRA - 对这种注意力头架构改进的可视化解释
  8. 总结 - 下一步和征集反馈

1 亮点

mm 的可视化方法基于这样一个前提:矩阵乘法本质上是一个三维运算

换句话说,这个(二维图)

matrix multiplication is fundamentally a three-dimensional operation

是在努力表现这个(三维图)(在 mm 中打开)

wrap the matmul around a cube

当我们以这种方式将矩阵乘法包裹在一个立方体周围时,参数形状、结果形状和共享维度之间的正确关系就自然到位了。

现在计算有了几何意义:结果矩阵中的每个位置 i, j 都锚定了一个沿立方体内部深度维度 k 延伸的向量,这是从 L 中行 i 延伸的水平面和从 R 中列 j 延伸的垂直面相交的地方。沿着这个向量,左、右参数中成对的 (i, k)(k, j) 元素相遇并相乘,然后将得到的乘积沿 k 求和,并将结果存入结果矩阵的 i, j 位置。

( 잠깐 미리 살펴보자면, 这里是动画。)

这是矩阵乘法的直观含义:

  1. 将两个正交矩阵投影到立方体内部
  2. 相乘每个交点处的数值对,形成一个乘积网格
  3. 沿着第三个正交维度求和,得到结果矩阵。

为了定向,工具在立方体内部显示一个指向结果矩阵的箭头,其中蓝色箭头来自左参数,色箭头来自参数。工具还显示了白色参考线,指示每个矩阵的行轴,尽管在此截图中它们较淡。

布局约束很简单:

  • 左参数和结果必须沿着它们共享的高度 (i) 维度相邻
  • 右参数和结果必须沿着它们共享的宽度 (j) 维度相邻
  • 左参数和右参数必须沿着它们共享的(左宽度/右高度)维度相邻,这个维度变成了矩阵乘法的深度 (k) 维度

这种几何结构为可视化所有标准的矩阵乘法分解提供了一个坚实的基础,也为探索非平凡复杂的矩阵乘法组合提供了一个直观的基础,我们将在下文看到。

2 热身 - 动画

在深入探讨更复杂的例子之前,我们将进行一些直观的练习,以感受这种可视化风格下的外观和感觉。

2a 点积

首先,典型的算法——通过计算相应的左行和右列的点积来计算每个结果元素。我们在动画中看到的是乘数值向量在立方体内部的扫过,每个向量都在相应位置产生求和结果。

这里,L 有填充着 1(蓝色)或 -1(红色)的行块;R 有类似填充的列块。k 在这里是 24,因此结果矩阵 (L @ R) 的蓝色值为 24,红色值为 -24(在 mm 中打开 - 长按或按住 Control 单击查看值)

2b 矩阵-向量乘积

分解为矩阵-向量乘积的矩阵乘法看起来像是一个垂直平面(左参数与右参数的每一列的乘积)在立方体内部水平扫过,将列绘制到结果上(在 mm 中打开)

观察分解的中间值会非常有趣,即使在简单的例子中也是如此。

例如,请注意当我们使用随机初始化的参数时,中间矩阵-向量乘积中显著的垂直模式——这反映了每个中间产物都是左参数经过列缩放的副本这一事实(在 mm 中打开)

2c 向量-矩阵乘积

分解为向量-矩阵乘积的矩阵乘法看起来像是一个水平平面在立方体内部向下扫过,将行绘制到结果上(在 mm 中打开)

切换到随机初始化的参数后,我们看到了与矩阵-向量乘积相似的模式——只是这次模式是水平的,这对应于每个中间向量-矩阵乘积都是右参数经过行缩放的副本这一事实。

在思考矩阵乘法如何表达其参数的秩和结构时,将这两种模式同时发生在计算中进行设想非常有用(在 mm 中打开)

这里还有一个使用向量-矩阵乘积的直观示例,展示了单位矩阵如何像一面设置在与其相对参数和结果都呈 45 度角的镜子一样工作(在 mm 中打开)

2d 外积求和

第三种平面分解是沿 k 轴进行,通过向量外积的逐点求和来计算矩阵乘积的结果。在这里,我们看到外积平面从立方体的“后方到前方”扫过,累积到结果中(在 mm 中打开)

将随机初始化的矩阵用于这种分解时,我们不仅可以看到数值,还可以看到在结果中累积,因为每个秩为 1 的外积都被添加到结果中。

除此之外,这还建立了对为什么“低秩分解”(即通过构造一个深度维度较小的矩阵乘法来近似一个矩阵)在被近似矩阵本身为低秩时效果最佳的直观理解。后面的 LoRA 部分(在 mm 中打开)提供了这种见解的一个例子。

3 热身 - 表达式

我们如何将这种可视化方法扩展到矩阵乘法的组合?到目前为止,我们的例子都只可视化了矩阵 LR 的单个矩阵乘法 L @ R——那么当 L 和/或 R 本身就是矩阵乘法,并依此类推地传递下去时,会是怎样呢?

事实证明,我们可以很好地将这种方法扩展到复合表达式。关键规则很简单:子表达式(子级)的矩阵乘法是另一个立方体,遵循与父级相同的布局约束,并且子级的结果面与父级对应的参数面同时共享,就像共享的共价电子一样。

在这些约束内,我们可以随意排列子矩阵乘法的面。在这里,我们使用了工具的默认方案,它生成交替的凸面和凹面立方体——这种布局在实践中效果很好,可以最大限度地利用空间并最小化遮挡。(然而,布局是完全可定制的——请参阅参考了解详细信息。)

在本节中,我们将可视化机器学习模型中的一些关键构建块,以熟悉视觉语言,并看看即使是简单的例子也能给我们带来什么样的直观理解。

3a 左结合表达式

我们将查看两个形式为 (A @ B) @ C 的表达式,每个表达式都有其独特的形状和特征。(注意:mm 遵循矩阵乘法是左结合的约定,并将其简单地写为 A @ B @ C。)

首先,我们将给 A @ B @ C 赋予典型的 FFN 形状,其中“隐藏维度”比“输入”或“输出”维度更宽。(具体来说,在这个例子中,这意味着 B 的宽度大于 AC 的宽度。)

与单个矩阵乘法示例中一样,浮动箭头指向结果矩阵,蓝色箭头来自左参数,红色箭头来自右参数(在 mm 中打开)

As in the single matmul examples, the floating arrows point towards the result matrix, blue vane coming from the left argument and red vane from right argument

接下来,我们将可视化 A @ B @ C,其中 B 的宽度比 AC 的宽度,赋予其瓶颈或“自编码器”形状(在 mm 中打开)

visualize A @ B @ C with the width of B narrower than that of A or C

这种凸面和凹面块交替的模式可以扩展到任意长度的链:例如这个多层瓶颈(在 mm 中打开)

pattern of alternating convex and concave blocks extends to chains of arbitrary length

3b 右结合表达式

接下来,我们将可视化一个右结合表达式 A @ (B @ C)

就像左结合表达式水平延伸(可以说从根表达式的左参数“生长”出来)一样,右结合链垂直延伸,从根表达式的右参数“生长”出来。

有时,人们会看到 MLP 被右结合地表述,即输入以列形式位于右侧,权重层从右向左运行。使用上面图片所示的双层 FFN 示例中的矩阵(经过适当转置),看起来是这样的,其中 C 现在扮演输入角色,B 是第一层,A 是第二层 (在 mm 中打开)

an MLP formulated right-associatively

题外话:除了箭头叶片的颜色(左侧蓝色,右侧红色)之外,区分左侧和右侧参数的第二个视觉线索是它们的方向:左侧参数的行与结果的行共面——它们沿同一轴(i)堆叠。这两个线索都告诉我们,例如,B 是上方 (B @ C) 的左侧参数。

3c 二元表达式

对于一个可视化工具来说,要想超越简单的教学示例变得有用,可视化需要随着表达式变得更复杂而保持清晰易读。实际应用中的一个关键结构组件是二元表达式——即左右两侧都有子表达式的矩阵乘法(matmuls)。

这里我们将可视化最简单的此类表达式形式:(A @ B) @ (C @ D) (在 mm 中打开)

binary expressions - matmuls with subexpressions on both the left and right

3d 快速题外话:分区与并行性

本文的范围不包含对这一主题的完整介绍,尽管稍后我们将在注意力头的上下文中看到其应用。但作为热身,两个快速示例应该能让你感受一下这种可视化风格如何通过简单的几何分区使复合表达式的并行化推理变得非常直观。

在第一个示例中,我们将把经典的“数据并行”分区应用于上面的左结合多层瓶颈示例。我们沿 i 轴进行分区,分割初始左侧参数(“批量”)和所有中间结果(“激活”),但不分割后续参数(“权重”)——几何结构使其显而易见,哪些表达式参与者被分割,哪些保持完整 (在 mm 中打开)

the canonical "data parallel" partitioning to the left-associative multilayer bottleneck example

第二个示例(至少对我来说)如果没有清晰的几何结构支持,会很难建立直观理解:它展示了如何通过沿左侧子表达式的 j 轴、沿右侧子表达式的 i 轴以及沿父表达式的 k 轴分区来并行化二元表达式 (在 mm 中打开)

a binary expression can be parallelized by partitioning the left subexpression along its j axis, the right subexpression along its i axis, and the parent expression along its k axis

4 注意力头内部

让我们来看一个 GPT2 注意力头——具体来说,是来自 NanoGPT 的“gpt2”(小型)配置(层数=12,头数=12,嵌入维度=768)的第 5 层、第 4 个头,使用通过 HuggingFace 获取的 OpenAI 权重。输入激活值取自一个 OpenWebText 训练样本(包含 256 个 token)的前向传播过程。

这个特定的注意力头没有什么特别之处;我选择它主要是因为它计算了一个相当常见的注意力模式,并且位于模型的中间,在那里,激活值已经变得结构化并显示出一些有趣的纹理。(题外话:在后续文章中,我将介绍一个注意力头探索器,它允许你可视化该模型的所有层和头,并附带一些使用说明。)

在 mm 中打开(可能需要几秒钟获取模型权重)

There's nothing particularly unusual about this particular head

4a 结构

整个注意力头被可视化为一个单一的复合表达式,从输入开始,到投影输出结束。(注意:为了保持内容的独立性,我们按照 Megatron-LM 中描述的方式进行每个头的输出投影。)

计算包含六个矩阵乘法

Q = input @ wQ        // 1
K_t = wK_t @ input_t  // 2
V = input @ wV        // 3
attn = sdpa(Q @ K_t)  // 4
head_out = attn @ V   // 5
out = head_out @ wO   // 6

我们所看到内容的简要描述

  • 风车的叶片是矩阵乘法 1、2、3 和 6:前一组是从输入到 Q、K 和 V 的输入投影;后者是从 attn @ V 到嵌入维度的输出投影。
  • 在中心是两次矩阵乘法,首先计算注意力分数(后面的凸形立方体),然后使用它们从值向量(前面的凹形立方体)生成输出 token。因果关系意味着注意力分数形成一个下三角矩阵。

但我鼓励在工具中亲自探索这个示例,而不是仅仅依靠截图或下面的视频来理解能从中获取多少信息,包括关于其结构以及流经计算过程的实际值。

4b 计算与值

这里是注意力头计算的一个动画。具体来说,我们正在观看

sdpa(input @ wQ @ K_t) @ V @ wO

(即上面的矩阵乘法 1、4、5 和 6,其中 K_tV 已预先计算好)作为向量-矩阵乘法的融合链进行计算:序列中的每个项都在一步中从输入经过注意力到达输出。关于这个动画选择的更多内容将在后面关于并行化的部分介绍,但首先让我们看看计算得到的值告诉了我们什么。

在 mm 中打开

这里有很多有趣的事情正在发生。

  • 在我们甚至开始注意力计算之前,QK_t 的低秩程度令人相当惊讶。放大 Q @ K_t 向量-矩阵乘法动画,情况更加生动:QK两者都有相当数量的通道(嵌入位置)在整个序列中看起来或多或少是恒定的,这意味着有用的注意力信号可能只由嵌入维度中的一小部分驱动。理解和利用这一现象是我们作为 SysML ATOM transformer efficiency 项目的一部分正在探索的课题之一。
  • 也许最常见的是注意力矩阵中出现的强大但并非完美的对角线。这是一种常见的模式,出现在该模型的许多注意力头(以及许多 transformer 模型)中。它产生局部化的注意力:输出 token 位置紧邻前方的少量邻域内的值 token 在很大程度上决定了该输出 token 的内容模式。
  • 然而,该邻域的大小及其内部各个 token 的影响并非微不足道,而是有显著变化的——这可以在注意力网格中的非对角线“霜”中看到,以及 attn[i] @ V 向量-矩阵乘积平面在沿序列下降通过注意力矩阵时出现的波动模式中看到。
  • 但请注意,局部邻域并不是唯一引起注意力的东西:注意力网格最左侧的一列,对应于序列的第一个 token,完全充满了非零(但波动)的值,这意味着每个输出 token 都会在某种程度上受到第一个值 token 的影响。
  • 此外,当前 token 邻域与初始 token 之间的注意力分数主导权存在一种不精确但可辨别的振荡。这种振荡的周期各不相同,但大致来说,随着序列的进行,周期先是较短,然后逐渐变长(根据因果关系,这与每一行的候选注意力 token 数量有启发性的关联)。
  • 要了解 attn @ V 是如何形成的,重要的是不要孤立地关注注意力——V 是一个同等重要的参与者。每个输出项都是整个 V 向量的加权平均;在注意力是完美对角线的情况下,attn @ V 仅仅是 V 的一个精确复制。这里我们看到了一些更具纹理的内容:可以看到明显的条带,其中某些 token 在连续的注意力行子序列上得分很高,叠加在一个明显类似于 V 的矩阵上,但由于较宽的对角线而有一些垂直的涂抹效果。(题外话:根据 mm 参考指南,长按或按住 Control 键点击将显示可视化元素的实际数值。)
  • 请记住,由于我们处于中间层(5),因此该注意力头的输入是中间表示,而不是原始的分词文本。因此,输入中看到的模式本身就引人深思——特别是,那些强烈的垂直“线程”是指特定嵌入位置的值在序列的很大一部分范围内(有时几乎是整个序列)都具有统一的高幅值。
  • 有趣的是,输入序列中的第一个向量却很特别,它不仅打破了这些高幅值列的模式,而且在几乎每个位置都带有非典型值(附注:此处未可视化,但此模式在多个样本输入中重复出现)。

注意:关于最后两点,值得重申的是,我们正在可视化基于单个样本输入的计算。在实践中我发现,每个头部在一系列不错的样本中都会一致地(尽管并非完全相同地)表现出其特有的模式(即将推出的注意力头浏览器将提供一系列样本供您尝试),但当查看任何包含激活的可视化时,务必记住完整的输入分布可能会以微妙的方式影响其引发的思考和直觉。

最后,再次推荐您直接探索动画

4c 头部在有趣的方式上有所不同

在我们继续之前,这里还有另一个例子,演示了简单地“摆弄”模型来详细了解其工作原理是多么有用。

这是 GPT2 中的另一个注意力头。它的行为与上面提到的第 5 层第 4 个头(layer 5, head 4)大相径庭——正如所料,因为它位于模型的非常不同的部分。这个头位于第一层:第 0 层第 2 个头(在 mm 中打开,可能需要几秒钟加载模型权重)

This is another attention head from GPT2

值得注意的地方

  • 这个头的注意力分布非常均匀。这样做的效果是,在 attn @ V 中,将 V 的相对未加权平均(或者更确切地说,是 V 的适当因果前缀)传递到每一行,如这个动画所示:当我们沿着注意力得分三角形向下移动时,attn[i] @ V 向量-矩阵乘积与 V 的缩小版、逐步揭示的副本仅有微小波动。
  • attn @ V 具有惊人的垂直均匀性——在嵌入的大片列状区域中,相同的值模式会持续贯穿整个序列。可以将这些视为所有 token 共享的属性。
  • 附注:一方面,考虑到注意力分布非常均匀的效果,人们可能会期望 attn @ V 中具有一定的均匀性。但每一行仅由 V 的因果子序列而非整体构建而成——为什么这没有引起更多的变化,例如随着序列向下移动而产生的渐进变形?通过视觉检查 V 沿着其长度方向并非均匀的,因此答案一定在于其值分布的某种更微妙的属性。
  • 最后,这个头的输出在经过输出投影后更加垂直均匀
  • 强烈的印象是,这个注意力头传递的大部分信息由序列中所有 token 共享的属性构成。其输出投影权重的构成强化了这种直觉。

总的来说,很难否定这样的想法:这个注意力头产生的极其规律、高度结构化的信息,可能可以通过某种……不那么“奢华”的计算方式获得。当然,这不是一个未被探索的领域,但可视化计算的特异性和丰富的信号有助于产生新的想法,并推理由现有想法。

4d 重温要点:免费获得不变量

回顾一下,值得重申的是,我们之所以能够可视化注意力头这样非平凡的复合操作并使其保持直观性,是因为重要的代数属性——例如参数形状如何受限,或者哪些并行化轴与哪些操作相交——无需额外思考:它们直接来源于可视化对象的几何结构,而不是需要记住的额外规则。

例如,在这些注意力头可视化中,可以立即清楚地看到:

  • Qattn @ V 的长度相同,KV 的长度相同,并且这两对的长度彼此独立
  • QK 的宽度相同,Vattn @ V 的宽度相同,并且这两对的宽度彼此独立。

这些属性是结构本身固有的真实属性,是复合结构组成部分所处的区域及其方向的简单结果。

这种“免费属性”的好处在探索规范结构的变体时特别有用——一个显而易见的例子是在自回归逐令牌解码中一行高的注意力矩阵(在 mm 中打开

the one-row-high attention matrix in autoregressive token-at-a-time decoding

5 注意力并行化

在上面的第 4 层第 5 个注意力头的动画中,我们可视化了注意力头中 6 个矩阵乘法中的 4 个

作为一个向量-矩阵乘积的融合链,这证实了几何直觉,即从输入到输出的整个左结合链沿共享的 i 轴呈层状,并且可以并行化。

5a 示例:沿 i 轴划分

实际上为了并行化计算,我们会沿着 i 轴将输入划分为块。我们可以在工具中通过指定将某个轴划分为特定数量的块来可视化这种划分——在这些示例中我们将使用 8 个块,但这个数字并没有什么特别之处。

除此之外,这种可视化清楚地表明,每个并行计算都需要完整的 wQ(用于输入投影)、K_tV(用于注意力)以及 wO(用于输出投影),因为它们沿着未划分的维度与被划分的矩阵相邻(在 mm 中打开

wQ (for in-projection), K_t and V (for attention) and wO (for out-projection) are needed in their entirety by each parallel computation

5b 示例:双重划分

作为沿多个轴划分的一个示例,我们可以可视化一些最近在这个领域进行创新的工作(Block Parallel Transformer,基于 Flash Attention 及其前身等工作)。

首先,BPT 如上所述沿着 i 轴进行划分——并且实际上将这种序列的横向划分扩展到注意力层的后半部分(FFN)的整个过程。(我们将在后面的部分可视化这一点。)

为了完全解决上下文长度问题,接着在 MHA 中添加了第二种划分——即注意力计算本身的划分(也就是说,沿 Q @ K_tj 轴进行划分)。这两种划分共同将注意力划分为块的网格(在 mm 中打开

The two partitions together divide attention into a grid of blocks

这种可视化清楚地表明

  • 这种双重划分作为解决上下文长度问题的有效性,因为我们现在已经明显地划分了注意力计算中序列长度的每次出现
  • 第二种划分的“范围”:从几何学上可以清楚地看出,KV 的输入投影计算可以与核心的双重矩阵乘法一起进行划分

注意一个微妙之处:这里的视觉含义是,我们也可以将后续的矩阵乘法 attn @ V 沿着 k 轴并行化,并以 split-k 方式 对部分结果求和,从而并行化整个双重矩阵乘法。但是 sdpa() 中的按行 softmax 增加了要求,即在计算 attn @ V 的相应行之前,每一行必须将其所有段进行归一化,这在注意力计算和最终矩阵乘法之间增加了一个额外的按行步骤。

6 注意力层中的大小

注意力层的第一部分(MHA)以其二次复杂度而闻名,计算量巨大,但第二部分(FFN)由于其隐藏维度很宽,通常是模型嵌入维度的 4 倍,计算量也很大。可视化整个注意力层的“生物量”(biomass)有助于直观了解层的前后两部分如何相互比较。

6a 可视化整个层

下面是一个完整的注意力层,前半部分(MHA)在背景,后半部分(FFN)在前景。像往常一样,箭头指向计算的方向。

注意事项

  • 这种可视化并没有描绘单个注意力头,而是展示了围绕着一个中心双重矩阵乘法的未切片 Q/K/V 权重和投影。当然,这并不是对完整 MHA 操作的忠实可视化——但这里的目标是更清楚地了解层前后两部分的相对矩阵大小(sizes),而不是每一部分执行的相对计算量。(此外,使用了随机值而不是真实权重。)
  • 这里使用的维度为了让浏览器(相对地)保持正常而缩小了,但比例得以保留(来自 NanoGPT 的小型配置):模型嵌入维度 = 192(原为 768),FFN 嵌入维度 = 768(原为 3072),序列长度 = 256(原为 1024),尽管序列长度对模型来说并非基础。(从视觉上看,序列长度的变化会表现为输入“叶片”(blades)宽度的变化,进而影响注意力“中心”(hub)的大小和下游垂直平面的高度。)

在 mm 中打开:

a full attention layer with the first half (MHA) in the background and the second (FFN) in the foreground

6b 可视化 BPT 划分的层

简要回顾 Blockwise Parallel Transformer,这里我们在整个注意力层的背景下可视化 BPT 的并行化方案(如上所述省略了单个注意力头)。特别要注意,沿着 i 轴(序列块)的划分如何贯穿 MHA 和 FFN 两部分(在 mm 中打开

visualize BPT's parallelization scheme in the context of an entire attention layer

6c 划分 FFN

这种可视化表明可以进行额外的划分,与上述划分正交——在注意力层的 FFN 部分,将双重矩阵乘法 (attn_out @ FFN_1) @ FFN_2 进行拆分,首先沿 j 轴对 attn_out @ FFN_1 进行划分,然后在与 FFN_2 的后续矩阵乘法中沿 k 轴进行划分。这种划分切分了 FFN 权重的两层,在计算中降低了每个参与者的容量需求,但代价是最终需要对部分结果进行求和。

下面是将这种划分应用于一个未进行其他划分的注意力层的样子(在 mm 中打开

what this partition looks like applied to an otherwise unpartitioned attention layer

下面是将其应用于一个按照 BPT 方式划分的层的样子(在 mm 中打开

applied to a layer partitioned a la BPT

6d 可视化逐令牌解码

在自回归逐令牌解码过程中,查询向量由单个令牌组成。在这种情况下,对注意力层是什么样子有一个心理画面是很有启发性的——一个单一的嵌入行穿过巨大的平铺权重平面。

除了强调权重相对于激活值的巨大体量之外,这种视图还让人联想到 K_tV 的功能类似于 6 层 MLP 中动态生成的层,尽管 MHA 本身的复用/解复用计算(这里为了简化而忽略,如上所述)使得这种对应关系不完全精确(在 mm 中打开

the mux/demux computations of MHA itself

7 LoRA

最近的 LoRA 论文(LoRA: Large Language Models 的低秩适应)描述了一种基于微调过程中引入的权重变化是低秩的高效微调技术。根据论文,这“允许我们通过优化神经网络中密集层在适应过程中的变化的秩分解矩阵来间接训练一些密集层 […],同时保持预训练权重冻结。”

7a 基本思想

简而言之,关键之处在于训练权重矩阵的*因子*,而不是训练矩阵本身:用一个 I x K 张量和一个 K x J 张量的矩阵乘法来替换一个 I x J 权重张量,其中 K 保持在一个较小的数值。

如果 K 足够小,尺寸上的优势会非常大,但权衡是降低它会降低乘积所能表达的秩。作为对尺寸节省和结果结构化效应的快速说明,这里展示了一个随机 128 x 4 左侧参数和 4 x 128 右侧参数的矩阵乘法——也称为 128 x 128 矩阵的秩-4分解。注意 L @ R 中垂直和水平的图案。(在 mm 中打开

a matmul of random 128 x 4 left and 4 x 128 right arguments

7b 将 LoRA 应用于注意力头

LoRA 将这种分解方法应用于微调过程的方式是

  • 为每个待微调的权重张量创建一个低秩分解并训练这些因子,同时保持原始权重不变
  • 微调后,将每对低秩因子相乘得到一个与原始权重张量形状相同的矩阵,并将其加到原始预训练权重张量上

下面的可视化展示了一个注意力头,其权重张量 wQwK_twVwO 被低秩分解 wQ_A @ wQ_B 等替换。从视觉上看,因子矩阵显示为沿着风车叶片边缘的低矮围栏。(在 mm 中打开 - 按空格键停止旋转)

8 总结

8a 征集反馈

我发现这种可视化矩阵乘法表达式的方式对于建立直觉和推理非常有帮助,不仅对于矩阵乘法本身,而且对于机器学习模型及其计算的许多方面,从效率到可解释性。

如果你试用后有任何建议或评论,我非常希望听到,可以在这里的评论区留言或在仓库中提出

8b 后续步骤

  • 基于该工具构建了一个GPT2 注意力头探索器,我目前正在使用它来清点和分类在该模型中发现的注意力头特征。(我就是使用这个工具来发现和探索本笔记中的注意力头。)完成后,我计划发布一篇包含这份清单的笔记。
  • 如上所述,在 Python Notebook 中嵌入这些可视化是非常简单的。但会话 URL 可能会变得... 难以管理,因此有一个 Python 端的工具函数来从配置对象构建它们将非常有用,类似于参考指南中使用的简单 JavaScript 辅助函数。
  • 如果你有一个你认为可能会受益于这种可视化但又不确定如何使用该工具来实现的用例,请与我联系!我不一定希望大幅扩展其核心可视化功能(选择合适的工具做合适的事等),但例如,用于以编程方式驱动它的 API 相当基础,在这方面还有很多可以改进的地方。