Browse Source

!1606 LSTM network adapt to cpu target.

Merge pull request !1606 from caojian05/ms_master_dev
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
c0d38e40a4
1 changed files with 19 additions and 1 deletions
  1. +19
    -1
      mindspore/model_zoo/lstm.py

+ 19
- 1
mindspore/model_zoo/lstm.py View File

@@ -17,7 +17,7 @@ import math


import numpy as np import numpy as np


from mindspore import Parameter, Tensor, nn
from mindspore import Parameter, Tensor, nn, context, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops import operations as P from mindspore.ops import operations as P


@@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
if bidirectional: if bidirectional:
num_directions = 2 num_directions = 2


if context.get_context("device_target") == "CPU":
h_list = []
c_list = []
for i in range(num_layers):
hi = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='h' + str(i))
h_list.append(hi)
ci = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='c' + str(i))
c_list.append(ci)
h = ParameterTuple(tuple(h_list))
c = ParameterTuple(tuple(c_list))
return h, c

h = Tensor( h = Tensor(
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
c = Tensor( c = Tensor(


Loading…
Cancel
Save