跳转到主要内容
博客

矩阵内部:可视化矩阵乘法、注意力机制及其他

作者: 2023 年 9 月 25 日2024 年 11 月 14 日暂无评论

使用 3D 可视化矩阵乘法表达式、带有真实权重的注意力头等。

矩阵乘法 (matmuls) 是当今机器学习模型的基本组成部分。本文介绍了一款可视化 matmuls 及其组合的工具 mm

矩阵乘法本质上是一种三维运算。由于 mm 使用所有三个空间维度,它比通常的纸上二维表示法能更清晰、更直观地传达含义,特别是(但不仅限于)对于视觉/空间思维者而言。

我们也有空间以几何一致的方式组合 matmuls,因此我们可以使用与简单表达式相同的规则来可视化大型复合结构,如注意力头和 MLP 层。更高级的功能,如动画不同的 matmul 算法、并行分区和加载外部数据以探索实际模型的行为,都自然地建立在此基础上。

mm 是完全交互式的,在 浏览器中 运行,并将其完整状态保存在 URL 中,因此链接是可共享的会话(本文中的截图和视频都有链接,可在工具中打开相应的可视化)。此 参考指南 描述了所有可用功能。

我们将首先介绍可视化方法,通过可视化一些简单的矩阵乘法和表达式来建立直觉,然后深入探讨一些更复杂的示例。

  1. 主旨 – 为什么这种可视化方式更好?
  2. 热身 – 动画 – 观看典型矩阵乘法分解的实际操作
  3. 热身 – 表达式 – 快速浏览一些基本表达式构建块
  4. 注意力头内部 – 深入了解 GPT2 中几个注意力头的结构、值和计算行为,通过 NanoGPT
  5. 并行化注意力 – 可视化注意力头并行化,并以最近的 分块并行 Transformer 论文为例
  6. 注意力层中的大小 – 当我们将整个注意力层作为一个单一结构可视化时,MHA 和 FFA 两半看起来如何?在自回归解码过程中,图像如何变化?
  7. LoRA – 对注意力头架构这一精细化的视觉解释
  8. 总结 – 后续步骤和反馈征集

1 主旨

mm 的可视化方法基于“矩阵乘法本质上是一种三维运算”的论点。

换句话说,这

matrix multiplication is fundamentally a three-dimensional operation

是试图成为这样的纸张 (在 mm 中打开)

wrap the matmul around a cube

当我们以这种方式将矩阵乘法包裹在一个立方体周围时,参数形状、结果形状和共享维度之间的正确关系就都各就各位了。

现在,计算具有了几何意义:结果矩阵中每个位置i, j都锚定了一个沿立方体内部深度维度k运行的向量,左参数中第i行延伸出的水平平面和右参数中第j列延伸出的垂直平面在此相交。沿着这个向量,来自左右参数的(i, k) (k, j)元素对相遇并相乘,然后沿k求和,结果存入结果的i, j位置。

(暂时跳过,这是一个动画。)

这是矩阵乘法的直观含义

  1. 两个正交矩阵投影到立方体内部
  2. 每个交点的数值对相乘,形成一个乘积网格
  3. 沿着第三个正交维度求和,生成结果矩阵。

为了方便定向,工具在立方体内部显示一个指向结果矩阵的箭头,蓝色箭头来自左参数,红色箭头来自右参数。工具还会显示白色指导线以指示每个矩阵的行轴,尽管在此截图中它们很模糊。

布局约束直截了当

  • 左参数和结果必须沿着它们共享的高度(i)维度相邻
  • 右参数和结果必须沿着它们共享的宽度(j)维度相邻
  • 左参数和右参数必须沿着它们共享的(左宽度/右高度)维度相邻,这成为矩阵乘法的深度(k)维度

这种几何结构为我们可视化所有标准矩阵乘法分解提供了坚实的基础,并为探索非平凡复杂的矩阵乘法组合提供了直观的基础,我们将在下面看到。

2 热身 – 动画

在深入探讨更复杂的示例之前,我们将通过几个直观的构建器来感受这种可视化风格的外观和感觉。

