博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
8.8LSTM作为元学习器学习梯度下降
阅读量:3951 次
发布时间:2019-05-24

本文共 3152 字,大约阅读时间需要 10 分钟。

文章目录

本节讨论的是“元学习”,我们将讨论用 LSTM 做优化器的方法。我们所指的训练,是指训练优化器 LSTM 的参数。

1、介绍

我们可以把元学习的过程,当然 RNN 来看。首先复习 RNN 。多数时候,我们说 RNN ,其实就是在用 LSTM 。

比较 LSTM 中的运算与 meta-learning 中的更新式,会发现有一些相似之处。

在实际操作中,我们做了非常大的简化。

此外,我们很久以前介绍过一些优化方法,如 RMSProp 、Momentum ,是需要过去的梯度来确定现在的梯度的。因此,可以根据这个为 LSTM 做些改进。

具体LSTM:

2、引入 RNN

image-20210307145859598

上次说到meta learning是在训练一个learning algorithm。使用的方法是梯度下降,这次我们研究如何将这个learning algorithm看作是一个lstm,我们通过训练这个lstm网络来实现我们的meta learning。

我们观察这个网络,发现很像是一个RNN,我们的training data就像我们rnn中的输入 x x x,之后参数 ϕ \phi ϕ就像是之前RNN中的 h i h^i hi,不断的更新。

image-20210307150141800

我们先来简短的回顾一下什么是RNN。输入 x x x是输入的各个序列的字符(可以理解为是吃进去的字),其二是一个 h h h h h h可以是参数,也可以是人为设定的。输出是y和h。我们要注意的是输入的h和输出的h一定要具有相同的格式的。

LSTM

image-20210307150316563

我们再来回顾一下LSTM和普通的RNN之间有什么不同的地方。RNN是两个输入,两个输出,而LSTM是具有三个输入,三个输出。其中 c c c是变化很小的,而h是具有很大变化的。所以很多人都会说 c c c是可以存储很久远记忆的, c c c也是可以遗忘信息的。

image-20210307150543514

我们再来回顾一下LSTM的计算流程,就是我们将 h h h x x x接在一起,之后和一个参数矩阵w相称乘,之后再加上一个激活函数得到 z z z。之后再用x和h接在一起,之后乘上一个参数矩阵 w i w^i wi,之后加一个激活函数得到 z i z^i zi,同理得到 z f z^f zf z o z^o zo z i z^i zi为输入门, z f z^f zf为遗忘门, z o z^o zo为输出门。

image-20210307160137188

我们得到 z z z z i z^i zi z f z^f zf z o z^o zo以后,我们将zf和ct-1做点乘,再加上 z i z^i zi z z z做点乘得到新的 c t c^t ct。之后再将 c t c^t ct经过一个激活函数tanh,之后再和 z o z^o zo做点乘,得到 h t h^t ht。之后再用 h t h^t ht和w做乘法,再加一个激活函数,得到 y t y^t yt。以上就是LSTM中一个计算单元的全部计算过程。

我们刚才得到了新的参数 h t h^t ht c t c^t ct,之后又有新的输入 x t + 1 x_{t+1} xt+1。之后我们就不断的进行循环计算,得到最后的值。

3、LSTM变换成梯度下降

image-20210307160635803

我们将LSTM的式子和我们最初的梯度下降的式子都罗列出来看看,上面是梯度下降的式子,下面是LSTM的式子。

我们发现梯度下降左边是 θ t \theta^t θt,右边是 θ t − 1 \theta^{t-1} θt1,之后LSTM的第一个式子左边是 c t c^{t} ct,右边是 c t − 1 c^{t-1} ct1。于是我们就想,可不可以把 c c c当成 θ \theta θ来看呢!其实就是将c当成是神经网络的参数来看待。

我们将 c c c换为 θ \theta θ以后,就变成上图那样,之后我们发现将 z f z^f zf都换为1,之后再将 z i z^i zi换为一个常量矩阵。

为了使两个式子更相像,我们把 LSTM 的 input 从 h t + 1 , x t h^{t+1}, x^t ht+1,xt 变成 − ∇ θ l -\nabla_\theta l θl

上图,将 z z z换为梯度的负数 − ∇ θ l -\nabla_\theta l θl ,整个LSTM就换成是一个梯度下降的式子。可以理解为梯度下降就是LSTM的特殊形式。

我们通过以上的分析计算,发现其实 z f z^f zf z i z^i zi等等都是直接给出的,但是我们现在想能否将其动态的学习出来呢!这将是我们所想要解决的问题。我们计算zf时候,就像是在做一个回归的运算,而在计算 z i z^i zi的时候,更像是在动态的决定一个learning rate过程。

