前言

  本文我们构造了带动量项SGD的统一框架并给出了其PyTorch实现。

HB与NAG

  在上一篇文章中:

我们介绍了Heavy-Ball Method,介绍了并讨论了HB+GD与HB+SGD在一些情况下的收敛速度。HB+SGD可写为如下形式:

$$ \boldsymbol{\theta}^{(\tau+1)}=\begin{cases} \boldsymbol{\theta}^{(\tau)}-\eta^{(\tau)} \cdot \boldsymbol{g}_{\tau},&\tau =0\\ \boldsymbol{\theta}^{(\tau)}-\eta^{(\tau)} \cdot \boldsymbol{g}_{\tau}+\beta\cdot (\boldsymbol{\theta}^{(\tau)}-\boldsymbol{\theta}^{(\tau-1)}),&\tau \geq 1 \end{cases} \tag{1} $$

我们介绍到,HB可以写成下面的一种multi-stage的形式:

$$ \begin{cases} \boldsymbol{m}^{(\tau)}=\beta\boldsymbol{m}^{(\tau-1)}-\eta^{(\tau)}\boldsymbol{g}_{\tau} \\ \boldsymbol{\theta}^{(\tau+1)} = \boldsymbol{\theta}^{(\tau)} + \boldsymbol{m}^{(\tau)} \end{cases}\tag{2} $$

其中$\tau\geq 1,\boldsymbol{m}^{(0)}=-\eta^{(0)}\boldsymbol{g}_0$。
  此外,苏联在优化领域的卓越学者Nesterov提出另外一种基于动量项的算法(参考资料1):

$$ \boldsymbol{\theta}^{(\tau+1)}=\begin{cases} \boldsymbol{\theta}^{(\tau)}-\eta^{(\tau)} \cdot \boldsymbol{g}_{\tau}-\beta\eta^{(\tau)}\boldsymbol{g}_{\tau} ,&\tau =0\\ \boldsymbol{\theta}^{(\tau)}-\eta^{(\tau)}\nabla\mathcal L(\boldsymbol{\theta}^{(\tau)}+\beta(\boldsymbol{\theta}^{(\tau)}-\boldsymbol{\theta}^{(\tau-1)}))+\beta(\boldsymbol{\theta}^{(\tau)}-\boldsymbol{\theta}^{(\tau-1)}),&\tau\geq 1 \end{cases} \tag{3} $$

通常称为Nesterov accelerated gradient(NAG)。当然它也可以写成multi-stage的形式:

$$ \begin{cases} \boldsymbol{m}^{(\tau)}=\beta \boldsymbol{m}^{(\tau-1)}-\eta^{(\tau)}\nabla\mathcal L(\boldsymbol{\theta}^{(\tau)}+\beta(\boldsymbol{\theta}^{(\tau)}-\boldsymbol{\theta}^{(\tau-1)})) \\ \boldsymbol{\theta}^{(\tau+1)} = \boldsymbol{\theta}^{(\tau)} + \boldsymbol{m}^{(\tau)} \end{cases}\tag{4} $$

其中$\tau\geq 1,\boldsymbol{m}^{(0)}=-\eta^{(0)}\boldsymbol{g}_0-\beta\eta^{(0)}\boldsymbol{g}_0$。
  直观上而言,NAG的改进在于不是在$\boldsymbol{\theta}^{(\tau)}$处做梯度更新,而是前瞻一步在超前一个动量单位处:$\boldsymbol{\theta}^{(\tau)}+\beta(\boldsymbol{\theta}^{(\tau)}-\boldsymbol{\theta}^{(\tau-1)})$计算梯度。
  下图展示了HB与NAG的参数更新过程:

NAG加速更新

