From dd593c674a8fe247c337249a0f65b31b54497a65 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Mon, 15 Jun 2020 10:39:43 +0800 Subject: [PATCH] fix basic lstm cell bp error --- mindspore/ops/_grad/grad_nn_ops.py | 2 +- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index e998afb269..4c4acb802c 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -726,7 +726,7 @@ def get_bprop_basic_lstm_cell(self): def bprop(x, h, c, w, b, out, dout): _, _, it, jt, ft, ot, tanhct = out dct, dht, _, _, _, _, _ = dout - dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct) + dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, ft, jt, ot, tanhct) dxt, dht = basic_lstm_cell_input_grad(dgate, w) dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) return dxt, dht, dct_1, dw, db diff --git a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py index 099756ad35..440b1ce2c7 100644 --- a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +++ b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py @@ -37,10 +37,10 @@ basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \ .output(1, "dct_1", False, "required", "all") \ .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, - DataType.F16_FracNZ, DataType.F16_FracNZ) \ + DataType.F16_FracNZ, DataType.F32_FracNZ) \ .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, - DataType.F32_FracNZ, DataType.F16_FracNZ) \ + DataType.F16_FracNZ, DataType.F16_FracNZ) \ .get_op_info()