2a 点积

首先,是经典算法——通过计算相应左行和右列的点积来计算每个结果元素。我们在动画中看到的是乘积值向量在立方体内部的扫描,每个向量都在相应位置生成一个求和结果。

在这里,L 的行块填充有 1(蓝色)或 -1(红色);R 的列块也类似填充。这里 k 是 24,所以结果矩阵(L @ R)的蓝色值为 24,红色值为 -24 (在 mm 中打开 – 长按或按住 Control 键点击以检查数值)

2b 矩阵-向量乘积

分解为矩阵-向量乘积的矩阵乘法看起来像一个垂直平面(左参数与右参数每列的乘积),它在立方体内部水平扫描时将列绘制到结果上 (在 mm 中打开)

观察分解的中间值会非常有趣,即使在简单的例子中也是如此。

例如,当我们使用随机初始化的参数时,请注意中间矩阵-向量乘积中明显的垂直模式——这反映了每个中间值是左参数按列缩放的副本这一事实 (在 mm 中打开)

2c 向量-矩阵乘积

分解成向量-矩阵乘积的矩阵乘法看起来像一个水平平面,它在立方体内部下降时将行绘制到结果上 (在 mm 中打开)

切换到随机初始化的参数,我们看到与矩阵-向量乘积相似的模式——只是这次模式是水平的,这与每个中间向量-矩阵乘积是右参数按行缩放的副本这一事实相对应。

当思考矩阵乘法如何表达其参数的秩和结构时,设想这两种模式同时在计算中发生会很有用 (在 mm 中打开)

这是另一个使用向量-矩阵乘积的直观构建器,展示了单位矩阵如何像一面与反参数和结果成 45 度角放置的镜子一样工作 (在 mm 中打开)

2d 求和外积

第三个平面分解是沿着 k 轴,通过向量外积的点态求和来计算矩阵乘法结果。在这里,我们看到外积平面“从后到前”扫描立方体,并累积到结果中 (在 mm 中打开)

使用随机初始化的矩阵进行这种分解,我们不仅可以看到值,还可以看到结果中累积的,因为每个秩为 1 的外积都添加到其中。

除此之外,这也有助于理解为什么“低秩分解”——即通过构建深度维度较小的矩阵乘法来近似矩阵——在被近似的矩阵是低秩时效果最好。LoRA 在后面的章节中 (在 mm 中打开)

3 热身 – 表达式

我们如何将这种可视化方法扩展到矩阵乘法的组合?到目前为止,我们的例子都可视化了单个矩阵乘法L @ R,其中L和/或R本身就是矩阵乘法,并依此类推?

事实证明,我们可以很好地将这种方法扩展到复合表达式。关键规则很简单:子表达式(子)矩阵乘法是另一个立方体,服从与父表达式相同的布局约束,并且子表达式的结果面同时是父表达式的相应参数面,就像一个共价共享的电子。

在这些约束下,我们可以自由地安排子矩阵乘法的各个面。这里我们使用工具的默认方案,它生成交替的凸形和凹形立方体——这种布局在实践中能很好地最大化空间利用并最小化遮挡。(然而,布局是完全可定制的——详见参考资料。)

在本节中,我们将可视化机器学习模型中的一些关键构建块,以熟悉视觉表达方式,并了解即使是简单的示例也能给我们带来哪些直觉。

3a 左结合表达式

我们将研究两个形式为 (A @ B) @ C 的表达式,每个都有其独特的形状和特征。(注意:mm 遵循矩阵乘法是左结合的约定,并将其简单地写为 A @ B @ C。)

首先,我们将 A @ B @ C 赋予其特征性的 FFN 形状,其中“隐藏维度”比“输入”或“输出”维度更宽。(具体来说,在此示例中,这意味着 B 的宽度大于 AC 的宽度。)

与单个矩阵乘法示例一样,浮动箭头指向结果矩阵,蓝色箭头来自左参数,红色箭头来自右参数 (在 mm 中打开)

As in the single matmul examples, the floating arrows point towards the result matrix, blue vane coming from the left argument and red vane from right argument

