发布时间:2023-10-06 11:30
正向传播是指对神经网络沿着从输入层到输出层的顺序,依次计算并存储模型的中间变量(包括输出)。
假设输入是一个特征为 x ∈ R d \boldsymbol{x} \in \mathbb{R}^d x∈Rd的样本,且不考虑偏差项,那么中间变量
z = W ( 1 ) x , \boldsymbol{z} = \boldsymbol{W}^{(1)} \boldsymbol{x}, z=W(1)x,
(矩阵相乘)
其中 W ( 1 ) ∈ R h × d \boldsymbol{W}^{(1)} \in \mathbb{R}^{h \times d} W(1)∈Rh×d是隐藏层的权重参数。把中间变量 z ∈ R h \boldsymbol{z} \in \mathbb{R}^h z∈Rh输入按元素运算的激活函数 ϕ \phi ϕ后,将得到向量长度为 h h h的隐藏层变量
h = ϕ ( z ) . \boldsymbol{h} = \phi (\boldsymbol{z}). h=ϕ(z).
隐藏层变量 h \boldsymbol{h} h也是一个中间变量。假设输出层参数只有权重 W ( 2 ) ∈ R q × h \boldsymbol{W}^{(2)} \in \mathbb{R}^{q \times h} W(2)∈Rq×h,可以得到向量长度为 q q q的输出层变量
o = W ( 2 ) h . \boldsymbol{o} = \boldsymbol{W}^{(2)} \boldsymbol{h}. o=W(2)h.
假设损失函数为 ℓ \ell ℓ,且样本标签为 y y y,可以计算出单个数据样本的损失项
L = ℓ ( o , y ) . L = \ell(\boldsymbol{o}, y). L=ℓ(o,y).
根据 L 2 L_2 L2范数正则化的定义,给定超参数 λ \lambda λ,正则化项即(超参数 λ \lambda λ即表示惩罚的力度)
s = λ 2 ( ∣ W ( 1 ) ∣ F 2 + ∣ W ( 2 ) ∣ F 2 ) , s = \frac{\lambda}{2} \left(|\boldsymbol{W}^{(1)}|_F^2 + |\boldsymbol{W}^{(2)}|_F^2\right), s=2λ(∣W(1)∣F2+∣W(2)∣F2),
其中矩阵的Frobenius范数等价于将矩阵变平为向量后计算 L 2 L_2 L2范数。最终,模型在给定的数据样本上带正则化的损失为
J = L + s . J = L + s. J=L+s.
我们将 J J J称为有关给定数据样本的目标函数。
反向传播用于计算神经网络中的参数梯度。反向传播利用微积分中的链式法则,沿着从输出层到输入层的顺序进行依次计算目标函数有关神经网络各层的中间变量以及参数的梯度。
依据链式法则,我们可以知道:
∂ J ∂ o = prod ( ∂ J ∂ L , ∂ L ∂ o ) = ∂ L ∂ o . \frac{\partial J}{\partial \boldsymbol{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \boldsymbol{o}}\right) = \frac{\partial L}{\partial \boldsymbol{o}}. ∂o∂J=prod(∂L∂J,∂o∂L)=∂o∂L.
( ∂ J ∂ L = 1 , ∂ J ∂ s = 1 ) \left( \frac{\partial J}{\partial L} = 1, \quad \frac{\partial J}{\partial s} = 1\right) (∂L∂J=1,∂s∂J=1)
其中 prod \text{prod} prod运算符将根据两个输入的形状,在必要的操作(如转置和互换输入位置)后对两个输入做乘法。
∂ J ∂ W ( 2 ) = prod ( ∂ J ∂ o , ∂ o ∂ W ( 2 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 2 ) ) = ∂ J ∂ o h ⊤ + λ W ( 2 ) \frac{\partial J}{\partial \boldsymbol{W}^{(2)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(2)}}\right) = \frac{\partial J}{\partial \boldsymbol{o}} \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)} ∂W(2)∂J=prod(∂o∂J,∂W(2)∂o)+prod(∂s∂J,∂W(2)∂s)=∂o∂Jh⊤+λW(2)
其中:
( ∂ s ∂ W ( 1 ) = λ W ( 1 ) , ∂ s ∂ W ( 2 ) = λ W ( 2 ) ) \left(\frac{\partial s}{\partial \boldsymbol{W}^{(1)}} = \lambda \boldsymbol{W}^{(1)},\quad\frac{\partial s}{\partial \boldsymbol{W}^{(2)}} = \lambda \boldsymbol{W}^{(2)}\right) (∂W(1)∂s=λW(1),∂W(2)∂s=λW(2))
还有:
∂ J ∂ W ( 2 ) = prod ( ∂ J ∂ o , ∂ o ∂ W ( 2 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 2 ) ) = ∂ J ∂ o h ⊤ + λ W ( 2 ) \frac{\partial J}{\partial \boldsymbol{W}^{(2)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(2)}}\right) = \frac{\partial J}{\partial \boldsymbol{o}} \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)} ∂W(2)∂J=prod(∂o∂J,∂W(2)∂o)+prod(∂s∂J,∂W(2)∂s)=∂o∂Jh⊤+λW(2)
∂ J ∂ h = prod ( ∂ J ∂ o , ∂ o ∂ h ) = W ( 2 ) ⊤ ∂ J ∂ o \frac{\partial J}{\partial \boldsymbol{h}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{h}}\right) = {\boldsymbol{W}^{(2)}}^\top \frac{\partial J}{\partial \boldsymbol{o}} ∂h∂J=prod(∂o∂J,∂h∂o)=W(2)⊤∂o∂J
∂ J ∂ z = prod ( ∂ J ∂ h , ∂ h ∂ z ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) \frac{\partial J}{\partial \boldsymbol{z}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{h}}, \frac{\partial \boldsymbol{h}}{\partial \boldsymbol{z}}\right) = \frac{\partial J}{\partial \boldsymbol{h}} \odot \phi'\left(\boldsymbol{z}\right) ∂z∂J=prod(∂h∂J,∂z∂h)=∂h∂J⊙ϕ′(z)
所以,可以得到:
∂ J ∂ W ( 1 ) = prod ( ∂ J ∂ z , ∂ z ∂ W ( 1 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 1 ) ) = ∂ J ∂ z x ⊤ + λ W ( 1 ) \frac{\partial J}{\partial \boldsymbol{W}^{(1)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{z}}, \frac{\partial \boldsymbol{z}}{\partial \boldsymbol{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(1)}}\right) = \frac{\partial J}{\partial \boldsymbol{z}} \boldsymbol{x}^\top + \lambda \boldsymbol{W}^{(1)} ∂W(1)∂J=prod(∂z∂J,∂W(1)∂z)+prod(∂s∂J,∂W(1)∂s)=∂z∂Jx⊤+λW(1)