Browse Source

fix lstm output.

tags/v1.2.0-rc1
liuxiao93 4 years ago
parent
commit
ba3bde6f85
2 changed files with 10 additions and 4 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc
  2. +9
    -3
      mindspore/nn/layer/lstm.py

+ 1
- 1
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc View File

@@ -60,7 +60,7 @@ AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_
}
}
new_shape.erase(new_shape.begin() + axis + 1);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {output_shape},
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {new_shape},
new_pack.get());
return new_pack;
}


+ 9
- 3
mindspore/nn/layer/lstm.py View File

@@ -253,11 +253,14 @@ class LSTM(Cell):
x = self.transpose(x, (1, 0, 2))
h, c = hx
if self.is_ascend:
x_dtype = F.dtype(x)
h_dtype = F.dtype(h)
c_dtype = F.dtype(c)
_check_input_3d(F.shape(h), "h of hx", self.cls_name)
_check_input_3d(F.shape(c), "c of hx", self.cls_name)
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(x_dtype, "x", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(h_dtype, "h", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(c_dtype, "c", [mstype.float32, mstype.float16], self.cls_name)
x = self.cast(x, mstype.float16)
h = self.cast(h, mstype.float16)
c = self.cast(c, mstype.float16)
@@ -265,6 +268,9 @@ class LSTM(Cell):
x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list)
else:
x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list)
x = self.cast(x, x_dtype)
h = self.cast(h, h_dtype)
c = self.cast(c, c_dtype)
else:
x, h, c, _, _ = self.lstm(x, h, c, self.weight)
if self.batch_first:


Loading…
Cancel
Save