Browse Source

LSTM network adapt to cpu target.

tags/v0.5.0-beta
caojian05 5 years ago
parent
commit
a3cee90b25
1 changed files with 19 additions and 1 deletions
  1. +19
    -1
      mindspore/model_zoo/lstm.py

+ 19
- 1
mindspore/model_zoo/lstm.py View File

@@ -17,7 +17,7 @@ import math

import numpy as np

from mindspore import Parameter, Tensor, nn
from mindspore import Parameter, Tensor, nn, context, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P

@@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
if bidirectional:
num_directions = 2

if context.get_context("device_target") == "CPU":
h_list = []
c_list = []
for i in range(num_layers):
hi = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='h' + str(i))
h_list.append(hi)
ci = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='c' + str(i))
c_list.append(ci)
h = ParameterTuple(tuple(h_list))
c = ParameterTuple(tuple(c_list))
return h, c

h = Tensor(
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
c = Tensor(


Loading…
Cancel
Save