接下来,我们将可视化 A @ B @ C,其中 B 的宽度窄于 AC,使其具有瓶颈或“自编码器”形状 (在 mm 中打开)

visualize A @ B @ C with the width of B narrower than that of A or C

这种凸块和凹块交替的模式可以扩展到任意长度的链:例如,这个多层瓶颈 (在 mm 中打开)

pattern of alternating convex and concave blocks extends to chains of arbitrary length

3b 右结合表达式

接下来我们将可视化一个右结合表达式 A @ (B @ C)

就像左结合表达式水平延伸一样——可以说从根表达式的左参数萌芽——右结合链垂直延伸,从根表达式的右参数萌芽。

有时人们会将 MLP(多层感知器)表述为右结合的,即输入列向量在右侧,权重层从右到左排列。使用上面 2 层 FFN 示例中的矩阵——适当地转置后——看起来是这样的,其中 C 现在扮演输入的角色,B 是第一层,A 是第二层 (在 mm 中打开)

an MLP formulated right-associatively

旁注:除了箭头扇叶的颜色(左侧蓝色,右侧红色)外,区分左右参数的第二个视觉线索是它们的方向:左参数的行与结果的行共面——它们沿着相同的轴(i)堆叠。这两个线索都告诉我们,例如,B 是上面 (B @ C) 的左参数。

3c 二元表达式

为了使可视化工具超越简单的教学示例,当表达式变得更复杂时,可视化需要保持可读性。实际用例中的一个关键结构组件是二元表达式——左右两侧都有子表达式的矩阵乘法。

这里我们将可视化最简单的此类表达式形状:(A @ B) @ (C @ D) (在 mm 中打开)

binary expressions - matmuls with subexpressions on both the left and right

3d 快速旁注:分区和并行性

关于这个主题的完整介绍超出了本文的范围,尽管我们将在后面关注力头的背景下看到它的实际应用。但是作为热身,两个快速示例应该能让您了解这种可视化风格如何通过简单的分区几何使并行化复合表达式的推理变得非常直观。

在第一个示例中,我们将对上面的左结合多层瓶颈示例应用典型的“数据并行”分区。我们沿着 i 轴分区,分割初始左参数(“批次”)和所有中间结果(“激活”),但不分割任何后续参数(“权重”)——几何结构清晰地表明了表达式中哪些参与者被分割,哪些保持完整 (在 mm 中打开)

the canonical "data parallel" partitioning to the left-associative multilayer bottleneck example

第二个示例(至少对我而言)如果没有清晰的几何结构支持,将很难直观地理解:它展示了如何通过沿左子表达式的 j 轴、右子表达式的 i 轴以及父表达式的 k 轴进行分区来并行化二元表达式 (在 mm 中打开)

a binary expression can be parallelized by partitioning the left subexpression along its j axis, the right subexpression along its i axis, and the parent expression along its k axis

4 注意力头内部

让我们看看 GPT2 的一个注意力头——具体来说,是 NanoGPT 中“gpt2”(小型)配置(层数=12,头数=12,嵌入=768)的第 5 层、第 4 个注意力头,使用通过 HuggingFace 获取的 OpenAI 权重。输入激活来自 OpenWebText 训练样本(256 个 token)的前向传播。

这个特定的注意力头并没有什么特别之处;我选择它主要是因为它计算了一种相当常见的注意力模式,并且位于模型的中间,那里的激活已经变得结构化并显示出一些有趣的纹理。(旁注:在随后的文章中,我将介绍一个注意力头浏览器,它允许您可视化该模型的所有层和注意力头,并附带一些旅行笔记。)

在 mm 中打开 (可能需要几秒钟来获取模型权重)

There's nothing particularly unusual about this particular head

4a 结构

整个注意力头被可视化为一个复合表达式,从输入开始,到投影输出结束。(注意:为了保持内容的自包含性,我们按照 Megatron-LM 中描述的方式进行每个头的输出投影。)

计算包含六个矩阵乘法

