在我们了解xLSTM和Mamba之前,我们有必要先了解一下之前的前置时序模型(循环神经网络)。
LSTM的出现是由于RNN模型出现的梯度消失和梯度爆炸的问题。
1 RNN Model
RNN的模型架构如下:
RNN的计算公式如下:
2 RNN的梯度消失和梯度爆炸
说起梯度消失和梯度爆炸,我们先来看看梯度怎么算,具体来说是反向传播过程中链式法则如何计算梯度。
梯度实际上就是求loss针对需要更新的参数的偏导,故哪些参数需要优化和更新,哪些参数就需要求倒数。
我们开始计算RNN的梯度,RNN需要更新的参数为Wxh, Whh, Whq三个权重参数:
在Whh和Wxh的求导中,如果时序过长,中间会出现大量的ht对ht-1的偏导,这其中还夹杂激活函数的导数,如果激活函数或者ht对ht-1的偏导小于1或者都大于1,则很容易导致梯度消失或者梯度爆炸的问题。
3 LSTM Model
LSTM的架构如下:
LSTM的公式原理如下:
4 LSTM如何解决梯度消失和梯度爆炸
在计算loss对W的梯度时,绕不过Ct对Ct-1的梯度,故我们先计算Ct对Ct-1的梯度:
当ft接近1时,这个梯度整体是有可能大于1的。
- 1 cell state传播函数中的“加法”结构确实起了一定作用,它使得导数有可能大于1
- 2 LSTM中逻辑门的参数可以一定程度控制不同时间步梯度消失的程度。
5 为什么说细胞状态是全局信息,隐状态是局部信息
我们从公式来说,Ct来源于Ct-1和~Ct,而且分别由遗忘门和输入门控制,这导致Ct无法做大幅度的更新,即既无法完全使用来自输入门的信息,也无法做到对过去信息的完全遗忘。
故细胞状态Ct为全局信息的解释:
- 1 Ct存在过去Ct-1的信息;
- 2 Ct的信息更新幅度较小,比较稳定,同时存在一部分来自输入的信息~Ct,也存在一定的过去的信息Ct-1。
隐状态ht为局部信息的解释:
- 1 虽然ht的来源是Ct,但是由于其只受到一个输出门的控制,该门对于信息的过滤有一票否决权,同时该门是受到输入X的控制,故认为相比Ct更偏局部信息。