MindSpore易点通·精讲系列--网络构建之LSTM算子

发布时间:2023-04-02 15:00

Dive Into MindSpore–LSTM Operator For Network Construction

MindSpore易点通·精讲系列–网络构建之LSTM算子

本文开发环境

  • MindSpore 1.7.0

本文内容提要

  • 原理介绍
  • 文档说明
  • 案例解说
  • 本文总结
  • 本文参考

1. 原理介绍

LSTM,Long Short Term Memory,又称长短时记忆网络。原始RNN存在一个严重的缺陷:训练过程中经常会出现梯度爆炸和梯度消失的问题,以至于原始的RNN很难处理长距离的依赖,为了解决(缓解)这个问题,研究人员提出了LSTM。

1.1 LSTM公式

LSTM的公式表示如下:
MindSpore易点通·精讲系列--网络构建之LSTM算子_第1张图片
其中 σ是sigmoid激活函数, *是乘积。 W, b 是公式中输出和输入之间的可学习权重。

1.2 LSTM结构

为方便理解,1.1中的公式的结构示意图如下:
MindSpore易点通·精讲系列--网络构建之LSTM算子_第2张图片

1.3 LSTM门控

1.3.1 遗忘门

遗忘门公式为:

f t = σ ( W f x x t + b f x + W f h h ( t − 1 ) + b f h ) f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) ft=σ(Wfxxt+bfx+Wfhh(t1)+bfh)

解读:

“遗忘门”决定之前状态中的信息有多少应该舍弃。它会读取 h t − 1 h_{t-1} ht1 x t x_t xt的内容, σ \sigma σ符号代表Sigmoid函数,它会输出一个0到1之间的值。其中0代表舍弃之前细胞状态 C t − 1 C_{t-1} Ct1中的内容,1代表完全保留之前细胞状态 C t − 1 C_{t-1} Ct1中的内容。0、1之间的值代表部分保留之前细胞状态 C t − 1 C_{t-1} Ct1中的内容。

1.3.2 输入门

输入门公式为:

i t = σ ( W i x x t + b i x + W i h h ( t − 1 ) + b i h ) c ~ t = tanh ⁡ ( W c x x t + b c x + W c h h ( t − 1 ) + b c h ) i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\ \tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) it=σ(Wixxt+bix+Wihh(t1)+bih)c~t=tanh(Wcxxt+bcx+Wchh(t1)+bch)

解读:

“输入门”决定什么样的信息保留在细胞状态 C t C_t Ct中,它会读取 h t − 1 h_{t-1} ht1 x t x_t xt的内容, σ \sigma σ符号代表Sigmoid函数,它会输出一个0到1之间的值。

和“输入门”配合的还有另外一部分,这部分输入也是 h t − 1 h_{t-1} ht1 x t x_t xt,不过采用tanh激活函数,将这部分标记为 c ~ ( t ) \tilde c^{(t)} c~(t),称作为“候选状态”。

1.3.3 细胞状态

细胞状态公式为:

c t = f t ∗ c ( t − 1 ) + i t ∗ c ~ t c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t ct=ftc(t1)+itc~t

解读:

C t − 1 C_{t-1} Ct1 计算得到 C t C_t Ct

旧“细胞状态” C t − 1 C_{t-1} Ct1和“遗忘门”的结果进行计算,决定旧的“细胞状态”保留多少,忘记多少。接着“输入门” i ( t ) i^{(t)} i(t)和候选状态 c ~ ( t ) \tilde c^{(t)} c~(t)进行计算,将所得到的结果加入到“细胞状态”中,这表示新的输入信息有多少加入到“细胞状态中”。

1.3.4 输出门

输出门公式为:

o t = σ ( W o x x t + b o x + W o h h ( t − 1 ) + b o h ) h t = o t ∗ tanh ⁡ ( c t ) o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\ h_t = o_t * \tanh(c_t) ot=σ(Woxxt+box+Wohh(t1)+boh)ht=ottanh(ct)

解读:

和其他门计算一样,它会读取 h t − 1 h_{t-1} ht1 x t x_t xt的内容,然后计算Sigmoid函数,得到“输出门”的值。接着把“细胞状态”通过tanh进行处理(得到一个在-1到1之间的值),并将它和输出门的结果相乘,最终得到确定输出的部分 h t h_t ht,即新的隐藏状态。

特别说明:

在上述公式中,xt为当前的输入,h(t-1)为上一步的隐藏状态,c(t-1)为上一步的细胞状态。

当t=1时,可知h(t-1)为h0,c(t-1)为c0

一般来说,h0/c0设置为0或1,或固定的随机值。

2. 文档说明

下面来看看官网文档说明,主要看参数部分:
MindSpore易点通·精讲系列--网络构建之LSTM算子_第3张图片

从官方文档可知,MindSpore中的LSTM算子支持多层双向设置,同时可接受输入数据第一维为非batch_size的情况,而且自带dropout。

下面通过案例来对该算子的输入和输出进行讲解。

3. 案例解说

3.1 单层正向LSTM

本示例中随机生成了[4, 8, 4]数据,该数据batch_size为4,固定seq_length为8,输入维度为4。

本示例采用单层单向LSTM,隐层大小为8。

本示例中LSTM调用时进行对比测试,一个seq_length为默认值None,一个为有效长度input_seq_length

示例代码如下:

import numpy as np

from mindspore import dtype
from mindspore import Tensor
from mindspore.nn import LSTM


def single_layer_lstm():
    random_data = np.random.rand(4, 8, 4)
    seq_length = [3, 8, 5, 1]
    input_seq_data = Tensor(random_data, dtype=dtype.float32)
    input_seq_length = Tensor(seq_length, dtype=dtype.int32)

    batch_size = 4
    input_size = 4
    hidden_size = 8
    num_layers = 1
    bidirectional = False
    num_bi = 2 if bidirectional else 1

    lstm = LSTM(
        input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
        has_bias=True, batch_first=True, dropout=0.0, bidirectional=bidirectional)

    h0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
    c0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))

    output_0, (hn_0, cn_0) = lstm(input_seq_data, (h0, c0))
    output_1, (hn_1, cn_1) = lstm(input_seq_data, (h0, c0), input_seq_length)

    print("====== single layer lstm output 0 shape: {} ======\n{}".format(output_0.shape, output_0), flush=True)
    print("====== single layer lstm hn0 shape: {} ======\n{}".format(hn_0.shape, hn_0), flush=True)
    print("====== single layer lstm cn0 shape: {} ======\n{}".format(cn_0.shape, cn_0), flush=True)

    print("====== single layer lstm output 1 shape: {} ======\n{}".format(output_1.shape, output_1), flush=True)
    print("====== single layer lstm hn1 shape: {} ======\n{}".format(hn_1.shape, hn_1), flush=True)
    print("====== single layer lstm cn1 shape: {} ======\n{}".format(cn_1.shape, cn_1), flush=True)

示例代码输出内容如下:

对输出内容进行分析:

  1. output_0和output_1维度都是[4, 8, 8],即batch_size, seq_length和hidden_size
  2. output_0对应的是调用时seq_length为None的情况,即默认有效seq_length为8,可以看到output_0各个长度输出数值皆非全零。
  3. output_1对应的是调用时seq_length为设定值[3, 8, 5, 1],可以看到output_1超过有效长度的输出部分皆为全零。
  4. hn和cn分别为隐层状态和细胞状态输出。下面以hn_1和cn_1为例进行讲解。
  5. hn_1维度为[1, 4, 8],1代表单向单层(1*1),4代表batch_size,8代表hidden_size。
  6. 仔细观察可以看出,hn_1的输出与output_1最后一维的输出一致,即与有效长度内最后一个的输出保持一致。
  7. cn_1为有效最后一步的细胞状态。