Q = input @ wQ        // 1
K_t = wK_t @ input_t  // 2
V = input @ wV        // 3
attn = sdpa(Q @ K_t)  // 4
head_out = attn @ V   // 5
out = head_out @ wO   // 6

我们正在查看的简要描述

  • 风车的叶片是矩阵乘法 1、2、3 和 6:前者是将输入投影到 Q、K 和 V;后者是将 attn @ V 投影回嵌入维度。
  • 中心是双重矩阵乘法,首先计算注意力分数(后面的凸立方体),然后利用它们从值向量生成输出 token(前面的凹立方体)。因果关系意味着注意力分数形成一个下三角形。

但我鼓励您在工具中探索这个示例,而不是依赖截图或下面的视频来传达它可以从中获取的关于其结构和流经计算的实际值的信号量。

4b 计算与值

这是一个注意力头计算的动画。具体来说,我们正在观察

sdpa(input @ wQ @ K_t) @ V @ wO

(即,上面的矩阵乘法 1、4、5 和 6,其中 K_tV 预先计算)被计算为向量-矩阵乘积的融合链:序列中的每个项都一步从输入到注意力再到输出。关于此动画选择的更多信息将在后面的并行化部分中介绍,但首先让我们看看正在计算的值告诉我们什么。

在 mm 中打开

这里有很多有趣的事情发生。

  • 甚至在进行注意力计算之前,QK_t 的低秩程度就相当惊人。放大 Q @ K_t 向量-矩阵乘积动画,情况更加生动:QK 中有相当数量的通道(嵌入位置)在整个序列中看起来或多或少保持不变,这意味着有用的注意力信号可能只由嵌入的一小部分驱动。理解和利用这种现象是我们 SysML ATOM Transformer 效率项目的一部分。
  • 也许最熟悉的是注意力矩阵中出现的强烈但不完美的对角线。这是一种常见模式,出现在该模型(以及许多 Transformer 模型)的许多注意力头中。它产生局部化注意力:紧邻输出 token 位置的小范围内的值 token 在很大程度上决定了该输出 token 的内容模式。
  • 然而,这个邻域的大小以及其中单个 token 的影响会有非平凡的变化——这既可以在注意力网格的非对角线“霜冻”中看到,也可以在 注意力分数主导地位的不精确但可辨别的振荡,介于当前 token 邻域和初始 token 之间。振荡周期不同,但总体而言,随着序列的向下移动,振荡周期从短到长(与每个行的候选注意力 token 数量具有启发性相关性,考虑因果关系)。
  • 但是请注意,局部邻域并不是唯一吸引注意力的因素:注意力网格的最左列,对应于序列的第一个 token,完全填充了非零(但波动)值,这意味着每个输出 token 都将在某种程度上受到第一个值 token 的影响。
  • 此外,还有不精确但可辨别的振荡,在当前 token 邻域和初始 token 之间交替出现。振荡周期各不相同,但总体来说,随着序列的向下移动,周期开始时较短,然后逐渐变长(与每行的候选注意力 token 数量具有启发性相关性,考虑到因果关系)。
  • 要理解 (attn @ V) 是如何形成的,重要的是不要孤立地关注注意力——V 是同等重要的参与者。每个输出项都是整个 V 向量的加权平均值:当注意力是完美对角线时,attn @ V 简单地是 V 的精确副本。这里我们看到更具纹理的景象:明显的条带,其中特定 token 在连续的注意力行子序列中得分较高,叠加在一个与 V 明显相似但由于粗对角线而有一些垂直涂抹的矩阵上。(旁注:根据 mm 参考指南,长按或按住 Control 键点击将显示可视化元素的实际数值。)
  • 请记住,由于我们处于中间层 (5),因此此注意力头的输入是中间表示,而不是原始的分词文本。因此,输入中观察到的模式本身就发人深省——特别是,强烈的垂直线程是特定的嵌入位置,其值在序列的漫长延伸中(有时几乎是整个序列)保持统一的高幅度。
  • 然而,有趣的是,输入序列中的第一个向量是独特的,它不仅打破了这些高幅度列的模式,而且几乎在每个位置都携带着非典型值(旁注:此处未可视化,但此模式在多个样本输入中重复)。