可见在未达到最优值前,$\boldsymbol{\theta}^{(\tau)}+\beta(\boldsymbol{\theta}^{(\tau)}-\boldsymbol{\theta}^{(\tau-1)})$处的梯度方向与动量方向的夹角为锐角, 亦可看作是向前“试探” 一步后并未发现前方梯度出现逆转,故而NAG起到了加速更新的效果。而当$\boldsymbol{\theta}^{(\tau)}+\beta(\boldsymbol{\theta}^{(\tau)}-\boldsymbol{\theta}^{(\tau-1)})$处的梯度方向与动量方向的夹角为钝角时,NAG的参数更新如下图所示,可见此时NAG能有效地减弱振荡。

NAG减弱振荡

带动量项SGD的统一框架

  结合(1)-(4)式,我们给出带动量项SGD的统一框架:

$$ \begin{cases} \boldsymbol{m}^{(\tau)}= \boldsymbol{\theta}^{(\tau-1)}-s\eta^{(\tau-1)}\boldsymbol{g}_{\tau-1} \\ \boldsymbol{\theta}^{(\tau)} =\boldsymbol{\theta}^{(\tau-1)}-\eta^{(\tau-1)}\boldsymbol{g}_{\tau-1}+\beta( \boldsymbol{m}^{(\tau)}-\boldsymbol{m}^{(\tau-1)}) \end{cases}\tag{5} $$

其中$s\in \{0,1\},\tau\geq 1,\boldsymbol{m}^{(0)}=\boldsymbol{\theta}^{(0)}$。当$s=0$时,上式退化为HB,(5)与(1)的等价性是显然的。而当$s=1$时,(5)与(4)得出的$\boldsymbol{\theta}^{(1)}$是相同的,而在$\tau\geq 1$时,对于(4),我们记:

$$ \boldsymbol{v}^{(\tau)}=\boldsymbol{\theta}^{(\tau)}+\beta(\boldsymbol{\theta}^{(\tau)}-\boldsymbol{\theta}^{(\tau-1)})\tag{6}=\boldsymbol{\theta}^{(\tau)}+\beta\boldsymbol{m}^{(\tau-1)} $$

所以:

$$ \boldsymbol{m}^{(\tau)}=\beta\boldsymbol{m}^{(\tau-1)}-\eta^{(\tau)}\nabla \mathcal L(\boldsymbol{v}^{(\tau)})\tag{7} $$

此外:

$$ \begin{align*} \boldsymbol{\theta}^{(\tau+1)}&=\boldsymbol{\theta}^{(\tau)}+\boldsymbol{m}^{(\tau)}\\ \Rightarrow\boldsymbol{v}^{(\tau+1)}-\beta\boldsymbol{m}^{(\tau)}&=\boldsymbol{v}^{(\tau)}-\beta\boldsymbol{m}^{(\tau-1)}+\boldsymbol{m}^{(\tau)}\\ \Rightarrow \boldsymbol{v}^{(\tau+1)}&=\boldsymbol{v}^{(\tau)}+\beta\boldsymbol{m}^{(\tau)}-\eta^{(\tau)}\nabla \mathcal L(\boldsymbol{v}^{(\tau)})\tag{8} \end{align*} $$

结合(7)(8),我们得出了NAG的一个等价形式:

$$ \begin{cases} \boldsymbol{m}^{(\tau)}=\beta \boldsymbol{m}^{(\tau-1)}-\eta^{(\tau)}\boldsymbol{g}_{\tau} \\ \boldsymbol{\theta}^{(\tau)} = \boldsymbol{\theta}^{(\tau-1)} + \beta\boldsymbol{m}^{(\tau-1)}-\eta^{(\tau-1)}\boldsymbol{g}_{\tau-1} \end{cases}\tag{9} $$