====== single layer lstm output 0 shape: (4, 8, 8) ======
[[[ 0.13193643  0.31574252  0.21773982  0.359429    0.23590101
    0.28213733  0.24443595  0.37388077]
  [-0.02988351  0.1415896   0.15356182  0.2834958  -0.00328176
    0.3491612   0.12643641  0.142024  ]
  [-0.09670443  0.03373189  0.1445203   0.19673887  0.06278481
    0.33509392 -0.02579015  0.07650157]
  [-0.15380219 -0.04781847  0.07795938  0.15893918  0.01305779
    0.33979264 -0.00364386  0.04361304]
  [-0.16254447 -0.06737433  0.05285644  0.10944269  0.01782622
    0.34567034 -0.04204851  0.01285298]
  [-0.21082401 -0.09526701  0.0265205   0.10617667 -0.03112434
    0.33731762 -0.02207689 -0.00955394]
  [-0.23450094 -0.09586379  0.02365175  0.09352495 -0.03744857
    0.33376914 -0.04699665 -0.03528202]
  [-0.24089803 -0.06166056  0.02839395  0.09916345 -0.04156012
    0.31369895 -0.08876226 -0.0487675 ]]

 [[ 0.10673305  0.30631748  0.22279048  0.35392687  0.270858
    0.2800686   0.21576329  0.37215734]
  [ 0.07373721  0.07924869  0.20754944  0.2059646   0.12672944
    0.35556036  0.05576535  0.2124105 ]
  [-0.09233213  0.02507205  0.11608997  0.23507075  0.0269099
    0.3196378   0.00475359  0.05898073]
  [-0.14939436 -0.04166775  0.07941992  0.15797664  0.02167228
    0.34059638 -0.02956495  0.00525782]
  [-0.18659307 -0.08790994  0.04543061  0.12085741  0.01649844
    0.33063915 -0.03531799 -0.01156766]
  [-0.22867033 -0.10603286  0.03872797  0.11688479  0.01904946
    0.3056394  -0.05695718 -0.01623933]
  [-0.21695574 -0.11095987  0.03115554  0.08672465  0.04249544
    0.3152427  -0.07418983 -0.02036544]
  [-0.21967101 -0.10076816  0.01712734  0.08198812  0.02862469
    0.31535396 -0.09173042 -0.05647325]]

 [[ 0.1493079   0.28768584  0.2575181   0.3199168   0.30599245
    0.28865623  0.16678075  0.41237575]
  [ 0.01445133  0.13631815  0.18265024  0.2577204   0.09361918
    0.3227448   0.04080902  0.17163058]
  [-0.1164555   0.05409181  0.1229048   0.24406306  0.02090637
    0.31171325 -0.02868806  0.06015658]
  [-0.12215493 -0.04073931  0.09229688  0.13461691  0.05322267
    0.34697118 -0.04028781  0.05017967]
  [-0.16058712 -0.02990636  0.06711683  0.13881728  0.04944531
    0.30471358 -0.08764775  0.01227296]
  [-0.17542893 -0.04518626  0.06441598  0.12666796  0.1039256
    0.29512212 -0.12625514 -0.01764686]
  [-0.18198647 -0.06205402  0.05437353  0.12312049  0.11571115
    0.27589387 -0.13898477 -0.00659172]
  [-0.18840623 -0.03089028  0.02871101  0.13332503  0.02779378
    0.2934873  -0.12758468 -0.02508291]]

 [[ 0.16055782  0.28248906  0.24979302  0.3381475   0.28849283
    0.3085897   0.21882199  0.3911534 ]
  [ 0.03212452  0.10363571  0.18571742  0.25555134  0.11808199
    0.33315352  0.0612903   0.16566488]
  [-0.09707587  0.08886775  0.130165    0.23324937  0.0596167
    0.28433815 -0.05993269  0.06611289]
  [-0.15705962 -0.00274712  0.09360209  0.18597823  0.04157853
    0.32279128 -0.07580574  0.01155218]
  [-0.15376413 -0.07929687  0.06302985  0.11465057  0.07184268
    0.3261627  -0.05871713  0.04223134]
  [-0.18791473 -0.07859336  0.02364462  0.12526496 -0.02513029
    0.33071572 -0.03542359 -0.00976665]
  [-0.23625109 -0.03007499  0.03267653  0.15940045 -0.08530897
    0.30445266 -0.0852924  -0.04507463]
  [-0.23499809 -0.07687293  0.03790941  0.08663946 -0.00264841
    0.33423126 -0.06512782  0.01413365]]]
====== single layer lstm hn0 shape: (1, 4, 8) ======
[[[-0.24089803 -0.06166056  0.02839395  0.09916345 -0.04156012
    0.31369895 -0.08876226 -0.0487675 ]
  [-0.21967101 -0.10076816  0.01712734  0.08198812  0.02862469
    0.31535396 -0.09173042 -0.05647325]
  [-0.18840623 -0.03089028  0.02871101  0.13332503  0.02779378
    0.2934873  -0.12758468 -0.02508291]
  [-0.23499809 -0.07687293  0.03790941  0.08663946 -0.00264841
    0.33423126 -0.06512782  0.01413365]]]
====== single layer lstm cn0 shape: (1, 4, 8) ======
[[[-0.72842515 -0.10623126  0.07748945  0.23840414 -0.0663506
    0.82394135 -0.20612013 -0.11983471]
  [-0.6431069  -0.17861958  0.04168103  0.20188545  0.0463764
    0.73273325 -0.21914008 -0.13169488]
  [-0.61163914 -0.05123866  0.07892742  0.32583922  0.04181815
    0.79872614 -0.2969701  -0.0625343 ]
  [-0.58037984 -0.15040846  0.09998614  0.24211554 -0.0044073
    0.8616534  -0.1546249   0.03137078]]]
====== single layer lstm output 1 shape: (4, 8, 8) ======
[[[ 0.13193643  0.31574252  0.21773985  0.35942894  0.23590101
    0.28213733  0.24443595  0.37388077]
  [-0.02988352  0.1415896   0.15356182  0.28349578 -0.00328175
    0.34916118  0.12643641  0.142024  ]
  [-0.09670443  0.0337319   0.14452031  0.19673884  0.06278481
    0.33509392 -0.02579015  0.07650157]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]]

 [[ 0.10673306  0.30631748  0.22279048  0.35392687  0.27085796
    0.2800686   0.21576326  0.37215734]
  [ 0.07373722  0.0792487   0.20754944  0.2059646   0.12672943
    0.35556036  0.05576536  0.2124105 ]
  [-0.09233214  0.02507207  0.11608997  0.23507075  0.02690989
    0.3196378   0.00475359  0.05898073]
  [-0.14939436 -0.04166774  0.07941992  0.15797664  0.02167228
    0.34059638 -0.02956495  0.00525782]
  [-0.18659307 -0.08790994  0.04543061  0.12085741  0.01649844
    0.33063915 -0.03531799 -0.01156766]
  [-0.22867033 -0.10603285  0.03872797  0.11688479  0.01904945
    0.3056394  -0.05695718 -0.01623933]
  [-0.21695574 -0.11095986  0.03115554  0.08672465  0.04249543
    0.3152427  -0.07418983 -0.02036544]
  [-0.21967097 -0.10076815  0.01712734  0.08198812  0.02862468
    0.31535396 -0.09173042 -0.05647324]]

 [[ 0.1493079   0.28768584  0.25751814  0.3199168   0.30599245
    0.28865623  0.16678077  0.41237575]
  [ 0.01445133  0.13631816  0.18265024  0.25772038  0.09361918
    0.3227448   0.04080902  0.17163058]
  [-0.1164555   0.05409183  0.1229048   0.24406303  0.02090637
    0.31171325 -0.02868806  0.06015658]
  [-0.12215493 -0.0407393   0.09229688  0.1346169   0.05322267
    0.3469712  -0.0402878   0.05017967]
  [-0.16058712 -0.02990635  0.06711683  0.13881728  0.0494453
    0.30471358 -0.08764775  0.01227296]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]]

 [[ 0.16055782  0.2824891   0.24979301  0.33814746  0.28849283
    0.30858967  0.21882202  0.3911534 ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]]]
====== single layer lstm hn1 shape: (1, 4, 8) ======
[[[-0.09670443  0.0337319   0.14452031  0.19673884  0.06278481
    0.33509392 -0.02579015  0.07650157]
  [-0.21967097 -0.10076815  0.01712734  0.08198812  0.02862468
    0.31535396 -0.09173042 -0.05647324]
  [-0.16058712 -0.02990635  0.06711683  0.13881728  0.0494453
    0.30471358 -0.08764775  0.01227296]
  [ 0.16055782  0.2824891   0.24979301  0.33814746  0.28849283
    0.30858967  0.21882202  0.3911534 ]]]
====== single layer lstm cn1 shape: (1, 4, 8) ======
[[[-0.22198828  0.05788375  0.38487202  0.5277796   0.10692163
    0.88817626 -0.06333658  0.15489307]
  [-0.6431068  -0.17861956  0.04168103  0.20188545  0.04637639
    0.73273325 -0.21914008 -0.13169487]
  [-0.44337854 -0.05043292  0.17615467  0.36942852  0.0769525
    0.8138213  -0.22219141  0.02737183]
  [ 0.50136805  0.47527558  0.8696786   0.7511291   0.37594885
    0.9162327   0.5345433   0.6333548 ]]]

3.2 单层双向LSTM

本示例中随机生成了[4, 8, 4]数据,该数据batch_size为4,固定seq_length为8,输入维度为4。

本示例采用单层双向LSTM,隐层大小为8。

本示例中LSTM调用时进行对比测试,一个seq_length为默认值None,一个为有效长度input_seq_length

示例代码如下:

import numpy as np

from mindspore import dtype
from mindspore import Tensor
from mindspore.nn import LSTM