注意:关于最后两点,值得重申的是,我们正在可视化**单个样本输入**的计算。实际上,我发现每个注意力头都有一个它将在大量样本中始终(但不完全相同地)表达的特征模式(即将推出的注意力头浏览器将提供大量样本供使用),但查看任何包含激活的可视化时,请务必记住,输入的完整分布可能会以微妙的方式影响它所引发的思想和直觉。

最后,再推荐一次直接探索动画

4c 注意力头有趣的不同之处

在继续之前,再演示一下简单地探究模型以详细了解其工作原理的有用性。

这是 GPT2 的另一个注意力头。它的行为与上面的第 5 层第 4 个注意力头截然不同——正如人们所预期的那样,因为它位于模型的非常不同的部分。这个注意力头位于第一层:第 0 层,第 2 个注意力头(在mm中打开,可能需要几秒钟来加载模型权重)。

This is another attention head from GPT2

注意事项

  • 这个注意力头将注意力均匀地分散开来。这导致了对 `V`(或者说,`V` 的适当因果前缀)的相对**非加权**平均值传递到 `attn @ V` 中的每一行,如此动画所示:随着我们在注意力分数三角形中向下移动,`attn[i] @ V` 向量-矩阵乘积与 `V` 的简单缩放、逐渐显现的副本之间只有微小波动。
  • `attn @ V` 具有惊人的垂直一致性——在嵌入的大部分列区域中,相同的值模式在**整个序列**中持续存在。可以将这些视为每个标记共享的属性。
  • 旁注:一方面,考虑到非常均匀分布的注意力所产生的影响,人们可能会期望 `attn @ V` 中存在**一些**一致性。但是,每行都是仅从 `V` 的因果子序列而不是整个序列构建的——为什么这没有引起更多的变化,比如随着序列的向下移动而逐渐变形?通过目视检查,V 的长度并不均匀,所以答案一定在于其值分布的某些更微妙的特性。
  • 最后,这个注意力头的输出在输出投影权重之后,垂直方向上甚至更加均匀,这强化了这种直觉。
  • 总的来说,这个注意力头产生的极其规则、高度结构化的信息,可能可以通过一些……不那么奢华的计算方式获得,这种想法很难被抗拒。当然,这不是一个未探索的领域,但是可视化计算的特异性和信号丰富性对于产生新想法和推理现有想法很有用。

最后,再补充一点,

4d 重温主旨:免费的不变性

回过头来看,值得重申的是,我们能够可视化注意力头这种非平凡的复合操作并保持其直观性的原因是重要的代数属性——例如参数形状如何受约束,或者哪些并行化轴与哪些操作相交——**无需额外思考**:它们直接源于可视化对象的几何结构,而不是需要记住的额外规则。

例如,在这些注意力头可视化中,立即显而易见的是

  • `Q` 和 `attn @ V` 长度相同,`K` 和 `V` 长度相同,并且这些对的长度彼此独立。
  • `Q` 和 `K` 宽度相同,`V` 和 `attn @ V` 宽度相同,并且这些对的宽度彼此独立。

这些属性是**通过构造**成立的,是组成部分在复合结构中所处位置及其方向的简单结果。

这种“免费属性”的优势在探索规范结构的变体时尤其有用——一个明显的例子是自回归逐词解码中的单行高注意力矩阵(在mm中打开)。

the one-row-high attention matrix in autoregressive token-at-a-time decoding

5 并行化注意力

在上面的头 5,层 4 的动画中,我们可视化了注意力头中 6 个矩阵乘法中的 4 个

作为向量-矩阵乘积的融合链,这证实了几何直觉,即从输入到输出的整个左关联链在共享 `i` 轴上是层状的,并且可以并行化。

5a 示例:沿 `i` 轴分区

为了在实践中并行化计算,我们将沿 `i` 轴将输入分成块。我们可以在工具中可视化此分区,方法是指定给定轴被分成特定数量的块——在这些示例中,我们将使用 8,但这没什么特别的。

