Browse Source

fix basic lstm cell bp error

tags/v0.6.0-beta
zhaozhenlong 5 years ago
parent
commit
dd593c674a
2 changed files with 3 additions and 3 deletions
  1. +1
    -1
      mindspore/ops/_grad/grad_nn_ops.py
  2. +2
    -2
      mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py

+ 1
- 1
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -726,7 +726,7 @@ def get_bprop_basic_lstm_cell(self):
def bprop(x, h, c, w, b, out, dout): def bprop(x, h, c, w, b, out, dout):
_, _, it, jt, ft, ot, tanhct = out _, _, it, jt, ft, ot, tanhct = out
dct, dht, _, _, _, _, _ = dout dct, dht, _, _, _, _, _ = dout
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, ft, jt, ot, tanhct)
dxt, dht = basic_lstm_cell_input_grad(dgate, w) dxt, dht = basic_lstm_cell_input_grad(dgate, w)
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
return dxt, dht, dct_1, dw, db return dxt, dht, dct_1, dw, db


+ 2
- 2
mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py View File

@@ -37,10 +37,10 @@ basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \
.output(1, "dct_1", False, "required", "all") \ .output(1, "dct_1", False, "required", "all") \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ) \
DataType.F16_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F32_FracNZ, DataType.F16_FracNZ) \
DataType.F16_FracNZ, DataType.F16_FracNZ) \
.get_op_info() .get_op_info()






Loading…
Cancel
Save