def single_layer_bi_lstm():
    random_data = np.random.rand(4, 8, 4)
    seq_length = [3, 8, 5, 1]
    input_seq_data = Tensor(random_data, dtype=dtype.float32)
    input_seq_length = Tensor(seq_length, dtype=dtype.int32)

    batch_size = 4
    input_size = 4
    hidden_size = 8
    num_layers = 1
    bidirectional = True
    num_bi = 2 if bidirectional else 1

    lstm = LSTM(
        input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
        has_bias=True, batch_first=True, dropout=0.0, bidirectional=bidirectional)

    h0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
    c0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))

    output_0, (hn_0, cn_0) = lstm(input_seq_data, (h0, c0))
    output_1, (hn_1, cn_1) = lstm(input_seq_data, (h0, c0), input_seq_length)

    print("====== single layer bi lstm output 0 shape: {} ======\n{}".format(output_0.shape, output_0), flush=True)
    print("====== single layer bi lstm hn0 shape: {} ======\n{}".format(hn_0.shape, hn_0), flush=True)
    print("====== single layer bi lstm cn0 shape: {} ======\n{}".format(cn_0.shape, cn_0), flush=True)

    print("====== single layer bi lstm output 1 shape: {} ======\n{}".format(output_1.shape, output_1), flush=True)
    print("====== single layer bi lstm hn1 shape: {} ======\n{}".format(hn_1.shape, hn_1), flush=True)
    print("====== single layer bi lstm cn1 shape: {} ======\n{}".format(cn_1.shape, cn_1), flush=True)

示例代码输出内容如下:

对输出内容进行分析:

  1. output_0和output_1维度都是[4, 8, 16],即batch_size, seq_length和hidden_size * 2,这里乘2是因为是双向输出。
  2. output_0对应的是调用时seq_length为None的情况,即默认有效seq_length为8,可以看到output_0各个长度输出数值皆非全零。
  3. output_1对应的是调用时seq_length为设定值[3, 8, 5, 1],可以看到output_1超过有效长度的输出部分皆为全零。
  4. hn和cn分别为隐层状态和细胞状态输出。下面以hn_1和cn_1为例进行讲解。
  5. hn_1维度为[2, 4, 8],2代表双向单层(2*1),4代表batch_size,8代表hidden_size。
  6. 仔细观察可以看出,hn_1中第一维度第0索引的正向输出部分与output_1最后一维输出前hidden_size数值一致,即与有效长度内最后一个的输出的前hidden_size数值保持一致。
  7. 仔细观察可以看出,hn_1中第一维度第1索引的反向输出部分与output_1开始一维输出后hidden_size数值一致。
  8. cn_1为有效最后一步的细胞状态。
====== single layer bi lstm output 0 shape: (4, 8, 16) ======
[[[ 0.11591419  0.29961097  0.3425573   0.4287143   0.17212108
    0.07444338  0.43271446  0.15715674  0.08194006  0.11577142
   -0.09744498 -0.02763127  0.09280778  0.08716499  0.02522062
    0.33181873]
  [-0.01308823  0.13623668  0.19448121  0.37028143  0.22777143
    0.00628781  0.39128026  0.15501572  0.08111142  0.11017906
   -0.12316822 -0.00816909  0.09567513  0.05021677  0.08249568
    0.33742255]
  [-0.05627449  0.04682723  0.15380071  0.3137156   0.26430035
   -0.046514    0.35723254  0.16584632  0.10204285  0.10223756
   -0.13232729 -0.00190703  0.11279006  0.07007243  0.07809626
    0.36085904]
  [-0.09489179 -0.00705127  0.1340199   0.24711385  0.27097055
   -0.05539801  0.29088783  0.180727    0.13702057  0.07165765
   -0.15263684 -0.02301912  0.14440101  0.09643525  0.04434848
    0.32824463]
  [-0.13192342 -0.09842218  0.13483751  0.2363211   0.2714419
   -0.06301905  0.23002718  0.12190706  0.1600955   0.0820565
   -0.13324322  0.00847512  0.15308659  0.12757084  0.06873622
    0.3726861 ]
  [-0.16037701 -0.12437794  0.12642992  0.23676534  0.29797453
   -0.04277696  0.24219972  0.16359471  0.16195399  0.07269616
   -0.1250204  -0.0185749   0.19040069  0.12709007  0.12064856
    0.30454746]
  [-0.1353235  -0.12385159  0.1025193   0.23867385  0.30110353
   -0.03195428  0.2832907   0.18136714  0.19130123  0.09153596
   -0.05207976  0.02430173  0.2524703   0.22256352  0.17788586
    0.3196903 ]
  [-0.15227936 -0.16710246  0.11279354  0.2324703   0.3158889
   -0.05391366  0.28967926  0.21905534  0.34464788  0.06061291
    0.10662059  0.08228769  0.38103724  0.44488934  0.22631703
    0.38864976]]

 [[ 0.07946795  0.30921736  0.35205007  0.37194842  0.2058839
    0.09482588  0.4332572   0.2775039   0.10343523  0.07151344
   -0.13616626 -0.04245609  0.10985457  0.06919786  0.0364913
    0.31924048]
  [-0.04591701  0.14795585  0.20307627  0.35713255  0.21074952
    0.03478044  0.36047992  0.15351431  0.11235587  0.07168273
   -0.11715946 -0.02380875  0.11772131  0.11803672  0.00387634
    0.33266184]
  [-0.09412251  0.02499678  0.17255405  0.3178058   0.23692454
   -0.03471331  0.26576498  0.10732022  0.14581609  0.07355653
   -0.12852795  0.01927058  0.13053373  0.14796041  0.01590303
    0.3854578 ]
  [-0.09348419  0.00631614  0.1466178   0.22848201  0.22966608
   -0.05388562  0.14963126  0.08823045  0.15729474  0.0657778
   -0.15222837 -0.01835432  0.15758416  0.17561477 -0.03188463
    0.3511778 ]
  [-0.15382743 -0.04836275  0.14573918  0.22835778  0.2532363
   -0.03674607  0.1401736   0.09852327  0.17570393  0.04582136
   -0.13850203  0.00081276  0.16863164  0.14211492  0.04397457
    0.33833435]
  [-0.14028388 -0.08847751  0.13194019  0.21878807  0.28851762
   -0.06432837  0.15592363  0.16226491  0.20294866  0.04400881
   -0.11535563  0.04870296  0.22049154  0.17808373  0.09339966
    0.34441146]
  [-0.1683049  -0.16189072  0.1318028   0.22591397  0.3027075
   -0.07447627  0.15145044  0.1329806   0.2544369   0.06014252
   -0.01793557  0.11026148  0.2146467   0.3118566   0.12141219
    0.39812002]
  [-0.19805393 -0.17752953  0.12876241  0.21628919  0.3038769
   -0.036511    0.1357605   0.10460708  0.3527281   0.07156999
    0.1540587   0.09252883  0.35960466  0.54258245  0.16377062
    0.40849966]]

 [[ 0.08452003  0.3159105   0.3420099   0.3319746   0.20285761
    0.08632328  0.3581056   0.27760154  0.14828831  0.04973472
   -0.18127252 -0.02664946  0.11601479  0.06740937  0.0379785
    0.342705  ]
  [-0.0266434   0.16035607  0.18312001  0.31999707  0.22840345
    0.01311543  0.3133277   0.20360778  0.12191478  0.06214391
   -0.16598006 -0.03916245  0.10791545  0.06448431  0.03113508
    0.33138022]
  [-0.10794992  0.03787376  0.16952753  0.2500641   0.24685495
   -0.05109966  0.20483223  0.18794663  0.16794644  0.03811646
   -0.17785533  0.00866746  0.13491729  0.06493596  0.055873
    0.3487326 ]
  [-0.11205798 -0.04663825  0.13637729  0.2688466   0.2944545
   -0.06623676  0.24580626  0.1894824   0.12357055  0.08545923
   -0.13890322  0.02125055  0.12671538  0.05041068  0.10938939
    0.37651145]
  [-0.14464049 -0.11277611  0.12929943  0.2506328   0.32429394
   -0.06989705  0.26676533  0.22626272  0.14871088  0.06151669
   -0.14160013  0.01764496  0.15616798  0.06309532  0.11477884
    0.3533678 ]
  [-0.1919359  -0.14934857  0.12687694  0.2482472   0.30332044
   -0.02129422  0.24142255  0.19039477  0.1872613   0.05607529
   -0.10981983  0.02655923  0.19725962  0.15991098  0.08460074
    0.32532936]
  [-0.15997384 -0.16905244  0.12601317  0.24978957  0.3109707
   -0.05129525  0.25644392  0.18721735  0.23115595  0.07164647
   -0.04363466  0.09616573  0.23608637  0.23462081  0.16639999
    0.36137852]
  [-0.17784727 -0.19330868  0.12555353  0.25036657  0.3237954
   -0.05024423  0.27374345  0.16953917  0.3444527   0.074378
    0.12866443  0.11058272  0.34053382  0.47292238  0.20279881
    0.42136478]]

 [[ 0.09268619  0.35032618  0.34263822  0.33635783  0.19130397
    0.089779    0.3541034   0.26252666  0.15370639  0.05593391
   -0.16430146 -0.00316385  0.14068598  0.13546935 -0.01566708
    0.32892445]
  [ 0.00249528  0.16723414  0.19037648  0.32905748  0.20670214
   -0.01093364  0.22814633  0.10346357  0.14574584  0.08942283
   -0.13508694  0.02989143  0.13283192  0.155128   -0.00928066
    0.38435996]
  [-0.09191902  0.02066077  0.1762495   0.2693505   0.2615397
   -0.07361222  0.17539641  0.12341685  0.14845897  0.06833903
   -0.15054268  0.02503714  0.12414654  0.08736143  0.07049443
    0.35888508]
  [-0.08116069 -0.0288023   0.12298302  0.24174306  0.3107592
   -0.07053182  0.23929915  0.17529318  0.09909797  0.10476568
   -0.13906275 -0.0065798   0.12028767  0.09093229  0.08531829
    0.33838242]
  [-0.08996075 -0.04482763  0.10432535  0.18569301  0.29469466
   -0.064595    0.21119419  0.19096416  0.15567164  0.06260847
   -0.15861334 -0.01660161  0.17961282  0.14018227  0.05389842
    0.32480207]
  [-0.13079894 -0.12208281  0.11661161  0.20262218  0.31364897
   -0.09002802  0.23725566  0.21705934  0.20321131  0.03772969
   -0.12727125  0.04301733  0.21097985  0.16362298  0.12457186
    0.3570657 ]
  [-0.14077222 -0.14493458  0.10797977  0.20154148  0.32082993
   -0.06558356  0.24276899  0.20433648  0.23955566  0.04574178
   -0.03365875  0.05299059  0.26905897  0.3059458   0.11437013
    0.3523326 ]
  [-0.20353709 -0.20380074  0.12652008  0.19772139  0.28259847
   -0.04320877  0.1549557   0.12743628  0.37037018  0.04201189
    0.16136979  0.10812846  0.3535916   0.573114    0.14248823
    0.42301312]]]
