|
|
|
@@ -726,7 +726,7 @@ def get_bprop_basic_lstm_cell(self): |
|
|
|
def bprop(x, h, c, w, b, out, dout): |
|
|
|
_, _, it, jt, ft, ot, tanhct = out |
|
|
|
dct, dht, _, _, _, _, _ = dout |
|
|
|
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct) |
|
|
|
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, ft, jt, ot, tanhct) |
|
|
|
dxt, dht = basic_lstm_cell_input_grad(dgate, w) |
|
|
|
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) |
|
|
|
return dxt, dht, dct_1, dw, db |
|
|
|
|