| @@ -726,7 +726,7 @@ def get_bprop_basic_lstm_cell(self): | |||||
| def bprop(x, h, c, w, b, out, dout): | def bprop(x, h, c, w, b, out, dout): | ||||
| _, _, it, jt, ft, ot, tanhct = out | _, _, it, jt, ft, ot, tanhct = out | ||||
| dct, dht, _, _, _, _, _ = dout | dct, dht, _, _, _, _, _ = dout | ||||
| dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct) | |||||
| dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, ft, jt, ot, tanhct) | |||||
| dxt, dht = basic_lstm_cell_input_grad(dgate, w) | dxt, dht = basic_lstm_cell_input_grad(dgate, w) | ||||
| dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) | dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) | ||||
| return dxt, dht, dct_1, dw, db | return dxt, dht, dct_1, dw, db | ||||
| @@ -37,10 +37,10 @@ basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \ | |||||
| .output(1, "dct_1", False, "required", "all") \ | .output(1, "dct_1", False, "required", "all") \ | ||||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | ||||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | ||||
| DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||||
| DataType.F16_FracNZ, DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | ||||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | ||||
| DataType.F32_FracNZ, DataType.F16_FracNZ) \ | |||||
| DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||||
| .get_op_info() | .get_op_info() | ||||