====== single layer bi lstm hn0 shape: (2, 4, 8) ======
[[[-0.15227936 -0.16710246  0.11279354  0.2324703   0.3158889
   -0.05391366  0.28967926  0.21905534]
  [-0.19805393 -0.17752953  0.12876241  0.21628919  0.3038769
   -0.036511    0.1357605   0.10460708]
  [-0.17784727 -0.19330868  0.12555353  0.25036657  0.3237954
   -0.05024423  0.27374345  0.16953917]
  [-0.20353709 -0.20380074  0.12652008  0.19772139  0.28259847
   -0.04320877  0.1549557   0.12743628]]

 [[ 0.08194006  0.11577142 -0.09744498 -0.02763127  0.09280778
    0.08716499  0.02522062  0.33181873]
  [ 0.10343523  0.07151344 -0.13616626 -0.04245609  0.10985457
    0.06919786  0.0364913   0.31924048]
  [ 0.14828831  0.04973472 -0.18127252 -0.02664946  0.11601479
    0.06740937  0.0379785   0.342705  ]
  [ 0.15370639  0.05593391 -0.16430146 -0.00316385  0.14068598
    0.13546935 -0.01566708  0.32892445]]]
====== single layer bi lstm cn0 shape: (2, 4, 8) ======
[[[-0.48307976 -0.40690032  0.24048738  0.49366224  0.5961513
   -0.13565473  0.5191028   0.48418468]
  [-0.55306923 -0.41890883  0.31527558  0.4081013   0.5560535
   -0.10868378  0.22270739  0.224445  ]
  [-0.5595058  -0.5172409   0.28816614  0.4680259   0.6353333
   -0.1406159   0.45408633  0.39424264]
  [-0.55914015 -0.42366728  0.29431793  0.42468843  0.5133875
   -0.11134674  0.27713037  0.2564772 ]]

 [[ 0.13141792  0.26979685 -0.20174497 -0.06629345  0.16831748
    0.14618596  0.05280813  0.84774   ]
  [ 0.16957031  0.19068424 -0.28012666 -0.10653219  0.1932735
    0.12457087  0.07286038  0.91865647]
  [ 0.25553685  0.1275407  -0.37673476 -0.06495219  0.21608156
    0.11330918  0.07597075  0.97954106]
  [ 0.2739099   0.14198926 -0.342751   -0.00778307  0.25392675
    0.23573248 -0.03052862  0.89955646]]]
====== single layer bi lstm output 1 shape: (4, 8, 16) ======
[[[ 0.11591419  0.299611    0.3425573   0.4287143   0.17212108
    0.07444337  0.43271446  0.15715674  0.14267941  0.11772849
   -0.08396029 -0.0199183   0.17602898  0.19761203  0.06850712
    0.30409858]
  [-0.01308823  0.1362367   0.19448121  0.3702814   0.22777143
    0.00628781  0.39128026  0.1550157   0.19404428  0.11392959
   -0.04281732  0.02546077  0.24461909  0.24037687  0.16997418
    0.30728906]
  [-0.05627449  0.04682725  0.15380071  0.3137156   0.26430035
   -0.04651401  0.3572325   0.1658463   0.32523182  0.10201547
    0.12631407  0.07232428  0.37344953  0.46444228  0.22052252
    0.38782993]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]]

 [[ 0.07946795  0.30921736  0.35205007  0.37194842  0.2058839
    0.09482589  0.4332572   0.27750388  0.10343523  0.07151344
   -0.13616627 -0.04245608  0.10985459  0.06919786  0.0364913
    0.31924048]
  [-0.04591701  0.14795585  0.20307627  0.35713255  0.21074952
    0.03478044  0.36047992  0.1535143   0.11235587  0.07168273
   -0.11715946 -0.02380875  0.11772133  0.11803672  0.00387635
    0.33266184]
  [-0.09412251  0.02499679  0.17255405  0.31780577  0.23692457
   -0.03471331  0.265765    0.10732021  0.14581607  0.07355653
   -0.12852795  0.01927058  0.13053373  0.14796041  0.01590303
    0.38545772]
  [-0.09348419  0.00631614  0.14661779  0.228482    0.2296661
   -0.05388563  0.14963126  0.08823042  0.15729474  0.0657778
   -0.15222837 -0.01835432  0.15758416  0.17561477 -0.03188463
    0.35117778]
  [-0.15382743 -0.04836275  0.14573918  0.22835778  0.25323635
   -0.03674608  0.14017357  0.09852324  0.17570391  0.04582136
   -0.13850203  0.00081274  0.16863164  0.14211491  0.04397457
    0.33833435]
  [-0.14028388 -0.08847751  0.13194019  0.21878807  0.28851762
   -0.06432837  0.15592363  0.16226488  0.20294866  0.04400881
   -0.11535563  0.04870294  0.22049154  0.17808372  0.09339967
    0.34441146]
  [-0.1683049  -0.16189072  0.1318028   0.22591396  0.30270752
   -0.07447628  0.15145041  0.13298061  0.2544369   0.06014251
   -0.01793558  0.11026147  0.2146467   0.31185657  0.1214122
    0.39812005]
  [-0.19805394 -0.17752953  0.12876241  0.21628918  0.30387694
   -0.036511    0.1357605   0.10460708  0.3527281   0.07156998
    0.1540587   0.09252883  0.35960466  0.54258245  0.16377063
    0.40849966]]

 [[ 0.08452003  0.31591052  0.3420099   0.3319746   0.2028576
    0.08632328  0.3581056   0.2776015   0.16127887  0.05090985
   -0.18798977 -0.03278283  0.14869703  0.09618111  0.05077953
    0.32884052]
  [-0.0266434   0.16035606  0.18312001  0.31999707  0.22840345
    0.01311543  0.31332764  0.20360778  0.14828573  0.06162609
   -0.16532603 -0.04184524  0.17109753  0.11741111  0.05272176
    0.31123316]
  [-0.10794992  0.03787376  0.16952753  0.2500641   0.24685495
   -0.05109966  0.2048322   0.18794663  0.21637706  0.03754523
   -0.15342048  0.0159312   0.2186653   0.17495207  0.09126361
    0.32591543]
  [-0.11205798 -0.04663826  0.13637729  0.2688466   0.2944545
   -0.06623676  0.24580622  0.1894824   0.21777555  0.08560579
   -0.0555483   0.0522357   0.2504716   0.23061936  0.18061498
    0.34555358]
  [-0.14464049 -0.11277609  0.12929943  0.2506328   0.32429394
   -0.06989705  0.26676533  0.22626273  0.34267974  0.06394035
    0.10800922  0.07929072  0.38286424  0.44688055  0.22619261
    0.38621217]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]]

 [[ 0.09268619  0.35032618  0.34263822  0.33635783  0.19130397
    0.089779    0.3541034   0.26252666  0.34620598  0.06714007
    0.13512857  0.04233981  0.42014182  0.5216394   0.18838547
    0.3683127 ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]]]