我们的输入现在是负数的梯度,我们想能否换成是别的信息呢!

4、LSTM学习梯度下降

image-20210307161703376

我们来看一下梯度下降版的LSTM是什么情况的。计算公式如上图所示。我们的输入是一个 θ 0 \theta^0 θ0。同时我们提取出一个batch的数据,之后我们将其通过参数 θ \theta θ,得到梯度的负数,我们将这两个作为“LSTM”的输入,之后进行训练,假设我们循环三次LSTM,之后得到模型参数 θ 3 \theta^3 θ3,然后在Testing Data上计算loss,调整LSTM,使得loss最小,通过损失函数的梯度下降,不断的更新参数。

image-20210307161805035

如上,我们使用“简化版LSTM”,更新参数。我们的目标 loss 即为最终算出的 l ( θ i ) l(\theta^i) l(θi)

有几个要注意的地方:

  1. 在一般的 LSTM 里面,每次的 c c c x x x 是无关的;

  2. 而在我们这里,现在的 θ \theta θ 会影响到 − ∇ θ l -\nabla_\theta l θl

很多人都说,我们神经网络有成千上万的参数,难道我们需要成千上万个LSTM吗!我们训练一千个LSTM就很费劲了,那么我们要怎么做呢!

我们其实就训练一个LSTM的神经元,之后所有的参数,都使用这样的LSTM来更新。那么又会有人问,使用相同的LSTM的话,难道不会使所有的参数训练以后都相同吗!其实是不会的,因为我们参数的初始值是不相同的,所以我们训练的梯度也是不相同的,所以我们得到的参数也不相同。

我们这样做其实是有自身的道理的:

  1. 我们不可能训练成千上万个LSTM的神经元,因为过于巨大
  2. 我们以往的更新策略其实也是对于不同的参数采用相同的更新策略
  3. 我们以往使用MAML的时候我么训练的模型和测试的模型必须是相同的,但是现在我们没有这方面的顾虑了,但是现在不一样了,因为我们使用相同的LSTM更新参数,无论参数个数多少都无所谓了。

实验结果

image-20210307162657517

如上,我们发现 Forget gate 一般保持在 1 左右;并且,其学习率是动态变化的。

image-20210307162926232

通过我们之前的学习,我们知道很多时候参数的更新不仅仅是依靠当前的梯度,很多时候是依靠一直以来的梯度加在一起所得到的结果的。但是我们刚才并没有思考这些,所以我们现在思考能不能将过去的梯度和现在的梯度做一个整合处理。

5、梯度下降LSTM +动量

image-20210307163056409

如上,我们再加一层绿色的 LSTM ,希望 m 存储过去的梯度。在架构中加入“过去的梯度”,让过去的梯度也参与现在梯度的决定,类似优化器 Momentum

实验结果

image-20210307163314365

如上,LSTM 作为优化器,得到了很惊人的效果。而这个 LSTM 中的参数在 1 * 20 的小神经网络中训练好了;拿到测试任务中,也训练得起来。

但是如最后一张图,训练时与测试时使用不同的激活函数,会坏掉。

参考资料

李宏毅2020人类语言处理

转载地址:http://bdyzi.baihongyu.com/

你可能感兴趣的文章
技术人攻略访谈二十九:平行世界守护者
查看>>
制作initramfs/initrd镜像
查看>>
浅析busybox查找命令和调用相应命令函数的实现流程框架
查看>>
利用linux dd和tr命令生成特定的数据
查看>>
Fundamentals of battery fuel-gauging
查看>>
armlinux内核启动--内存初始化管理
查看>>
rk3188--4.android用initrd文件系统启动流程
查看>>
rk3188--3.initramfs_data.cpio的生成及使用
查看>>
小议基于Android平台的流媒体播放器的设计 转载
查看>>
linux 2.6 输入子系统 键盘驱动的实现
查看>>
Linux Input Device
查看>>
学习ARM+Linux的很好的资料
查看>>
linux spi子系统 驱动分析续
查看>>
linux设备模型深探
查看>>
SPI设备的驱动
查看>>
Linux 2.6下SPI设备模型--------基于AT91RM9200分析
查看>>
struct device 结构
查看>>
S3C2440上触摸屏驱动实例开发讲解
查看>>
一个基于linux2.6内核下S3C2410触摸屏驱动
查看>>
Linux 内核/sys 文件系统介绍
查看>>