Browse Source

!12933 Fix lstm output data type of Ascend.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
63ddc78145
2 changed files with 10 additions and 4 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc
  2. +9
    -3
      mindspore/nn/layer/lstm.py

+ 1
- 1
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc View File

@@ -60,7 +60,7 @@ AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_
} }
} }
new_shape.erase(new_shape.begin() + axis + 1); new_shape.erase(new_shape.begin() + axis + 1);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {output_shape},
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {new_shape},
new_pack.get()); new_pack.get());
return new_pack; return new_pack;
} }


+ 9
- 3
mindspore/nn/layer/lstm.py View File

@@ -253,11 +253,14 @@ class LSTM(Cell):
x = self.transpose(x, (1, 0, 2)) x = self.transpose(x, (1, 0, 2))
h, c = hx h, c = hx
if self.is_ascend: if self.is_ascend:
x_dtype = F.dtype(x)
h_dtype = F.dtype(h)
c_dtype = F.dtype(c)
_check_input_3d(F.shape(h), "h of hx", self.cls_name) _check_input_3d(F.shape(h), "h of hx", self.cls_name)
_check_input_3d(F.shape(c), "c of hx", self.cls_name) _check_input_3d(F.shape(c), "c of hx", self.cls_name)
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(x_dtype, "x", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(h_dtype, "h", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(c_dtype, "c", [mstype.float32, mstype.float16], self.cls_name)
x = self.cast(x, mstype.float16) x = self.cast(x, mstype.float16)
h = self.cast(h, mstype.float16) h = self.cast(h, mstype.float16)
c = self.cast(c, mstype.float16) c = self.cast(c, mstype.float16)
@@ -265,6 +268,9 @@ class LSTM(Cell):
x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list) x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list)
else: else:
x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list) x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list)
x = self.cast(x, x_dtype)
h = self.cast(h, h_dtype)
c = self.cast(c, c_dtype)
else: else:
x, h, c, _, _ = self.lstm(x, h, c, self.weight) x, h, c, _, _ = self.lstm(x, h, c, self.weight)
if self.batch_first: if self.batch_first:


Loading…
Cancel
Save