====== single layer bi lstm hn1 shape: (2, 4, 8) ======
[[[-0.05627449  0.04682725  0.15380071  0.3137156   0.26430035
   -0.04651401  0.3572325   0.1658463 ]
  [-0.19805394 -0.17752953  0.12876241  0.21628918  0.30387694
   -0.036511    0.1357605   0.10460708]
  [-0.14464049 -0.11277609  0.12929943  0.2506328   0.32429394
   -0.06989705  0.26676533  0.22626273]
  [ 0.09268619  0.35032618  0.34263822  0.33635783  0.19130397
    0.089779    0.3541034   0.26252666]]

 [[ 0.14267941  0.11772849 -0.08396029 -0.0199183   0.17602898
    0.19761203  0.06850712  0.30409858]
  [ 0.10343523  0.07151344 -0.13616627 -0.04245608  0.10985459
    0.06919786  0.0364913   0.31924048]
  [ 0.16127887  0.05090985 -0.18798977 -0.03278283  0.14869703
    0.09618111  0.05077953  0.32884052]
  [ 0.34620598  0.06714007  0.13512857  0.04233981  0.42014182
    0.5216394   0.18838547  0.3683127 ]]]
====== single layer bi lstm cn1 shape: (2, 4, 8) ======
[[[-0.16340391  0.12338591  0.36321753  0.60983956  0.4963916
   -0.14528881  0.61422133  0.37583172]
  [-0.5530693  -0.41890883  0.31527558  0.40810126  0.5560536
   -0.10868377  0.22270739  0.22444502]
  [-0.46137562 -0.27004397  0.27595642  0.5348579   0.62363803
   -0.18086377  0.46610427  0.4973321 ]
  [ 0.23746979  0.6868869   0.56339467  0.96855223  0.39346337
    0.32335475  0.7259624   0.4185825 ]]

 [[ 0.22938183  0.2952913  -0.17549752 -0.05000385  0.33509728
    0.3336044   0.14473113  0.7370499 ]
  [ 0.16957031  0.19068426 -0.2801267  -0.10653219  0.19327351
    0.12457087  0.07286038  0.91865647]
  [ 0.27940926  0.13317151 -0.39137632 -0.081429    0.28198367
    0.16170114  0.10146889  0.91004795]
  [ 0.6180897   0.28882137  0.28748003  0.15160248  0.7991137
    0.90929043  0.45457762  0.8128108 ]]]

3.3 双层双向LSTM

本示例中随机生成了[4, 8, 4]数据,该数据batch_size为4,固定seq_length为8,输入维度为4。

本示例采用双层双向LSTM,隐层大小为8。

本示例中LSTM调用时进行对比测试,一个seq_length为默认值None,一个为有效长度input_seq_length

示例代码如下:

import numpy as np

from mindspore import dtype
from mindspore import Tensor
from mindspore.nn import LSTM


def double_layer_bi_lstm():
    random_data = np.random.rand(4, 8, 4)
    seq_length = [3, 8, 5, 1]
    input_seq_data = Tensor(random_data, dtype=dtype.float32)
    input_seq_length = Tensor(seq_length, dtype=dtype.int32)

    batch_size = 4
    input_size = 4
    hidden_size = 8
    num_layers = 2
    bidirectional = True
    num_bi = 2 if bidirectional else 1

    lstm = LSTM(
        input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
        has_bias=True, batch_first=True, dropout=0.0, bidirectional=bidirectional)

    h0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
    c0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))

    output_0, (hn_0, cn_0) = lstm(input_seq_data, (h0, c0))
    output_1, (hn_1, cn_1) = lstm(input_seq_data, (h0, c0), input_seq_length)

    print("====== double layer bi lstm output 0 shape: {} ======\n{}".format(output_0.shape, output_0), flush=True)
    print("====== double layer bi lstm hn0 shape: {} ======\n{}".format(hn_0.shape, hn_0), flush=True)
    print("====== double layer bi lstm cn0 shape: {} ======\n{}".format(cn_0.shape, cn_0), flush=True)

    print("====== double layer bi lstm output 1 shape: {} ======\n{}".format(output_1.shape, output_1), flush=True)
    print("====== double layer bi lstm hn1 shape: {} ======\n{}".format(hn_1.shape, hn_1), flush=True)
    print("====== double layer bi lstm cn1 shape: {} ======\n{}".format(cn_1.shape, cn_1), flush=True)

示例代码输出内容如下:

对输出内容进行分析:

  1. output_0和output_1维度都是[4, 8, 16],即batch_size, seq_length和hidden_size * 2,这里乘2是因为是双向输出。
  2. output_0和output_1皆是第二层(最后一层)的输出,中间层(本例为第一层)输出没有显示给出。
  3. output_0对应的是调用时seq_length为None的情况,即默认有效seq_length为8,可以看到output_0各个长度输出数值皆非全零。
  4. output_1对应的是调用时seq_length为设定值[3, 8, 5, 1],可以看到output_1超过有效长度的输出部分皆为全零。
  5. hn和cn分别为隐层状态和细胞状态输出。下面以hn_1和cn_1为例进行讲解。
  6. hn_1维度为[4, 4, 8],4代表双向双层(2*2),4代表batch_size,8代表hidden_size。
  7. 6中说明4代表双向双层(2*2),hn_1包含各层的最终有效隐层状态输出,这里同output_1只包含最后一层的输出不同。
  8. 仔细观察可以看出,hn_1中第一维度第2索引位置(即最后一层)的正向输出部分与output_1最后一维输出前hidden_size数值一致,即与有效长度内最后一个的输出的前hidden_size数值保持一致。
  9. 仔细观察可以看出,hn_1中第一维度第3索引位置(即最后一层)的反向输出部分与output_1开始一维输出后hidden_size数值一致。
  10. cn_1为有效最后一步的细胞状态。