除其他外,此可视化清楚地表明,`wQ`(用于输入投影)、`K_t` 和 `V`(用于注意力)以及 `wO`(用于输出投影)需要完整地用于每个并行计算,因为它们沿着这些矩阵的未分区维度与分区矩阵相邻(在mm中打开

wQ (for in-projection), K_t and V (for attention) and wO (for out-projection) are needed in their entirety by each parallel computation

5b 示例:双重分区

作为沿**多个**轴进行分区的示例,我们可以可视化最近在该领域进行创新的工作(块并行 Transformer,基于 Flash Attention 及其前身等工作)。

首先,BPT 如上所述沿 `i` 轴分区——实际上,这种序列块的水平分区一直延伸到注意力层的第二(FFN)半部分。 (我们将在后面的部分中进行可视化。)

为了完全解决上下文长度问题,MHA 中又增加了一个分区——即注意力计算本身的分区(即,沿 `Q @ K_t` 的 `j` 轴进行分区)。这两个分区共同将注意力分为一个块网格(在mm中打开)。

The two partitions together divide attention into a grid of blocks

此可视化清楚地表明

  • 这种双重分区作为解决上下文长度问题的方法的有效性,因为我们现在已经明显地将注意力计算中序列长度的每次出现都进行了分区
  • 第二次分区的影响范围:从几何结构上可以清楚地看出,`K` 和 `V` 的输入投影计算可以与核心双矩阵乘法一起分区。

请注意一个微妙之处:这里的视觉含义是,我们还可以沿 `k` 轴并行化随后的矩阵乘法 `attn @ V` 并求和部分结果,采用split-k 样式,从而并行化整个双矩阵乘法。但是 `sdpa()` 中的逐行 softmax 增加了在计算 `attn @ V` 的相应行之前,每行必须将其所有段归一化的要求,这在注意力计算和最终矩阵乘法之间增加了一个额外的逐行步骤。

6 注意力层中的大小

注意力层的第一部分(MHA)因其二次复杂度而闻名于计算量大,但第二部分(FFN)由于其隐藏维度(通常是模型嵌入维度的4倍)的宽度也同样计算量大。可视化整个注意力层的“生物量”有助于直观地了解层两部分的相互比较。

6a 可视化完整层

下面是一个完整的注意力层,前半部分(MHA)在背景中,后半部分(FFN)在前景中。和往常一样,箭头指向计算方向。

注意事项

  • 此可视化不描绘单个注意力头,而是显示围绕中心双矩阵乘法的未切片 Q/K/V 权重和投影。当然,这并不是对完整 MHA 操作的忠实可视化——但这里的目标是更清楚地了解层两部分的相对矩阵**大小**,而不是每部分执行的相对计算量。(此外,此处使用随机值而不是真实权重。)
  • 这里使用的维度已经缩小,以保持浏览器(相对)愉快,但比例保持不变(来自 NanoGPT 的小配置):模型嵌入维度 = 192(原为 768),FFN 嵌入维度 = 768(原为 3072),序列长度 = 256(原为 1024),尽管序列长度对模型来说不是根本性的。(视觉上,序列长度的变化将表现为输入刀片的宽度变化,从而影响注意力枢纽的大小和下游垂直平面的高度。)

在 mm 中打开:

a full attention layer with the first half (MHA) in the background and the second (FFN) in the foreground

6b 可视化 BPT 分区层

简要回顾分块并行 Transformer,这里我们在整个注意力层(如上所述省略了单个注意力头)的上下文中可视化 BPT 的并行化方案。特别注意沿 `i` 轴(序列块)的分区如何贯穿 MHA 和 FFN 两半部分(在mm中打开

visualize BPT's parallelization scheme in the context of an entire attention layer

6c FFN 分区

该可视化建议进行额外的分区,与上述分区正交——在注意力层的 FFN 半部分中,将双重矩阵乘法 `(attn_out @ FFN_1) @ FFN_2` 进行分割,首先沿 `j` 轴对 `attn_out @ FFN_1` 进行分区,然后在与 `FFN_2` 的后续矩阵乘法中沿 `k` 轴进行分区。此分区将 `FFN` 权重的两层都进行切片,从而降低了计算中每个参与者的容量要求,但代价是最终需要对部分结果进行求和。

这是应用于未分区注意力层(在mm中打开

what this partition looks like applied to an otherwise unpartitioned attention layer

这是应用于 BPT 分区的层(在mm中打开

applied to a layer partitioned a la BPT

6d 可视化逐词解码

在自回归逐词解码过程中,查询向量由单个标记组成。在那种情况下,注意力层是什么样子,在脑海中有一个清晰的图像是很有启发性的——一个单一的嵌入行在一个巨大的平铺权重平面中运行。

除了强调权重相对于激活的巨大性之外,这种视图还让人联想到这样一个概念:`K_t` 和 `V` 的功能类似于一个 6 层 MLP 中动态生成的层,尽管 MHA 本身的复用/解复用计算(此处未显示,如上所述)使这种对应关系不精确(在mm中打开)。

the mux/demux computations of MHA itself

7 LoRA

最近的 LoRA 论文(LoRA: Low-Rank Adaptation of Large Language Models)描述了一种基于微调期间引入的权重增量是低秩的思想的有效微调技术。根据论文,这“允许我们通过优化密集层在适应过程中变化的秩分解矩阵来间接训练神经网络中的一些密集层[……],同时保持预训练权重冻结。”

7a 基本思想

简而言之,关键的步骤是训练权重矩阵的**因子**而不是矩阵本身:用一个 `I x K` 张量和一个 `K x J` 张量的矩阵乘法来替换 `I x J` 权重张量,并将 `K` 保持在一个较小的数字。

如果 `K` 足够小,则大小增益可能巨大,但权衡是降低 `K` 会降低产品所能表达的秩。作为大小节省和结果结构化效果的快速说明,这里是随机的 `128 x 4` 左和 `4 x 128` 右参数的矩阵乘法——又名 `128 x 128` 矩阵的秩 4 分解。注意 `L @ R` 中的垂直和水平模式(在mm中打开)。

a matmul of random 128 x 4 left and 4 x 128 right arguments

7b 将 LoRA 应用于注意力头

LoRA 将这种因子分解方法应用于微调过程的方式是:

  • 为每个要微调的权重张量创建一个低秩分解并训练这些因子,同时保持原始权重不变
  • 微调后,将每对低秩因子相乘,得到一个与原始预训练权重张量形状相同的矩阵,并将其添加到原始预训练权重张量中

以下可视化展示了一个注意力头,其中权重张量 wQwK_twVwO 被低秩分解 wQ_A @ wQ_B 等替换。在视觉上,因子矩阵在风车叶片的边缘显示为低矮的栅栏(在 mm 中打开 – 按空格键停止旋转)。

8 总结

8a 征求反馈

我发现这种可视化矩阵乘法表达式的方式对于建立直觉和理解不仅是矩阵乘法本身,还包括机器学习模型及其计算的许多方面(从效率到可解释性)都非常有帮助。

如果您尝试后有任何建议或意见,我非常希望听到,无论是通过这里的评论还是在仓库中

8b 下一步

  • 有一个基于该工具构建的GPT2 注意力头探索器,我目前正用它来清点和分类在该模型中发现的注意力头特征。(这是我用于查找和探索本文中注意力头的工具。)完成后,我计划发布一份带有清单的说明。
  • 正如开头所提到的,将这些可视化嵌入到 Python 笔记本中非常简单。但是会话 URL 可能会变得……难以驾驭,因此,拥有用于从配置对象构建它们的 Python 端实用程序将很有用,类似于参考指南中使用的简单 JavaScript 辅助工具。
  • 如果您有一个您认为可能会受益于此类可视化的用例,但又不清楚如何使用该工具来实现,请联系我们!我不一定要进一步扩展其核心可视化功能(适合该工作的正确工具等),但是例如,用于以编程方式驱动它的 API 非常基础,在这方面还有很多可以做的事情。