Browse Source

LSTM network adapt to cpu target.

tags/v0.5.0-beta
caojian05 5 years ago
parent
commit
a3cee90b25
1 changed files with 19 additions and 1 deletions
  1. +19
    -1
      mindspore/model_zoo/lstm.py

+ 19
- 1
mindspore/model_zoo/lstm.py View File

@@ -17,7 +17,7 @@ import math


import numpy as np import numpy as np


from mindspore import Parameter, Tensor, nn
from mindspore import Parameter, Tensor, nn, context, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops import operations as P from mindspore.ops import operations as P


@@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
if bidirectional: if bidirectional:
num_directions = 2 num_directions = 2


if context.get_context("device_target") == "CPU":
h_list = []
c_list = []
for i in range(num_layers):
hi = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='h' + str(i))
h_list.append(hi)
ci = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='c' + str(i))
c_list.append(ci)
h = ParameterTuple(tuple(h_list))
c = ParameterTuple(tuple(c_list))
return h, c

h = Tensor( h = Tensor(
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
c = Tensor( c = Tensor(


Loading…
Cancel
Save