====== double layer bi lstm output 0 shape: (4, 8, 16) ======
[[[ 3.70550364e-01  2.17652053e-01  3.79816592e-01  5.39002419e-01
    2.28588611e-01  3.83301824e-02  2.20795229e-01  2.44438455e-01
    2.06572518e-01 -3.78293954e-02  2.60271341e-01 -4.60247397e-02
   -3.78369205e-02 -1.90976545e-01 -1.01466656e-01  1.76680252e-01]
  [ 1.65173441e-01  7.22418576e-02  4.98769164e-01  2.52682149e-01
    2.94478923e-01 -1.56086944e-02  1.32235214e-01  4.96024750e-02
    1.81777030e-01 -7.20555857e-02  2.31085896e-01  7.43698841e-03
   -2.21280195e-02 -1.63902551e-01 -8.19268897e-02  1.90522313e-01]
  [ 3.94219235e-02 -8.84856097e-03  4.88511086e-01  1.51095495e-01
    2.83691764e-01 -2.36562286e-02  1.14125453e-01 -4.99135666e-02
    1.84900641e-01 -9.07974318e-02  2.06634849e-01  5.43768853e-02
   -2.88773868e-02 -1.41080543e-01 -7.59911761e-02  1.93940982e-01]
  [-3.12361736e-02 -5.87114133e-02  4.53683615e-01  7.93214589e-02
    2.92402357e-01 -2.14897078e-02  1.08925141e-01 -9.88882780e-02
    1.98123455e-01 -8.50049481e-02  1.91045731e-01  8.83036405e-02
   -1.40397642e-02 -1.22237459e-01 -6.35140762e-02  1.80813670e-01]
  [-6.87201098e-02 -6.12376854e-02  4.39131975e-01  2.83084475e-02
    2.86313444e-01 -9.33245104e-03  1.12482831e-01 -1.27253398e-01
    2.32264340e-01 -7.87357539e-02  1.86317161e-01  1.59440145e-01
    1.36264751e-03 -9.95954126e-02 -4.97992262e-02  1.69756234e-01]
  [-8.29227120e-02 -6.19332492e-02  4.27550107e-01 -1.70003679e-02
    2.88041800e-01  3.62846977e-03  1.04239471e-01 -1.43706441e-01
    2.90384740e-01 -5.84731065e-02  1.86135545e-01  2.26804867e-01
    3.95135172e-02 -6.33978993e-02 -1.63939036e-02  1.48533911e-01]
  [-8.72429982e-02 -6.10240139e-02  4.20974702e-01 -6.44157380e-02
    2.92603880e-01  2.60243341e-02  9.26012769e-02 -1.46479979e-01
    3.93343538e-01 -1.41044548e-02  1.96197629e-01  3.05834383e-01
    1.02294169e-01  2.09005456e-03  5.07600456e-02  1.33950055e-01]
  [-8.48584175e-02 -4.15292941e-02  4.26153004e-01 -1.12198450e-01
    2.93441713e-01  4.73045520e-02  7.22456872e-02 -1.52661309e-01
    6.08003795e-01  1.02589525e-01  2.28410736e-01  3.57809156e-01
    2.30974391e-01  7.29562640e-02  1.54908523e-01  1.37615114e-01]]

 [[ 3.73128176e-01  2.24487275e-01  3.83654892e-01  5.39644539e-01
    2.24863932e-01  3.69703583e-02  2.22563371e-01  2.47377262e-01
    2.09958509e-01 -3.67934220e-02  2.55294740e-01 -5.44558465e-02
   -3.49954516e-02 -1.88630879e-01 -9.97974724e-02  1.72440261e-01]
  [ 1.63444579e-01  7.47621208e-02  4.95126337e-01  2.49838263e-01
    2.98441172e-01 -2.29644943e-02  1.30464450e-01  4.65075821e-02
    1.87749639e-01 -5.69685884e-02  2.30926782e-01 -1.89751368e-02
   -6.57672016e-03 -1.64301425e-01 -7.78417960e-02  1.70920238e-01]
  [ 2.62361914e-02 -2.09027641e-02  4.81326580e-01  1.54101923e-01
    2.95957267e-01 -3.76441851e-02  1.13665104e-01 -5.53984046e-02
    1.96336910e-01 -6.99553713e-02  2.13279501e-01  2.09173746e-02
   -6.90750126e-03 -1.42273992e-01 -7.39771128e-02  1.70230061e-01]
  [-4.55044061e-02 -8.81957486e-02  4.49505210e-01  8.37849677e-02
    3.12549353e-01 -3.09768375e-02  9.69471037e-02 -9.93652195e-02
    2.07049429e-01 -6.65001795e-02  1.99929893e-01  4.60516922e-02
    9.15598311e-03 -1.23334207e-01 -6.36003762e-02  1.58215716e-01]
  [-8.54137391e-02 -9.59964097e-02  4.28828478e-01  2.81018596e-02
    3.12747598e-01 -1.96594596e-02  1.04248613e-01 -1.21685371e-01
    2.44353175e-01 -6.95914254e-02  1.87495902e-01  1.23339958e-01
    1.20015517e-02 -9.19487774e-02 -5.30561097e-02  1.52850106e-01]
  [-1.03212453e-01 -9.74086747e-02  4.10266966e-01 -2.03387272e-02
    3.20133060e-01  5.47134259e-04  1.07527576e-01 -1.26215830e-01
    3.05427969e-01 -5.22202961e-02  1.89031556e-01  2.16380343e-01
    4.27359492e-02 -5.31105101e-02 -2.50125714e-02  1.44858196e-01]
  [-1.06917843e-01 -7.13929683e-02  4.10624832e-01 -5.51486127e-02
    3.07110429e-01  1.92907490e-02  1.03878655e-01 -1.38662428e-01
    4.00884181e-01 -1.43600125e-02  1.82524621e-01  2.97586024e-01
    9.52146128e-02  9.59962141e-03  5.30949272e-02  1.37635604e-01]
  [-9.71160829e-02 -4.43801992e-02  4.20233607e-01 -1.02356419e-01
    3.03063601e-01  3.99401113e-02  8.28935355e-02 -1.43912748e-01
    6.09543681e-01  1.04935512e-01  2.27933496e-01  3.57850134e-01
    2.31336534e-01  7.57181123e-02  1.55172557e-01  1.39436752e-01]]

 [[ 3.74232024e-01  2.23312378e-01  3.80826175e-01  5.25748074e-01
    2.30494052e-01  3.75359394e-02  2.19325155e-01  2.45338157e-01
    1.90327644e-01 -9.49237868e-03  2.51282185e-01 -4.07305919e-02
   -7.68693071e-03 -1.96041882e-01 -9.43402052e-02  1.52500823e-01]
  [ 1.65756628e-01  8.52986127e-02  5.00474215e-01  2.32285380e-01
    2.97197372e-01 -2.87767611e-02  1.31484732e-01  4.05624248e-02
    1.72598451e-01 -3.74435596e-02  2.30013907e-01  1.03627918e-02
    1.63554456e-02 -1.71838194e-01 -7.55213797e-02  1.56671956e-01]
  [ 3.51614878e-02 -3.49920541e-02  4.85133171e-01  1.37813956e-01
    3.03884476e-01 -3.76141518e-02  9.96868908e-02 -4.97255772e-02
    1.81163609e-01 -4.24254723e-02  2.27177203e-01  3.23883444e-02
    2.71688756e-02 -1.56165496e-01 -6.69138283e-02  1.53632939e-01]
  [-4.10026051e-02 -8.96424949e-02  4.60784853e-01  8.30888674e-02
    3.03816915e-01 -2.20339652e-02  9.38846841e-02 -9.45615992e-02
    2.04564795e-01 -4.51925248e-02  2.18029544e-01  6.01283386e-02
    3.36706154e-02 -1.35854393e-01 -5.57745472e-02  1.48557410e-01]
  [-8.25456828e-02 -1.13149934e-01  4.36939508e-01  3.75392586e-02
    3.10225427e-01 -7.73321884e-03  9.12441462e-02 -1.16306305e-01
    2.42686659e-01 -4.25874330e-02  2.11468235e-01  1.09053820e-01
    4.69379947e-02 -1.04551256e-01 -4.02252935e-02  1.34793952e-01]
  [-1.08169496e-01 -1.15720116e-01  4.16452408e-01  4.10868321e-03
    3.16107094e-01  6.06524665e-03  9.51950625e-02 -1.27826288e-01
    3.06058168e-01 -3.21962573e-02  2.01961204e-01  1.87839821e-01
    6.73103184e-02 -5.98271154e-02 -1.05028180e-02  1.28264755e-01]
  [-1.16449505e-01 -1.07103497e-01  4.10319597e-01 -3.42636257e-02
    3.23818535e-01  2.40915213e-02  9.08538699e-02 -1.28739789e-01
    4.00041372e-01  5.13588311e-03  2.06977740e-01  2.77402431e-01
    1.18934669e-01  6.60364656e-03  5.48240133e-02  1.22762337e-01]
  [-1.07369550e-01 -7.64680207e-02  4.24612671e-01 -8.88631567e-02
    3.25147092e-01  5.22605665e-02  7.02133700e-02 -1.30118832e-01
    6.03053808e-01  1.08490229e-01  2.35621274e-01  3.42306137e-01
    2.33348757e-01  7.23976195e-02  1.51835442e-01  1.38724014e-01]]

 [[ 3.68833274e-01  2.19720796e-01  3.75712991e-01  5.39344609e-01
    2.32777387e-01  3.75517495e-02  2.15990663e-01  2.38119900e-01
    2.03846872e-01 -3.31601547e-03  2.63746709e-01 -5.33154309e-02
   -1.53900171e-02 -1.96350247e-01 -9.86721516e-02  1.51238605e-01]
  [ 1.61587596e-01  7.25713074e-02  4.97545034e-01  2.48409301e-01
    3.00032824e-01 -2.52650958e-02  1.25469610e-01  4.12617065e-02
    1.75564945e-01 -3.84877101e-02  2.34954998e-01  1.90881861e-03
    7.01279286e-03 -1.72224715e-01 -7.77121335e-02  1.60935923e-01]
  [ 2.84800380e-02 -2.69929953e-02  4.86053288e-01  1.57494590e-01
    2.96494991e-01 -3.40557620e-02  1.04029477e-01 -5.39027080e-02
    1.82317436e-01 -5.37234657e-02  2.23423839e-01  4.04849648e-02
    8.95922631e-03 -1.53901607e-01 -7.44922534e-02  1.65948585e-01]
  [-3.72786410e-02 -7.53442869e-02  4.61774200e-01  8.63353312e-02
    2.97733396e-01 -2.75274049e-02  9.13189948e-02 -1.00060880e-01
    1.94108337e-01 -5.79617955e-02  2.08687440e-01  6.31403774e-02
    2.11703759e-02 -1.34831637e-01 -6.31042644e-02  1.52588978e-01]
  [-7.60064349e-02 -1.06220305e-01  4.34687048e-01  3.19332667e-02
    3.09678972e-01 -1.16188908e-02  8.85540992e-02 -1.18266501e-01
    2.29653955e-01 -5.94241545e-02  2.00053185e-01  1.14932276e-01
    3.13343108e-02 -1.04001120e-01 -4.90994565e-02  1.44359529e-01]
  [-9.52797905e-02 -9.27509218e-02  4.22483116e-01 -1.29148299e-02
    3.04568678e-01  9.32686683e-03  9.81104076e-02 -1.28704712e-01
    2.98035592e-01 -5.08954525e-02  1.98656082e-01  2.12906018e-01
    5.04655764e-02 -6.18565194e-02 -2.38872226e-02  1.40028179e-01]
  [-9.81744751e-02 -8.54582712e-02  4.15283144e-01 -6.42896220e-02
    3.11841279e-01  3.18106599e-02  8.80582407e-02 -1.32987425e-01
    3.88665676e-01 -1.39519377e-02  1.92815915e-01  2.86827296e-01
    1.07908145e-01  2.11709971e-03  4.85477857e-02  1.27813160e-01]
  [-9.11041871e-02 -4.77942340e-02  4.29545075e-01 -1.14117011e-01
    3.04611683e-01  5.14086746e-02  7.33837485e-02 -1.44734517e-01
    6.06585741e-01  9.89784896e-02  2.24559098e-01  3.55441421e-01
    2.28052005e-01  7.30600879e-02  1.55306384e-01  1.37683451e-01]]]
