在我们了解xLSTM和Mamba之前,我们有必要先了解一下之前的前置时序模型(循环神经网络)。

LSTM的出现是由于RNN模型出现的梯度消失和梯度爆炸的问题。

1 RNN Model

RNN的模型架构如下:

RNN的计算公式如下:

2 RNN的梯度消失和梯度爆炸

说起梯度消失和梯度爆炸,我们先来看看梯度怎么算,具体来说是反向传播过程中链式法则如何计算梯度。

梯度实际上就是求loss针对需要更新的参数的偏导,故哪些参数需要优化和更新,哪些参数就需要求倒数。

我们开始计算RNN的梯度,RNN需要更新的参数为Wxh, Whh, Whq三个权重参数:

在Whh和Wxh的求导中,如果时序过长,中间会出现大量的ht对ht-1的偏导,这其中还夹杂激活函数的导数,如果激活函数或者ht对ht-1的偏导小于1或者都大于1,则很容易导致梯度消失或者梯度爆炸的问题。

3 LSTM Model

LSTM的架构如下:

LSTM的公式原理如下:

4 LSTM如何解决梯度消失和梯度爆炸

在计算loss对W的梯度时,绕不过Ct对Ct-1的梯度,故我们先计算Ct对Ct-1的梯度:

当ft接近1时,这个梯度整体是有可能大于1的。

  • 1 cell state传播函数中的“加法”结构确实起了一定作用,它使得导数有可能大于1
  • 2 LSTM中逻辑门的参数可以一定程度控制不同时间步梯度消失的程度。

5 为什么说细胞状态是全局信息,隐状态是局部信息

我们从公式来说,Ct来源于Ct-1和~Ct,而且分别由遗忘门和输入门控制,这导致Ct无法做大幅度的更新,即既无法完全使用来自输入门的信息,也无法做到对过去信息的完全遗忘。

故细胞状态Ct为全局信息的解释:

  • 1 Ct存在过去Ct-1的信息;
  • 2 Ct的信息更新幅度较小,比较稳定,同时存在一部分来自输入的信息~Ct,也存在一定的过去的信息Ct-1

隐状态ht为局部信息的解释:

  • 1 虽然ht的来源是Ct,但是由于其只受到一个输出门的控制,该门对于信息的过滤有一票否决权,同时该门是受到输入X的控制,故认为相比Ct更偏局部信息

6 参考文献