为什么变压器架构中需要残差连接?

机器算法验证 神经网络 变压器 注意力 残差网络
2022-03-16 12:00:14

残差连接的动机通常是非常深的神经网络在训练期间倾向于“忘记”其输入数据集样本的某些特征。

这个问题可以通过以下方式将输入x与典型前馈计算的结果相加来规避:

F(x)+x=[W2σ(W1x+b1)+b2]+x.

这在 [ 1 ]中示意性地表示为:

在此处输入图像描述

另一方面,众所周知,Transformer 架构有一些残差网络,如下图所示:

在此处输入图像描述

问题:残差连接在非常深的网络架构的背景下被激发,但与 [ 1 ]中表现出色的网络相比,注意力块执行的计算量非常少;那么,transformer 架构的注意力块中存在快捷连接的动机是什么?

1个回答

在 Transformer 中使用残差连接的原因更多是技术性的,而不是架构设计的动机。

残差连接主要有助于缓解梯度消失问题。在反向传播期间,信号乘以激活函数的导数。在 ReLU 的情况下,这意味着在大约一半的情况下,梯度为零。如果没有残差连接,大部分训练信号会在反向传播过程中丢失。残差连接减少了影响,因为求和相对于导数是线性的,因此每个残差块也得到一个不受梯度消失影响的信号。残差连接的求和操作在计算图中形成了一条梯度不会丢失的路径。

残差连接的另一个影响是信息在 Transformer 层堆栈中保持本地化。自注意力机制允许网络中的任意信息流,从而任意排列输入标记。然而,残差连接总是“提醒”原始状态的表示。在某种程度上,残差连接保证了输入标记的上下文表示真正代表了标记。