====== double layer bi lstm hn0 shape: (4, 4, 8) ======
[[[ 0.25934413 -0.07461581  0.19370164  0.11095355  0.02041678
    0.29797387  0.03047622  0.19640712]
  [ 0.2874061  -0.08844143  0.22119689  0.1251989  -0.01900517
    0.29294112  0.05027778  0.2071664 ]
  [ 0.2596095   0.03271259  0.26155     0.10348854  0.08536521
    0.28197888 -0.08929807  0.18018515]
  [ 0.2509837  -0.07010224  0.20813467  0.10349585  0.04007874
    0.27277622  0.01278557  0.18474495]]

 [[-0.00949934  0.10407767  0.038502    0.14573903 -0.14825179
   -0.08745017  0.3038079   0.28010136]
  [ 0.05813041  0.14894389  0.05397653  0.15691832 -0.16107248
   -0.06869183  0.27977887  0.26698047]
  [-0.05296279  0.02392143  0.06922498  0.16198513 -0.12499766
   -0.063968    0.2682934   0.25862688]
  [-0.03301367  0.04014921 -0.00048225  0.1180163  -0.12858163
   -0.07102007  0.35664883  0.26105112]]

 [[-0.08485842 -0.04152929  0.426153   -0.11219845  0.2934417
    0.04730455  0.07224569 -0.15266131]
  [-0.09711608 -0.0443802   0.4202336  -0.10235642  0.3030636
    0.03994011  0.08289354 -0.14391275]
  [-0.10736955 -0.07646802  0.42461267 -0.08886316  0.3251471
    0.05226057  0.07021337 -0.13011883]
  [-0.09110419 -0.04779423  0.42954507 -0.11411701  0.30461168
    0.05140867  0.07338375 -0.14473452]]

 [[ 0.20657252 -0.0378294   0.26027134 -0.04602474 -0.03783692
   -0.19097655 -0.10146666  0.17668025]
  [ 0.20995851 -0.03679342  0.25529474 -0.05445585 -0.03499545
   -0.18863088 -0.09979747  0.17244026]
  [ 0.19032764 -0.00949238  0.2512822  -0.04073059 -0.00768693
   -0.19604188 -0.09434021  0.15250082]
  [ 0.20384687 -0.00331602  0.2637467  -0.05331543 -0.01539002
   -0.19635025 -0.09867215  0.1512386 ]]]
====== double layer bi lstm cn0 shape: (4, 4, 8) ======
[[[ 0.5770398  -0.16899881  0.40028483  0.25001454  0.04046626
    0.57915956  0.05266067  0.52447474]
  [ 0.66343445 -0.19959925  0.49729916  0.27566156 -0.03596141
    0.5509572   0.0853648   0.5394346 ]
  [ 0.5707181   0.07038814  0.5712474   0.2565448   0.1530705
    0.57276523 -0.15605333  0.46282846]
  [ 0.55990976 -0.16366895  0.4313923   0.23668876  0.08243398
    0.53433377  0.02196771  0.4817235 ]]

 [[-0.02554817  0.2071405   0.07978731  0.2778875  -0.24753608
   -0.2485388   0.62492937  0.6474521 ]
  [ 0.16052538  0.31375027  0.1059354   0.2853353  -0.26115927
   -0.20904504  0.5899866   0.56931025]
  [-0.14657407  0.05189808  0.13706218  0.33399543 -0.2142592
   -0.16363172  0.612855    0.61697096]
  [-0.0884767   0.07950284 -0.00107491  0.2254872  -0.21063672
   -0.20023198  0.72448045  0.60711044]]

 [[-0.2504415  -0.0814982   0.7923428  -0.19285998  0.5903069
    0.13990048  0.15511556 -0.2908177 ]
  [-0.28950468 -0.08669281  0.7886544  -0.17458251  0.6081315
    0.12001925  0.17698732 -0.2759574 ]
  [-0.30495524 -0.14845964  0.79688644 -0.15463473  0.6548568
    0.15446547  0.1526669  -0.24459954]
  [-0.265516   -0.09397535  0.79843074 -0.19696996  0.6198776
    0.15148453  0.15768716 -0.275381  ]]

 [[ 0.32853472 -0.05710489  0.7447654  -0.0758819  -0.09938034
   -0.47783113 -0.28168824  0.36019933]
  [ 0.33408064 -0.05591211  0.7391405  -0.08961775 -0.0917803
   -0.47115833 -0.278066    0.35383248]
  [ 0.30187273 -0.01431822  0.7146605  -0.06792408 -0.02012375
   -0.48834586 -0.26035625  0.3151392 ]
  [ 0.32118577 -0.00497683  0.7502155  -0.08775105 -0.04013083
   -0.4903597  -0.27541417  0.30617815]]]