其中$\boldsymbol{m}^{(0)}=\boldsymbol{g}_0,\tau\geq 1$。事实上,(9)正是TensorFlow中NAG的写法(见参考资料4)。现在证明(9)与(5)等价就很简单了,留作读者练习。参考资料2、3给出了类似于(5)的表达,称之为:Stochastic unified momentum(SUM)。容易看出,(5)式的NAG动量避免了计算$\nabla\mathcal L(\boldsymbol{\theta}^{(\tau)}+\beta(\boldsymbol{\theta}^{(\tau)}-\boldsymbol{\theta}^{(\tau-1)}))$,使得其更容易实现。
  那么PyTorch中带动量项SGD的统一框架是不是类似的形式呢?根据参考资料5:

$$ \begin{cases} \boldsymbol{b}^{(\tau)}=\beta\boldsymbol{b}^{(\tau-1)}+(1-k)\boldsymbol{g}_{\tau}\\ \boldsymbol{m}^{(\tau)}=s\boldsymbol{g}_\tau+(s\beta-s+1)\boldsymbol{b}^{(\tau)}\\ \boldsymbol{\theta}^{(\tau)}=\boldsymbol{\theta}^{(\tau-1)}-\eta^{(\tau-1)}\boldsymbol{m}^{(\tau-1)} \end{cases} \tag{10} $$

其中$s\in \{0,1\},k\in[0,1],\tau\geq 1,\boldsymbol{b}^{(0)}=\boldsymbol{g}_0,\boldsymbol{m}^{(0)}=(s\beta+1)\boldsymbol{g}_0$。可以证明,在步长恒定:$\eta^{(\tau)}\equiv \eta$,$k=0$时,(10)与(5)是等价的。

SUM的PyTorch实现

  接下来,我们用PyTorch实现一下(5)式所示的SUM。

import torch
from torch.optim.optimizer import Optimizer, required

class SUM(Optimizer):
    def __init__(self, params, lr=required, beta=0.0, s=False):
        if lr <= 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if beta < 0.0:
            raise ValueError("Invalid momentum value: {}".format(beta))
        if s and beta <= 0:
            raise ValueError("Nesterov momentum requires a momentum.")
        
        defaults = dict(lr=lr, beta=beta, s=s)
        super(SUM, self).__init__(params, defaults)
    
    def __setstate__(self, state):
        super(SUM, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            beta = group['beta']
            s = group['s']
            lr = group['lr']
            
            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    d_p_list.append(p.grad)
                    
                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        momentum_buffer_list.append(None)
                    else:
                        momentum_buffer_list.append(state['momentum_buffer'])
            
            for i, param in enumerate(params_with_grad):
                d_p = d_p_list[i]
                m_tau, buf = 0, 0
                if beta != 0:
                    buf = momentum_buffer_list[i]  # m^{(\tau-1)}
                    if buf is None:
                        if s:
                            momentum_buffer_list[i] = torch.clone(param).detach() - lr * d_p
                            param.add_(d_p, alpha=-lr - beta)
                        else:
                            momentum_buffer_list[i] = torch.clone(param).detach()
                            param.add_(d_p, alpha=-lr)
                        continue
                    elif s:
                        m_tau = param - lr * d_p
                    else:
                        m_tau = param
                    param.add_(-lr * d_p + beta * (m_tau - buf))
                    momentum_buffer_list[i] = m_tau
                else:
                    param.add_(d_p, alpha=-lr)
            
            # update momentum_buffers in state
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                state = self.state[p]
                state['momentum_buffer'] = momentum_buffer

References

  1. A method for solving the convex programming problem with convergence rate $O(\frac {1}{k^ 2} )$,Dokl. Akad. Nauk SSSR,1983
  2. Unified convergence analysis of stochastic momentum methods for convex and non-convex optimizationYang Tianbao,arXiv preprint arXiv:1604.03257,2016
  3. A unified analysis of stochastic momentum methods for deep learningYan Yan,arXiv preprint arXiv:1808.10396,2018
  4. https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/experimental/SGD
  5. https://pytorch.org/docs/stable/generated/torch.optim.SGD.html?highlight=sgd
如果觉得我的文章对你有用,请随意赞赏