====== double layer bi lstm output 1 shape: (4, 8, 16) ======
[[[ 3.5416836e-01  2.0936093e-01  3.8317284e-01  5.3357160e-01
    2.4053907e-01  4.1459590e-02  2.0509864e-01  2.5311515e-01
    3.7313861e-01  2.2726113e-02  2.4815443e-01  1.6349553e-01
    1.1913014e-02 -1.0416587e-01 -4.6682160e-02  1.2466244e-01]
  [ 1.6695338e-01  8.1573747e-02  5.0642765e-01  2.2585270e-01
    3.1199178e-01  7.0200888e-03  1.0298288e-01  7.1754217e-02
    4.2964008e-01  2.7423983e-02  2.2389892e-01  2.8188041e-01
    9.3678713e-02 -1.6824452e-02  4.4604652e-02  1.2561245e-01]
  [ 6.0777575e-02  3.0208385e-02  5.1636058e-01  8.0109224e-02
    3.0168548e-01  1.5010678e-02  5.8312915e-02 -2.7518146e-02
    6.2040079e-01  1.1676422e-01  2.4167898e-01  3.6679846e-01
    2.2570200e-01  6.9053181e-02  1.5332413e-01  1.3909420e-01]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]]

 [[ 3.7312818e-01  2.2448727e-01  3.8365489e-01  5.3964454e-01
    2.2486390e-01  3.6970358e-02  2.2256340e-01  2.4737728e-01
    2.0995849e-01 -3.6793407e-02  2.5529474e-01 -5.4455854e-02
   -3.4995444e-02 -1.8863088e-01 -9.9797480e-02  1.7244026e-01]
  [ 1.6344458e-01  7.4762136e-02  4.9512634e-01  2.4983825e-01
    2.9844120e-01 -2.2964491e-02  1.3046446e-01  4.6507578e-02
    1.8774964e-01 -5.6968573e-02  2.3092678e-01 -1.8975141e-02
   -6.5767197e-03 -1.6430146e-01 -7.7841796e-02  1.7092024e-01]
  [ 2.6236186e-02 -2.0902762e-02  4.8132658e-01  1.5410189e-01
    2.9595733e-01 -3.7644185e-02  1.1366512e-01 -5.5398405e-02
    1.9633688e-01 -6.9955371e-02  2.1327947e-01  2.0917373e-02
   -6.9075003e-03 -1.4227399e-01 -7.3977120e-02  1.7023006e-01]
  [-4.5504406e-02 -8.8195749e-02  4.4950521e-01  8.3784960e-02
    3.1254938e-01 -3.0976830e-02  9.6947111e-02 -9.9365219e-02
    2.0704943e-01 -6.6500187e-02  1.9992988e-01  4.6051688e-02
    9.1559850e-03 -1.2333421e-01 -6.3600369e-02  1.5821570e-01]
  [-8.5413747e-02 -9.5996402e-02  4.2882851e-01  2.8101865e-02
    3.1274763e-01 -1.9659458e-02  1.0424862e-01 -1.2168537e-01
    2.4435315e-01 -6.9591425e-02  1.8749590e-01  1.2333996e-01
    1.2001552e-02 -9.1948770e-02 -5.3056102e-02  1.5285012e-01]
  [-1.0321245e-01 -9.7408667e-02  4.1026697e-01 -2.0338718e-02
    3.2013306e-01  5.4713513e-04  1.0752757e-01 -1.2621583e-01
    3.0542794e-01 -5.2220318e-02  1.8903156e-01  2.1638034e-01
    4.2735931e-02 -5.3110521e-02 -2.5012573e-02  1.4485820e-01]
  [-1.0691784e-01 -7.1392961e-02  4.1062483e-01 -5.5148609e-02
    3.0711043e-01  1.9290760e-02  1.0387863e-01 -1.3866244e-01
    4.0088418e-01 -1.4360026e-02  1.8252462e-01  2.9758602e-01
    9.5214583e-02  9.5995963e-03  5.3094927e-02  1.3763560e-01]
  [-9.7116083e-02 -4.4380195e-02  4.2023361e-01 -1.0235640e-01
    3.0306363e-01  3.9940134e-02  8.2893521e-02 -1.4391276e-01
    6.0954368e-01  1.0493548e-01  2.2793353e-01  3.5785013e-01
    2.3133652e-01  7.5718097e-02  1.5517256e-01  1.3943677e-01]]

 [[ 3.6901441e-01  2.1822800e-01  3.7994039e-01  5.2547783e-01
    2.3396042e-01  3.9366722e-02  2.1538821e-01  2.4702020e-01
    2.4914475e-01 -6.9778422e-03  2.4806115e-01  2.1838229e-02
   -1.3991867e-02 -1.6620368e-01 -8.7110944e-02  1.4123847e-01]
  [ 1.6616049e-01  8.4187903e-02  4.9948204e-01  2.2646046e-01
    3.0369779e-01 -1.7643329e-02  1.2668489e-01  4.9117617e-02
    2.6261702e-01 -2.7619595e-02  2.2540939e-01  1.1914852e-01
    2.3004401e-02 -1.2194993e-01 -5.5561494e-02  1.3998528e-01]
  [ 4.2908981e-02 -2.5578242e-02  4.8486653e-01  1.1890158e-01
    3.1149039e-01 -1.4618633e-02  9.1249026e-02 -3.3213440e-02
    3.1701097e-01 -1.8276740e-02  2.2031868e-01  2.0087981e-01
    5.8553118e-02 -7.3650509e-02 -1.7827954e-02  1.3095699e-01]
  [-2.2401063e-02 -6.7246288e-02  4.6379456e-01  4.6429519e-02
    3.1024706e-01  1.2560772e-02  7.6885723e-02 -7.1739145e-02
    4.0658230e-01  1.3608186e-02  2.1248461e-01  2.7639762e-01
    1.0969905e-01 -1.7181308e-03  5.7507429e-02  1.2614906e-01]
  [-4.9086079e-02 -6.1570432e-02  4.6209678e-01 -3.5342608e-02
    3.1426692e-01  4.2432975e-02  5.4815758e-02 -9.5721334e-02
    6.0554379e-01  1.1493160e-01  2.4293001e-01  3.4404746e-01
    2.3283333e-01  6.8980336e-02  1.5239350e-01  1.3767722e-01]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]]

 [[ 3.3036014e-01  2.2069807e-01  4.0932164e-01  5.0686938e-01
    2.5304586e-01  4.5349576e-02  1.6947377e-01  2.6356062e-01
    6.4686131e-01  1.8447271e-01  2.6571944e-01  3.6628011e-01
    2.0576611e-01  5.9034787e-02  1.3657802e-01  1.4004102e-01]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]]]
====== double layer bi lstm hn1 shape: (4, 4, 8) ======
[[[ 0.30786592 -0.05702875  0.2098356   0.1831936   0.1446731
    0.35495615  0.10906219  0.2584008 ]
  [ 0.28740606 -0.08844142  0.2211969   0.12519889 -0.01900517
    0.29294112  0.05027781  0.2071664 ]
  [ 0.25389883  0.05431987  0.24731106  0.1163514   0.12489295
    0.31806058 -0.07178076  0.20686159]
  [ 0.47720045  0.11175225  0.22376464  0.36412558  0.46750376
    0.28765967  0.38535532  0.33306697]]

 [[ 0.0012262   0.3199089  -0.02733669  0.17044675 -0.04726706
   -0.02164171  0.28464028  0.3348536 ]
  [ 0.05813042  0.14894389  0.05397653  0.15691833 -0.16107246
   -0.06869183  0.27977887  0.26698047]
  [-0.04329334  0.12033389  0.03753637  0.15189895 -0.11344916
   -0.04964198  0.27086687  0.28215134]
  [ 0.05921583  0.543903    0.00194274  0.27610534  0.16461822
    0.25555757  0.18277422  0.3662175 ]]

 [[ 0.06077757  0.03020838  0.5163606   0.08010922  0.30168548
    0.01501068  0.05831292 -0.02751815]
  [-0.09711608 -0.0443802   0.4202336  -0.1023564   0.30306363
    0.03994013  0.08289352 -0.14391276]
  [-0.04908608 -0.06157043  0.46209678 -0.03534261  0.31426692
    0.04243298  0.05481576 -0.09572133]
  [ 0.33036014  0.22069807  0.40932164  0.5068694   0.25304586
    0.04534958  0.16947377  0.26356062]]

 [[ 0.3731386   0.02272611  0.24815443  0.16349553  0.01191301
   -0.10416587 -0.04668216  0.12466244]
  [ 0.2099585  -0.03679341  0.25529474 -0.05445585 -0.03499544
   -0.18863088 -0.09979748  0.17244026]
  [ 0.24914475 -0.00697784  0.24806115  0.02183823 -0.01399187
   -0.16620368 -0.08711094  0.14123847]
  [ 0.6468613   0.18447271  0.26571944  0.3662801   0.20576611
    0.05903479  0.13657802  0.14004102]]]
====== double layer bi lstm cn1 shape: (4, 4, 8) ======
[[[ 0.7061355  -0.13162777  0.46092123  0.4033497   0.2930356
    0.76054144  0.18314546  0.70929015]
  [ 0.6634344  -0.19959924  0.4972992   0.27566153 -0.0359614
    0.5509572   0.08536483  0.5394347 ]
  [ 0.5526391   0.1161246   0.5316373   0.28497726  0.22511882
    0.67451394 -0.12430747  0.5528798 ]
  [ 1.0954192   0.29093137  0.8067771   0.8504353   0.7032547
    0.97427243  0.5589305   0.8662672 ]]

 [[ 0.00324558  0.6688721  -0.05317001  0.32999027 -0.07784042
   -0.05728557  0.58330244  0.8111321 ]
  [ 0.16052541  0.31375027  0.1059354   0.28533533 -0.26115924
   -0.20904504  0.5899867   0.56931025]
  [-0.11802054  0.26023     0.07224996  0.31177503 -0.19568688
   -0.12562011  0.6177163   0.6840635 ]
  [ 0.16791074  1.2188046   0.00349617  0.670789    0.2591958
    0.46886685  0.5807996   0.86447406]]

 [[ 0.16193499  0.06143508  1.1399425   0.13840833  0.69956493
    0.04888431  0.1235408  -0.0485969 ]
  [-0.28950468 -0.0866928   0.7886544  -0.17458248  0.6081316
    0.12001929  0.17698729 -0.27595744]
  [-0.13397661 -0.12149224  0.9074148  -0.06176313  0.6541451
    0.12807912  0.1181712  -0.17463374]
  [ 0.8489872   0.6016479   1.3853014   0.8196937   1.020999
    0.24127276  0.45320526  0.4759813 ]]

 [[ 0.6076499   0.03351691  0.812855    0.27901018  0.02922555
   -0.26106828 -0.12472634  0.24901994]
  [ 0.3340806  -0.05591209  0.7391405  -0.08961776 -0.09178029
   -0.47115833 -0.27806604  0.35383248]
  [ 0.3964765  -0.01050393  0.7366462   0.03638346 -0.03574796
   -0.41335842 -0.23882627  0.28892466]
  [ 1.0575086   0.23200202  0.8150203   0.7750988   0.42505968
    0.24064866  0.46888143  0.26767123]]]

本文总结

本文简单介绍了LSTM的基本原理,然后结合MindSpore中文档说明,通过案例解说详细介绍参数设定和输入输出情况,让读者更好的理解MindSpore中的LSTM算子。

本文参考

  • LSTM基本原理
  • simplified-deeplearning/LSTM
  • LSTM API

本文为原创文章,版权归作者所有,未经授权不得转载!

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号