Merge pull request !31665 from Haim/export_haimr1.7
| @@ -36,6 +36,37 @@ void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, | |||
| } | |||
| } | |||
| void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, | |||
| bool is_bidirectional, int stride, const int *order) { | |||
| int unidirectional_batch = is_bidirectional ? batch / 2 : batch; | |||
| for (int i = 0; i < unidirectional_batch; i++) { | |||
| const float *src_batch = src + i * col * deep; | |||
| float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align * deep; | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col16Major(src_batch, dst_batch, col, deep); | |||
| #elif defined(ENABLE_ARM32) | |||
| RowMajor2Col4Major(src_batch, dst_batch, col, deep); | |||
| #else | |||
| RowMajor2Col8Major(src_batch, dst_batch, col, deep); | |||
| #endif | |||
| } | |||
| src += stride; | |||
| dst += unidirectional_batch * col_align * deep; | |||
| if (is_bidirectional) { | |||
| for (int i = 0; i < unidirectional_batch; i++) { | |||
| const float *src_batch = src + i * col * deep; | |||
| float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align * deep; | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col16Major(src_batch, dst_batch, col, deep); | |||
| #elif defined(ENABLE_ARM32) | |||
| RowMajor2Col4Major(src_batch, dst_batch, col, deep); | |||
| #else | |||
| RowMajor2Col8Major(src_batch, dst_batch, col, deep); | |||
| #endif | |||
| } | |||
| } | |||
| } | |||
| void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, | |||
| const int *order) { | |||
| int unidirectional_batch = is_bidirectional ? batch / 2 : batch; | |||
| @@ -55,6 +86,25 @@ void PackLstmBias(float *dst, const float *src, int batch, int col, int col_alig | |||
| } | |||
| } | |||
| void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, | |||
| int b_stride, const int *order) { | |||
| int unidirectional_batch = is_bidirectional ? batch / 2 : batch; | |||
| for (int i = 0; i < unidirectional_batch; i++) { | |||
| const float *src_batch = src + i * col; | |||
| float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align; | |||
| memcpy(dst_batch, src_batch, col * sizeof(float)); | |||
| } | |||
| if (is_bidirectional) { | |||
| const float *backward_src = src + b_stride; | |||
| float *backward_dst = dst + unidirectional_batch * col_align; | |||
| for (int i = 0; i < unidirectional_batch; i++) { | |||
| const float *backward_src_batch = backward_src + i * col; | |||
| float *backward_dst_batch = backward_dst + ((order == NULL) ? i : order[i]) * col_align; | |||
| memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float)); | |||
| } | |||
| } | |||
| } | |||
| void PackLstmInput(const float *src, float *dst, int row, int deep) { | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col6Major(src, dst, row, deep); | |||
| @@ -23,9 +23,15 @@ extern "C" { | |||
| #endif | |||
| void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int *order); | |||
| void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, | |||
| bool is_bidirectional, int stride, const int *order); | |||
| void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, | |||
| const int *order); | |||
| void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, | |||
| int b_stride, const int *order); | |||
| void PackLstmInput(const float *src, float *dst, int row, int deep); | |||
| void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align, | |||
| @@ -84,7 +84,7 @@ int GetGemmMatMullWorkspace(int batch, int input_size, int hidden_size) { | |||
| int workspace_size, temp; | |||
| // if the appropriate GemmMatNul use beta>0 matSizeTotal must have col as last parameter. | |||
| workspace_size = MatSizeTotal(batch, input_size, hidden_size, input_size); | |||
| temp = MatSizeTotal(hidden_size, batch, hidden_size, batch); | |||
| temp = MatSizeTotal(batch, hidden_size, hidden_size, hidden_size); | |||
| workspace_size = (temp > workspace_size) ? temp : workspace_size; | |||
| temp = MatSizeTotal(hidden_size, input_size, batch, input_size); | |||
| workspace_size = (temp > workspace_size) ? temp : workspace_size; | |||
| @@ -94,103 +94,105 @@ int GetGemmMatMullWorkspace(int batch, int input_size, int hidden_size) { | |||
| } | |||
| int GetRunWorkspaceSize(const LstmGradParameter *lstm_param) { | |||
| int workspace_size = no_of_temp_matrices_sized_output_step * lstm_param->output_step_; | |||
| int time_stamp_len = lstm_param->batch_ * lstm_param->hidden_size_; | |||
| int workspace_size = no_of_temp_matrices_sized_output_step * time_stamp_len; | |||
| workspace_size += GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_); | |||
| return workspace_size; | |||
| } | |||
| size_t GetRunWorkspaceGemmOffset(const LstmGradParameter *lstm_param) { | |||
| return no_of_temp_matrices_sized_output_step * lstm_param->output_step_; | |||
| int time_stamp_len = lstm_param->batch_ * lstm_param->hidden_size_; | |||
| return no_of_temp_matrices_sized_output_step * time_stamp_len; | |||
| } | |||
| void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate, | |||
| float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX, | |||
| float *weights, float *workspace, const LstmGradParameter *lstm_param) { | |||
| float *w, float *v, float *workspace, const LstmGradParameter *lstm_param) { | |||
| float *scratchPad = workspace; | |||
| float *temp0 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| float *temp1 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| float *temp2 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| float *temp3 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| float *temp4 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| int seq_len = lstm_param->batch_ * lstm_param->hidden_size_; | |||
| float *temp0 = AllocteFromScrachPad(&scratchPad, seq_len); | |||
| float *temp1 = AllocteFromScrachPad(&scratchPad, seq_len); | |||
| float *temp2 = AllocteFromScrachPad(&scratchPad, seq_len); | |||
| float *temp3 = AllocteFromScrachPad(&scratchPad, seq_len); | |||
| float *temp4 = AllocteFromScrachPad(&scratchPad, seq_len); | |||
| // Accumulate gradients into dH | |||
| ElementAdd(dH, dY, dH, lstm_param->output_step_); | |||
| ElementAdd(dH, dY, dH, seq_len); | |||
| ElementMul(dH, output_gate, temp1, lstm_param->output_step_); | |||
| Tanh(cell_state, lstm_param->output_step_, temp0); | |||
| ElementMul(temp0, temp0, temp2, lstm_param->output_step_); | |||
| ElementMul(temp1, temp2, temp4, lstm_param->output_step_); | |||
| ElementSub(temp1, temp4, temp1, lstm_param->output_step_); | |||
| ElementAdd(dC, temp1, dC, lstm_param->output_step_); | |||
| ElementMul(dH, output_gate, temp1, seq_len); | |||
| Tanh(cell_state, seq_len, temp0); | |||
| ElementMul(temp0, temp0, temp2, seq_len); | |||
| ElementMul(temp1, temp2, temp4, seq_len); | |||
| ElementSub(temp1, temp4, temp1, seq_len); | |||
| ElementAdd(dC, temp1, dC, seq_len); | |||
| // calculate dI, dO, dF and dG | |||
| float *dI = temp1; // dI = dC_{t} * G | |||
| ElementMul(dC, cell_gate, dI, lstm_param->output_step_); | |||
| ElementMul(dC, cell_gate, dI, seq_len); | |||
| float *dO = temp2; // dO = dH * Tanh(C_{t}) | |||
| ElementMul(dH, temp0, dO, lstm_param->output_step_); | |||
| ElementMul(dH, temp0, dO, seq_len); | |||
| float *dF = temp3; // dF = dC_{t} * C_{t-1} | |||
| ElementMul(dC, prev_cell_state, dF, lstm_param->output_step_); | |||
| ElementMul(dC, prev_cell_state, dF, seq_len); | |||
| float *dG = temp4; // dG = dC_{t} * I | |||
| ElementMul(dC, input_gate, dG, lstm_param->output_step_); | |||
| ElementMul(dC, input_gate, dG, seq_len); | |||
| // dAi = dI * I * (1 - I) | |||
| float *dAi = temp1; | |||
| *dA = dAi; | |||
| ElementMul(dI, input_gate, dAi, lstm_param->output_step_); | |||
| ElementMul(dAi, input_gate, temp0, lstm_param->output_step_); | |||
| ElementSub(dAi, temp0, dAi, lstm_param->output_step_); | |||
| ElementMul(dI, input_gate, dAi, seq_len); | |||
| ElementMul(dAi, input_gate, temp0, seq_len); | |||
| ElementSub(dAi, temp0, dAi, seq_len); | |||
| // dAo = dO * O * (1 - O) | |||
| float *dAo = temp2; | |||
| ElementMul(dO, output_gate, dAo, lstm_param->output_step_); | |||
| ElementMul(dAo, output_gate, temp0, lstm_param->output_step_); | |||
| ElementSub(dAo, temp0, dAo, lstm_param->output_step_); | |||
| ElementMul(dO, output_gate, dAo, seq_len); | |||
| ElementMul(dAo, output_gate, temp0, seq_len); | |||
| ElementSub(dAo, temp0, dAo, seq_len); | |||
| // dAf = dF * F * (1 - F) | |||
| float *dAf = temp3; | |||
| ElementMul(dF, forget_gate, dAf, lstm_param->output_step_); | |||
| ElementMul(dAf, forget_gate, temp0, lstm_param->output_step_); | |||
| ElementSub(dAf, temp0, dAf, lstm_param->output_step_); | |||
| ElementMul(dF, forget_gate, dAf, seq_len); | |||
| ElementMul(dAf, forget_gate, temp0, seq_len); | |||
| ElementSub(dAf, temp0, dAf, seq_len); | |||
| float *dAg = temp4; | |||
| ElementMul(cell_gate, cell_gate, temp0, lstm_param->output_step_); | |||
| ElementMul(dG, temp0, temp0, lstm_param->output_step_); | |||
| ElementSub(dG, temp0, dAg, lstm_param->output_step_); | |||
| ElementMul(cell_gate, cell_gate, temp0, seq_len); | |||
| ElementMul(dG, temp0, temp0, seq_len); | |||
| ElementSub(dG, temp0, dAg, seq_len); | |||
| // calculate dX | |||
| size_t dX_size = lstm_param->batch_ * lstm_param->input_size_ * sizeof(float); | |||
| memset(dX, 0, dX_size); | |||
| float *mat_workspace = AllocteFromScrachPad( | |||
| &scratchPad, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_)); | |||
| float *weights_loop = weights; | |||
| float *weights_loop = w; | |||
| float *dA_loop = dAi; // dAi, dAo, dAf, dAg | |||
| for (int idx = 0; idx < num_of_gates; idx++) { | |||
| GemmMatmul(0, 0, lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_, 1.0, dA_loop, | |||
| lstm_param->hidden_size_, weights_loop, lstm_param->input_size_, 1.0, dX, lstm_param->input_size_, | |||
| mat_workspace); | |||
| weights_loop += lstm_param->hidden_size_ * lstm_param->input_size_; | |||
| dA_loop += lstm_param->output_step_; | |||
| dA_loop += seq_len; | |||
| } | |||
| // calculate dH next | |||
| size_t dH_size = lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float); | |||
| memset(dH, 0, dH_size); | |||
| dA_loop = dAi; | |||
| weights_loop = v; | |||
| for (int idx = 0; idx < num_of_gates; idx++) { | |||
| GemmMatmul(0, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, 1.0, dA_loop, | |||
| lstm_param->hidden_size_, weights_loop, lstm_param->hidden_size_, 1.0, dH, lstm_param->hidden_size_, | |||
| mat_workspace); | |||
| weights_loop += lstm_param->hidden_size_ * lstm_param->hidden_size_; | |||
| dA_loop += lstm_param->output_step_; | |||
| dA_loop += seq_len; | |||
| } | |||
| // calculate dC next | |||
| ElementMul(dC, forget_gate, dC, lstm_param->output_step_); | |||
| ElementMul(dC, forget_gate, dC, seq_len); | |||
| } | |||
| void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *workspace, | |||
| const LstmGradParameter *lstm_param) { | |||
| void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *dV, float *dB, | |||
| float *workspace, const LstmGradParameter *lstm_param) { | |||
| // Calc dWi, dWo, dWf, dWg, dVi, dVo, dVf, dVg, dBi, dBo, dBf, dBg | |||
| int seq_len = lstm_param->batch_ * lstm_param->hidden_size_; | |||
| float *mat_workspace = AllocteFromScrachPad( | |||
| &workspace, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_)); | |||
| float *dA_loop = dA; // dAi, dAo, dAf, dAg | |||
| @@ -198,10 +200,10 @@ void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, f | |||
| int dV_size = lstm_param->hidden_size_ * lstm_param->hidden_size_; | |||
| int dB_size = 0; | |||
| float *dW_loop = dW; | |||
| float *dV_loop = dW + (num_of_gates * dW_size); | |||
| float *dV_loop = dV; | |||
| float *dB_loop = 0; | |||
| if (lstm_param->has_bias_) { | |||
| dB_loop = dW + (num_of_gates * (dW_size + dV_size)); | |||
| dB_loop = dB; | |||
| dB_size = lstm_param->hidden_size_; | |||
| } | |||
| @@ -218,120 +220,7 @@ void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, f | |||
| if (dB_loop != 0) { | |||
| sumCols(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dA_loop, dB_loop, true); | |||
| } | |||
| dA_loop += lstm_param->output_step_; | |||
| dW_loop += dW_size; | |||
| dV_loop += dV_size; | |||
| dB_loop += dB_size; | |||
| } | |||
| } | |||
| void LstmGradDoStep(const float *output_gate, float *cell_state, float *cell_state_minus1, float *cell_gate, | |||
| float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float *dX, float *weights, | |||
| float *dW, float *hidden_state, float *input_t, float *workspace, | |||
| const LstmGradParameter *lstm_param) { | |||
| float *workspace_i = workspace; | |||
| float buffer[1024]; | |||
| float *scratchPad = buffer; | |||
| float *tanh_c = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| float *temp = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| float *temp2 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| float *tanh_c_sqr = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| // Accumulate gradients into dH | |||
| ElementAdd(dH, dY, dH, lstm_param->output_step_); | |||
| ElementMul(dH, output_gate, temp2, lstm_param->output_step_); | |||
| Tanh(cell_state, lstm_param->output_step_, tanh_c); | |||
| ElementMul(tanh_c, tanh_c, tanh_c_sqr, lstm_param->output_step_); | |||
| ElementMul(temp2, tanh_c_sqr, temp, lstm_param->output_step_); | |||
| ElementSub(temp2, temp, temp2, lstm_param->output_step_); | |||
| ElementAdd(dC, temp2, dC, lstm_param->output_step_); | |||
| // calculate dI, dO, dF and dG | |||
| float *dI = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dI = dC_{t} * G | |||
| ElementMul(dC, cell_gate, dI, lstm_param->output_step_); | |||
| float *dO = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dO = dH * Tanh(C_{t}) | |||
| ElementMul(dH, tanh_c, dO, lstm_param->output_step_); | |||
| float *dF = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dF = dC_{t} * C_{t-1} | |||
| ElementMul(dC, cell_state_minus1, dF, lstm_param->output_step_); | |||
| float *dG = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dG = dC_{t} * I | |||
| ElementMul(dC, input_gate, dG, lstm_param->output_step_); | |||
| // dAi = dI * I * (1 - I) | |||
| float *dAi = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| ElementMul(dI, input_gate, dAi, lstm_param->output_step_); | |||
| ElementMul(dAi, input_gate, temp, lstm_param->output_step_); | |||
| ElementSub(dAi, temp, dAi, lstm_param->output_step_); | |||
| // dAo = dO * O * (1 - O) | |||
| float *dAo = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| ElementMul(dO, output_gate, dAo, lstm_param->output_step_); | |||
| ElementMul(dAo, output_gate, temp, lstm_param->output_step_); | |||
| ElementSub(dAo, temp, dAo, lstm_param->output_step_); | |||
| // dAf = dF * F * (1 - F) | |||
| float *dAf = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| ElementMul(dF, forget_gate, dAf, lstm_param->output_step_); | |||
| ElementMul(dAf, forget_gate, temp, lstm_param->output_step_); | |||
| ElementSub(dAf, temp, dAf, lstm_param->output_step_); | |||
| // dAg = dG * (1 - G^2) | |||
| float *dAg = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); | |||
| ElementMul(cell_gate, cell_gate, dAg, lstm_param->output_step_); | |||
| ElementMul(dG, dAg, dAg, lstm_param->output_step_); | |||
| ElementSub(dG, dAg, dAg, lstm_param->output_step_); | |||
| // calculate dX | |||
| float *mat_workspace = AllocteFromScrachPad( | |||
| &workspace_i, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_)); | |||
| float *weights_loop = weights; | |||
| float *dA_loop = dAi; // dAi, dAo, dAf, dAg | |||
| for (int idx = 0; idx < num_of_gates; idx++) { | |||
| GemmMatmul(0, 0, lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_, 1.0, dA_loop, | |||
| lstm_param->hidden_size_, weights_loop, lstm_param->input_size_, 1.0, dX, lstm_param->input_size_, | |||
| mat_workspace); | |||
| weights_loop += lstm_param->hidden_size_ * lstm_param->input_size_; | |||
| dA_loop += lstm_param->output_step_; | |||
| } | |||
| // calculate dH next | |||
| size_t dH_size = lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float); | |||
| memset(dH, 0, dH_size); | |||
| dA_loop = dAi; | |||
| for (int idx = 0; idx < num_of_gates; idx++) { | |||
| GemmMatmul(0, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, 1.0, dA_loop, | |||
| lstm_param->hidden_size_, weights_loop, lstm_param->hidden_size_, 1.0, dH, lstm_param->hidden_size_, | |||
| mat_workspace); | |||
| weights_loop += lstm_param->hidden_size_ * lstm_param->hidden_size_; | |||
| dA_loop += lstm_param->output_step_; | |||
| } | |||
| // calculate dC next | |||
| ElementMul(dC, forget_gate, dC, lstm_param->output_step_); | |||
| // Calc dWi, dWo, dWf, dWg, dVi, dVo, dVf, dVg, dBi, dBo, dBf, dBg | |||
| dA_loop = dAi; | |||
| int dW_size = lstm_param->input_size_ * lstm_param->hidden_size_; | |||
| int dV_size = lstm_param->hidden_size_ * lstm_param->hidden_size_; | |||
| int dB_size = lstm_param->hidden_size_; | |||
| float *dW_loop = dW; | |||
| float *dV_loop = dW + (num_of_gates * dW_size); | |||
| float *dB_loop = dW + (num_of_gates * (dW_size + dV_size)); | |||
| for (int idx = 0; idx < num_of_gates; idx++) { | |||
| // Calc dW | |||
| GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->input_size_, lstm_param->batch_, 1.0, dA_loop, | |||
| lstm_param->hidden_size_, input_t, lstm_param->input_size_, 1.0, dW_loop, lstm_param->input_size_, | |||
| mat_workspace); | |||
| // Calc dV | |||
| if (hidden_state != 0) { | |||
| GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->batch_, 1.0, dA_loop, | |||
| lstm_param->hidden_size_, hidden_state, lstm_param->hidden_size_, 1.0, dV_loop, | |||
| lstm_param->hidden_size_, mat_workspace); | |||
| } | |||
| // Clac dB | |||
| sumCols(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dA_loop, dB_loop, true); | |||
| dA_loop += lstm_param->output_step_; | |||
| dA_loop += seq_len; | |||
| dW_loop += dW_size; | |||
| dV_loop += dV_size; | |||
| dB_loop += dB_size; | |||
| @@ -58,10 +58,10 @@ void ReorderLstmWeights(float *dst, const float *src, int nof_martices, int col, | |||
| void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate, | |||
| float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX, | |||
| float *weights, float *workspace, const LstmGradParameter *lstm_param); | |||
| float *w, float *v, float *workspace, const LstmGradParameter *lstm_param); | |||
| void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *workspace, | |||
| const LstmGradParameter *lstm_param); | |||
| void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *dV, float *dB, | |||
| float *workspace, const LstmGradParameter *lstm_param); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -51,7 +51,8 @@ int LstmGradWeightInferShape(const TensorC *const *inputs, size_t inputs_size, T | |||
| if (has_bias) { | |||
| output_shape[0] += C2NUM * gate_size; | |||
| } | |||
| int dir_mul = (param->bidirectional_) ? C2NUM : C1NUM; | |||
| output_shape[0] *= dir_mul; | |||
| SetShapeArray(output, output_shape, C3NUM); | |||
| return NNACL_OK; | |||
| @@ -18,7 +18,7 @@ | |||
| #include "nnacl/infer/infer_register.h" | |||
| static const int num_of_gates = 4; | |||
| static const int no_of_recorde_values = 7; | |||
| static const int no_of_recorde_values = 6; | |||
| int CheckInputShapeValid(const TensorC *const *inputs, const LstmParameter *parameter) { | |||
| const TensorC *input = inputs[FIRST_INPUT]; | |||
| @@ -70,14 +70,14 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o | |||
| if (!InferFlag(inputs, inputs_size)) { | |||
| return NNACL_INFER_INVALID; | |||
| } | |||
| int dir_multiplier = param->bidirectional_ ? 2 : 1; | |||
| int out_shape[MAX_SHAPE_SIZE]; | |||
| size_t out_shape_size = 0; | |||
| int hidden_size = 1; | |||
| ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_); | |||
| if (inputs_size == DIMENSION_4D) { // if input from MINDIR | |||
| hidden_size = weight_i->shape_[THIRD_INPUT]; | |||
| out_shape[THIRD_INPUT] = hidden_size; | |||
| out_shape[THIRD_INPUT] = hidden_size * dir_multiplier; | |||
| } else { | |||
| if (CheckInputShapeValid(inputs, param) != NNACL_OK) { | |||
| return NNACL_ERR; | |||
| @@ -99,7 +99,7 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o | |||
| SetShapeArray(output, out_shape, out_shape_size); | |||
| int state_shape[MAX_SHAPE_SIZE]; | |||
| size_t state_shape_size = 0; | |||
| int dir_multiplier = param->bidirectional_ ? 2 : 1; | |||
| ShapeSet(state_shape, &state_shape_size, input->shape_, input->shape_size_); | |||
| state_shape[FIRST_INPUT] = dir_multiplier; | |||
| state_shape[THIRD_INPUT] = hidden_size; | |||
| @@ -74,6 +74,7 @@ void LstmCPUKernel::FreeRunBuffer() { | |||
| if (output_need_packed_) { | |||
| ms_context_->allocator->Free(buffer_[avx_state_output_index]); | |||
| } | |||
| ms_context_->allocator->Free(buffer_[tmp_hidden_output_index]); | |||
| } | |||
| int LstmCPUKernel::InitInputWeightBias() { | |||
| @@ -91,10 +92,16 @@ int LstmCPUKernel::InitInputWeightBias() { | |||
| const int *weights_order = (in_tensors_.size() == mindir_input_tensors) ? weights_order_IFOG : nullptr; | |||
| auto weight_i = in_tensors_.at(i_index); | |||
| auto weight_i_data = reinterpret_cast<float *>(weight_i->data()); | |||
| CHECK_NULL_RETURN(weight_i_data); | |||
| PackLstmWeight(weight_i_ptr_, weight_i_data, weight_batch_, lstm_param_->input_size_, lstm_param_->hidden_size_, | |||
| lstm_param_->input_col_align_, weights_order); | |||
| CHECK_NULL_RETURN(weight_i_data); | |||
| int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_); | |||
| int hh_size = (lstm_param_->hidden_size_ * lstm_param_->hidden_size_); | |||
| int b_size = (lstm_param_->hidden_size_); | |||
| bool has_bias = (weight_batch_ * (cw_size + hh_size) < weight_i->ElementsNum()) ? true : false; | |||
| int stride = (gpu_orig_state_) ? gate_num * (cw_size + hh_size) : gate_num * (cw_size); | |||
| PackLstmWeightWithStride(weight_i_ptr_, weight_i_data, weight_batch_, lstm_param_->input_size_, | |||
| lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_, | |||
| stride, weights_order); | |||
| // input bias | |||
| input_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float))); | |||
| if (input_bias_ == nullptr) { | |||
| @@ -103,18 +110,19 @@ int LstmCPUKernel::InitInputWeightBias() { | |||
| } | |||
| memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float)); | |||
| float *bias_data = nullptr; | |||
| int offset = gate_num * lstm_param_->hidden_size_ * (lstm_param_->input_size_ + lstm_param_->hidden_size_); | |||
| if (weight_i->ElementsNum() > offset) { | |||
| bias_data = weight_i_data + offset; | |||
| } | |||
| int offset = weight_batch_ * (cw_size + hh_size); | |||
| float *bias_data = (has_bias) ? weight_i_data + offset : nullptr; | |||
| int dir_mul = lstm_param_->bidirectional_ ? 2 : 1; | |||
| int b_stride = (gpu_orig_state_) ? gate_num * (dir_mul * b_size) : gate_num * (b_size); | |||
| if (in_tensors_.size() > mindir_input_tensors) { | |||
| bias_data = reinterpret_cast<float *>(in_tensors_.at(onnx_bias_index)->data()); | |||
| } | |||
| if (bias_data != nullptr) { | |||
| PackLstmBias(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, | |||
| lstm_param_->bidirectional_, weights_order); | |||
| } else { | |||
| if (bias_data != nullptr) { | |||
| PackLstmBiasWithStride(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_, | |||
| lstm_param_->input_col_align_, lstm_param_->bidirectional_, b_stride, weights_order); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -124,14 +132,23 @@ int LstmCPUKernel::InitStateWeightBias() { | |||
| // state -- row: batch; col: hidden_size | |||
| // weight -- row: hidden_size; col: hidden_size, need transpose | |||
| // result -- row: batch; col: hidden_size | |||
| int weight_i_size = gate_num * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| int weight_i_size = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| int h_index = (in_tensors_.size() == mindir_input_tensors) ? combined_weights_index : onnx_weight_h_index; | |||
| auto weight_h = in_tensors_.at(h_index); | |||
| auto weight_h_data = (reinterpret_cast<float *>(weight_h->data())); | |||
| int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_); | |||
| int hh_size = (lstm_param_->hidden_size_ * lstm_param_->hidden_size_); | |||
| int b_size = (lstm_param_->hidden_size_); | |||
| int stride = (gpu_orig_state_) ? gate_num * (cw_size + hh_size) : gate_num * (hh_size); | |||
| if (in_tensors_.size() == mindir_input_tensors) { | |||
| weight_h_data += weight_i_size; | |||
| if (gpu_orig_state_) { | |||
| weight_h_data += gate_num * cw_size; | |||
| } else { | |||
| weight_h_data += weight_i_size; | |||
| } | |||
| } | |||
| CHECK_NULL_RETURN(weight_h_data); | |||
| if (!state_is_vec_) { | |||
| weight_h_ptr_ = reinterpret_cast<float *>( | |||
| @@ -141,8 +158,9 @@ int LstmCPUKernel::InitStateWeightBias() { | |||
| return RET_ERROR; | |||
| } | |||
| const int *weights_order = (in_tensors_.size() == mindir_input_tensors) ? weights_order_IFOG : nullptr; | |||
| PackLstmWeight(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_, | |||
| lstm_param_->state_col_align_, weights_order); | |||
| PackLstmWeightWithStride(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->hidden_size_, | |||
| lstm_param_->hidden_size_, lstm_param_->state_col_align_, lstm_param_->bidirectional_, | |||
| stride, weights_order); | |||
| } else { | |||
| #ifdef ENABLE_AVX | |||
| weight_h_ptr_ = reinterpret_cast<float *>( | |||
| @@ -162,8 +180,8 @@ int LstmCPUKernel::InitStateWeightBias() { | |||
| } | |||
| // state bias | |||
| int weight_h_size = gate_num * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| int bias_size = gate_num * lstm_param_->hidden_size_; | |||
| int weight_h_size = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| int bias_size = weight_batch_ * lstm_param_->hidden_size_; | |||
| state_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float))); | |||
| if (state_bias_ == nullptr) { | |||
| MS_LOG(ERROR) << "LstmCPUKernel malloc state_bias_ error."; | |||
| @@ -179,9 +197,13 @@ int LstmCPUKernel::InitStateWeightBias() { | |||
| lstm_param_->bidirectional_, nullptr); | |||
| } else if (weight_h->ElementsNum() - weight_i_size - weight_h_size - C2NUM * bias_size == 0) { | |||
| // mindir from device "GPU", secend bias is also present order IFOG | |||
| float *state_bias = weight_h_data + weight_h_size + bias_size; | |||
| PackLstmBias(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, lstm_param_->state_col_align_, | |||
| lstm_param_->bidirectional_, weights_order_IFOG); | |||
| int dir_mul = lstm_param_->bidirectional_ ? 2 : 1; | |||
| int bias_offset = | |||
| (gpu_orig_state_) ? gate_num * ((dir_mul - 1) * cw_size + dir_mul * hh_size + b_size) : weight_h_size + bias_size; | |||
| float *state_bias = weight_h_data + bias_offset; | |||
| int b_stride = (gpu_orig_state_) ? gate_num * (b_size * C2NUM) : gate_num * b_size; | |||
| PackLstmBiasWithStride(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, | |||
| lstm_param_->state_col_align_, lstm_param_->bidirectional_, b_stride, weights_order_IFOG); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -202,11 +224,25 @@ int LstmCPUKernel::InitParam() { | |||
| } else { | |||
| lstm_param_->hidden_size_ = w_shape.at(SECOND_INPUT) / gate_num; | |||
| } | |||
| lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ | |||
| : lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| weight_batch_ = lstm_param_->bidirectional_ ? 2 * gate_num : gate_num; | |||
| state_is_vec_ = lstm_param_->batch_ == 1; | |||
| // determine FB origin | |||
| gpu_orig_state_ = false; | |||
| if (in_tensors_.size() == mindir_input_tensors) { | |||
| gpu_orig_state_ = gpu_orig_cfg_; | |||
| auto weight_t = in_tensors_.at(combined_weights_index); | |||
| int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_); | |||
| int hh_size = (lstm_param_->hidden_size_ * lstm_param_->hidden_size_); | |||
| int b_size = (lstm_param_->hidden_size_); | |||
| bool has_bias = (weight_batch_ * (cw_size + hh_size) < weight_t->ElementsNum()) ? true : false; | |||
| // if bias exist we can determine the gpu_orig_state_ | |||
| if (has_bias) { | |||
| gpu_orig_state_ = | |||
| (weight_batch_ * (cw_size + hh_size + C2NUM * b_size) == weight_t->ElementsNum()) ? true : false; | |||
| } | |||
| } | |||
| #ifdef ENABLE_AVX | |||
| row_tile_ = C6NUM; | |||
| @@ -279,7 +315,7 @@ int LstmCPUKernel::MallocRunBuffer() { | |||
| buffer_[packed_input_index] = reinterpret_cast<float *>( | |||
| ms_context_->allocator->Malloc(lstm_param_->input_row_align_ * lstm_param_->input_size_ * sizeof(float))); | |||
| if (buffer_[packed_input_index] == nullptr) { | |||
| MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight left matirx error."; | |||
| MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight left matrix error."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -338,6 +374,16 @@ int LstmCPUKernel::MallocRunBuffer() { | |||
| } | |||
| } | |||
| #endif | |||
| buffer_[tmp_hidden_output_index] = nullptr; | |||
| if (!(in_tensors_.size() > mindir_input_tensors)) { | |||
| auto buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); | |||
| buffer_[tmp_hidden_output_index] = reinterpret_cast<float *>(ms_context_->allocator->Malloc(buffer_size)); | |||
| if (buffer_[tmp_hidden_output_index] == nullptr) { | |||
| MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer for hidden error."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -367,7 +413,7 @@ int LstmInputMulWeightRun(void *cdata, int task_id, float, float) { | |||
| int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, const float *weight_h, | |||
| const float *input_bias, const float *state_bias, float *hidden_state, | |||
| float *cell_state, bool is_backward) { | |||
| float *cell_state, float *intermediate_states, bool is_backward) { | |||
| float *gate = buffer_[input_gate_index]; | |||
| for (int i = 0; i < gate_num; i++) { | |||
| weight_loop_ = weight_i + lstm_param_->input_size_ * lstm_param_->input_col_align_ * i; | |||
| @@ -383,30 +429,46 @@ int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, cons | |||
| float *forget_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * 2; | |||
| float *cell_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * 3; | |||
| float *output_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| float *tmp = buffer_[tmp_hidden_output_index]; | |||
| int dir_mult = lstm_param_->bidirectional_ ? 2 : 1; | |||
| for (int t = 0; t < lstm_param_->seq_len_; t++) { | |||
| int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t; | |||
| float *input_gate_t = input_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; | |||
| float *forget_gate_t = forget_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; | |||
| float *cell_gate_t = cell_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; | |||
| float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; | |||
| float *output_ptr = output + real_t * lstm_param_->output_step_; | |||
| LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, | |||
| hidden_state, cell_state, buffer_, lstm_param_); | |||
| if (IsTrain() && IsTrainable()) { | |||
| // if ONNX | |||
| if (in_tensors_.size() > mindir_input_tensors) { | |||
| // Sequence, DirMul, Batch, Hidden | |||
| float *output_ptr = output + real_t * lstm_param_->output_step_; | |||
| LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, | |||
| hidden_state, cell_state, buffer_, lstm_param_); | |||
| } else { | |||
| // Sequence, Batch, DirMul, Hidden | |||
| LstmStepUnit(tmp, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, hidden_state, | |||
| cell_state, buffer_, lstm_param_); | |||
| int seq_offset = real_t * lstm_param_->batch_ * dir_mult * lstm_param_->hidden_size_; | |||
| for (int b = 0; b < lstm_param_->batch_; b++) { | |||
| int batch_offset = b * dir_mult * lstm_param_->hidden_size_; | |||
| float *output_ptr = output + seq_offset + batch_offset; | |||
| memcpy(output_ptr, tmp + b * lstm_param_->hidden_size_, lstm_param_->hidden_size_ * sizeof(float)); | |||
| } | |||
| } | |||
| if (intermediate_states) { | |||
| RecordStates(hidden_state, cell_state, input_gate_t, output_gate_t, forget_gate_t, cell_gate_t, | |||
| is_backward ? real_t : t); | |||
| intermediate_states, real_t); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void LstmCPUKernel::RecordStates(float *hidden_state, float *cell_state, float *input_gate, float *output_gate, | |||
| float *forget_gate, float *cell_gate, int step) { | |||
| float *states = reinterpret_cast<float *>(out_tensors_[out_intermediate_states_index]->data()); | |||
| float *forget_gate, float *cell_gate, float *intermediate_states, int step) { | |||
| float *states = intermediate_states; | |||
| auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| auto stride = step * state_size; | |||
| auto seq_stride = lstm_param_->seq_len_ * state_size; | |||
| auto stride = step * lstm_param_->output_step_; | |||
| auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_; | |||
| memcpy(states + stride, hidden_state, state_size * sizeof(float)); | |||
| stride += seq_stride; | |||
| memcpy(states + stride, cell_state, state_size * sizeof(float)); | |||
| @@ -425,8 +487,12 @@ int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden | |||
| // buffer_[packed_input_index] : store packed input | |||
| PackLstmInput(input, buffer_[packed_input_index], lstm_param_->seq_len_ * lstm_param_->batch_, | |||
| lstm_param_->input_size_); | |||
| auto ret = | |||
| LstmUnidirectional(output, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, hidden_state, cell_state, false); | |||
| float *intermediate_states = nullptr; | |||
| if (IsTrain() && IsTrainable()) { | |||
| intermediate_states = reinterpret_cast<float *>(out_tensors_[out_intermediate_states_index]->data()); | |||
| } | |||
| auto ret = LstmUnidirectional(output, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, hidden_state, | |||
| cell_state, intermediate_states, false); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Lstm unidirectional calculation error."; | |||
| return RET_ERROR; | |||
| @@ -441,11 +507,17 @@ int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden | |||
| const float *backward_input_bias = input_bias_ + gate_num * lstm_param_->input_col_align_; | |||
| const float *backward_state_bias = state_bias_ + gate_num * lstm_param_->state_col_align_; | |||
| float *backward_output = output + lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| if (in_tensors_.size() == mindir_input_tensors) { | |||
| backward_output = output + lstm_param_->hidden_size_; | |||
| } | |||
| float *backward_cell_state = cell_state + lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| float *backward_hidden_state = hidden_state + lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| ret = LstmUnidirectional(backward_output, backward_weight_i, backward_weight_h, backward_input_bias, | |||
| backward_state_bias, backward_hidden_state, backward_cell_state, true); | |||
| if (intermediate_states) { | |||
| intermediate_states += lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| } | |||
| ret = | |||
| LstmUnidirectional(backward_output, backward_weight_i, backward_weight_h, backward_input_bias, | |||
| backward_state_bias, backward_hidden_state, backward_cell_state, intermediate_states, true); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Lstm bidirectional calculation error."; | |||
| return RET_ERROR; | |||
| @@ -47,10 +47,11 @@ class LstmCPUKernel : public InnerKernel { | |||
| int InitStateWeightBias(); | |||
| int LstmUnidirectional(float *output, const float *weight_i, const float *weight_h, const float *input_bias, | |||
| const float *state_bias, float *hidden_state, float *cell_state, bool is_backward); | |||
| const float *state_bias, float *hidden_state, float *cell_state, float *intermediate_states, | |||
| bool is_backward); | |||
| int InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state); | |||
| void RecordStates(float *hidden_state, float *cell_state, float *input_gate, float *output_gate, float *forget_gate, | |||
| float *cell_gate, int step); | |||
| float *cell_gate, float *intermediate_states, int step); | |||
| const float *weight_loop_; | |||
| const float *bias_loop_; | |||
| float *gate_loop_ = nullptr; | |||
| @@ -75,7 +76,7 @@ class LstmCPUKernel : public InnerKernel { | |||
| int hidden_state_input_index_ = onnx_hidden_state_index; | |||
| int cell_state_input_index_ = onnx_cell_state_index; | |||
| float *buffer_[7] = {nullptr}; | |||
| float *buffer_[8] = {nullptr}; | |||
| const int gate_num = 4; | |||
| const int packed_input_index = 0; | |||
| const int input_gate_index = 1; | |||
| @@ -84,6 +85,7 @@ class LstmCPUKernel : public InnerKernel { | |||
| const int cell_state_index = 4; | |||
| const int hidden_state_index = 5; | |||
| const int avx_state_output_index = 6; | |||
| const int tmp_hidden_output_index = 7; | |||
| static const int out_intermediate_states_index = 3; | |||
| const int weights_order_IFOG[2 * 4] = {0, 2, 3, 1, 4, 6, 7, 5}; // IFGO order to IOFG order | |||
| @@ -94,6 +96,9 @@ class LstmCPUKernel : public InnerKernel { | |||
| int weight_batch_ = 0; | |||
| bool state_is_vec_ = false; | |||
| bool output_need_packed_ = false; | |||
| // control weight layout | |||
| bool gpu_orig_state_ = true; | |||
| bool gpu_orig_cfg_ = true; | |||
| LstmParameter *lstm_param_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -43,21 +43,11 @@ int LSTMGradDataCPUKernel::ReSize() { return InitParam(); } | |||
| int LSTMGradDataCPUKernel::Run() { | |||
| auto ret = MallocRunBuffer(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "LstmGradDataCPUKernel MallocRunBuffer error."; | |||
| MS_LOG(ERROR) << "LSTMGradDataCPUKernel MallocRunBuffer error."; | |||
| FreeRunBuffer(); | |||
| return RET_ERROR; | |||
| } | |||
| auto output = out_tensors_.at(0); | |||
| auto output_ptr = reinterpret_cast<float *>(output->data()); | |||
| CHECK_NULL_RETURN(output_ptr); | |||
| LstmBackpropUnidirectional(output_ptr, false); | |||
| FreeRunBuffer(); | |||
| return RET_OK; | |||
| } | |||
| int LSTMGradDataCPUKernel::LstmBackpropUnidirectional(float *output, bool is_backward) { | |||
| // get input tensors | |||
| auto dC_tensor = in_tensors_.at(dC_index); | |||
| MS_ASSERT(dC_tensor != nullptr); | |||
| @@ -71,8 +61,6 @@ int LSTMGradDataCPUKernel::LstmBackpropUnidirectional(float *output, bool is_bac | |||
| MS_ASSERT(intermediate_tensor != nullptr); | |||
| auto cell_input_tensor = in_tensors_.at(cell_input_index); | |||
| MS_ASSERT(cell_input_tensor != nullptr); | |||
| // Get output tensors | |||
| auto dX_tensor = out_tensors_.at(dX_out_index); | |||
| MS_ASSERT(dX_tensor != nullptr); | |||
| auto dH_out_tensor = out_tensors_.at(dH_out_index); | |||
| @@ -80,58 +68,119 @@ int LSTMGradDataCPUKernel::LstmBackpropUnidirectional(float *output, bool is_bac | |||
| auto dC_out_tensor = out_tensors_.at(dC_out_index); | |||
| MS_ASSERT(dC_out_tensor != nullptr); | |||
| auto cell_input_data = reinterpret_cast<float *>(cell_input_tensor->data()); | |||
| auto dh_out = reinterpret_cast<float *>(dH_out_tensor->data()); | |||
| auto dc_out = reinterpret_cast<float *>(dC_out_tensor->data()); | |||
| auto intermediate_data = reinterpret_cast<float *>(intermediate_tensor->data()); | |||
| auto dC = reinterpret_cast<float *>(dC_tensor->data()); | |||
| auto dH = reinterpret_cast<float *>(dH_tensor->data()); | |||
| auto dY = reinterpret_cast<float *>(dy_tensor->data()); | |||
| auto dX = reinterpret_cast<float *>(dX_tensor->data()); | |||
| auto weights = reinterpret_cast<float *>(weights_tensor->data()); | |||
| auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| auto seq_stride = lstm_param_->seq_len_ * state_size; | |||
| float *cell_state = intermediate_data + seq_stride * 1; | |||
| float *input_gate = intermediate_data + seq_stride * 2; | |||
| float *output_gate = intermediate_data + seq_stride * 3; | |||
| float *forget_gate = intermediate_data + seq_stride * 4; | |||
| float *cell_gate = intermediate_data + seq_stride * 5; | |||
| // reorder weights only from IFGO to IOFG | |||
| ReorderLstmWeightGrad(weights_tmp_, weights); | |||
| memset(dH, 0, dH_tensor->Size()); | |||
| memset(dC, 0, dC_tensor->Size()); | |||
| // Get Tensors Data | |||
| int time_stamp_len = lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| weights_ = reinterpret_cast<float *>(weights_tensor->data()); | |||
| ReorderLstmWeightGrad(weights_tmp_, weights_); | |||
| dC_ = reinterpret_cast<float *>(dC_tensor->data()); | |||
| dH_ = reinterpret_cast<float *>(dH_tensor->data()); | |||
| dX_ = reinterpret_cast<float *>(dX_tensor->data()); | |||
| memset(dH_, 0, dH_tensor->Size()); | |||
| memset(dC_, 0, dC_tensor->Size()); | |||
| memset(dX_, 0, dX_tensor->Size()); | |||
| int w_size = lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| int h_size = lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| float *orig_da = dA_tmp_; | |||
| if (lstm_param_->bidirectional_) { | |||
| // Adjust pointer to backward cell | |||
| cell_input_data_ = reinterpret_cast<float *>(cell_input_tensor->data()) + time_stamp_len; | |||
| intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data()) + time_stamp_len; | |||
| dC_ = reinterpret_cast<float *>(dC_tensor->data()) + time_stamp_len; | |||
| dH_ = reinterpret_cast<float *>(dH_tensor->data()) + time_stamp_len; | |||
| dY_ = reinterpret_cast<float *>(dy_tensor->data()) + lstm_param_->hidden_size_; | |||
| dA_tmp_ = orig_da + lstm_param_->seq_len_ * num_of_gates * time_stamp_len; | |||
| int w_offset = num_of_gates * (w_size + h_size); | |||
| int v_offset = weight_batch_ * w_size + num_of_gates * h_size; | |||
| float *w = weights_tmp_ + w_offset; | |||
| float *v = weights_tmp_ + v_offset; | |||
| LstmBackpropUnidirectional(true, w, v); | |||
| } | |||
| // adjust to forward cell | |||
| cell_input_data_ = reinterpret_cast<float *>(cell_input_tensor->data()); | |||
| intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data()); | |||
| dC_ = reinterpret_cast<float *>(dC_tensor->data()); | |||
| dH_ = reinterpret_cast<float *>(dH_tensor->data()); | |||
| dY_ = reinterpret_cast<float *>(dy_tensor->data()); | |||
| int w_offset = 0; | |||
| int v_offset = num_of_gates * w_size; | |||
| float *w = weights_tmp_ + w_offset; | |||
| float *v = weights_tmp_ + v_offset; | |||
| dA_tmp_ = orig_da; | |||
| LstmBackpropUnidirectional(false, w, v); | |||
| // setup output tensors | |||
| dh_out_ = reinterpret_cast<float *>(dH_out_tensor->data()); | |||
| dc_out_ = reinterpret_cast<float *>(dC_out_tensor->data()); | |||
| std::copy(&(dH_[0]), &(dH_[dH_tensor->ElementsNum()]), &(dh_out_[0])); | |||
| std::copy(&(dC_[0]), &(dC_[dC_tensor->ElementsNum()]), &(dc_out_[0])); | |||
| auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_; | |||
| float *cell_state = intermediate_data_ + seq_stride * 1; | |||
| std::copy(&(dA_tmp_[0]), &(dA_tmp_[num_of_gates * seq_stride]), &(cell_state[0])); | |||
| FreeRunBuffer(); | |||
| return RET_OK; | |||
| } | |||
| int LSTMGradDataCPUKernel::LstmBackpropUnidirectional(bool is_backward, float *w, float *v) { | |||
| auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_; | |||
| int state_len = lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| float *cell_state = intermediate_data_ + seq_stride * 1; | |||
| float *input_gate = intermediate_data_ + seq_stride * 2; | |||
| float *output_gate = intermediate_data_ + seq_stride * 3; | |||
| float *forget_gate = intermediate_data_ + seq_stride * 4; | |||
| float *cell_gate = intermediate_data_ + seq_stride * 5; | |||
| int dir_mult = lstm_param_->bidirectional_ ? 2 : 1; | |||
| int prev_time_stamp_offset = (is_backward) ? 1 : -1; | |||
| int first_time_stamp = (is_backward) ? lstm_param_->seq_len_ - 1 : 0; | |||
| for (int t = lstm_param_->seq_len_ - 1; t >= 0; t--) { | |||
| int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t; | |||
| auto stride = real_t * state_size; | |||
| auto stride = real_t * lstm_param_->output_step_; | |||
| float *curr_cell_state = cell_state + stride; | |||
| float *prev_cell_state = (real_t > 0) ? cell_state + (real_t - 1) * state_size : cell_input_data; | |||
| float *prev_cell_state = (real_t == first_time_stamp) | |||
| ? cell_input_data_ | |||
| : cell_state + (real_t + prev_time_stamp_offset) * lstm_param_->output_step_; | |||
| float *curr_input_gate = input_gate + stride; | |||
| float *curr_forget_gate = forget_gate + stride; | |||
| float *curr_cell_gate = cell_gate + stride; | |||
| float *curr_output_gate = output_gate + stride; | |||
| float *curr_dx = dX + real_t * lstm_param_->batch_ * lstm_param_->input_size_; | |||
| float *curr_dy = dY + real_t * state_size; | |||
| float *curr_dx = dX_ + real_t * lstm_param_->batch_ * lstm_param_->input_size_; | |||
| int seq_offset = real_t * lstm_param_->output_step_; | |||
| for (int b = 0; b < lstm_param_->batch_; b++) { | |||
| int batch_offset = b * dir_mult * lstm_param_->hidden_size_; | |||
| float *dy = dY_ + seq_offset + batch_offset; | |||
| memcpy(curr_dy_ + b * lstm_param_->hidden_size_, dy, lstm_param_->hidden_size_ * sizeof(float)); | |||
| } | |||
| float *dA = nullptr; | |||
| LstmGradDoInputStep(curr_output_gate, curr_cell_state, prev_cell_state, curr_cell_gate, curr_input_gate, | |||
| curr_forget_gate, curr_dy, dC, dH, &dA, curr_dx, weights_tmp_, workspace_, lstm_param_); | |||
| float *dA_t = dA_tmp_ + t * num_of_gates * lstm_param_->output_step_; | |||
| std::copy(&(dA[0]), &(dA[num_of_gates * lstm_param_->output_step_]), &dA_t[0]); // for w grad step | |||
| curr_forget_gate, curr_dy_, dC_, dH_, &dA, curr_dx, w, v, workspace_, lstm_param_); | |||
| float *dA_t = dA_tmp_ + real_t * num_of_gates * state_len; | |||
| std::copy(&(dA[0]), &(dA[num_of_gates * state_len]), &dA_t[0]); // for w grad step | |||
| } | |||
| std::copy(&(dH[0]), &(dH[state_size]), &(dh_out[0])); | |||
| std::copy(&(dC[0]), &(dC[state_size]), &(dc_out[0])); | |||
| std::copy(&(dA_tmp_[0]), &(dA_tmp_[num_of_gates * lstm_param_->output_step_ * lstm_param_->seq_len_]), | |||
| &(cell_state[0])); | |||
| return RET_OK; | |||
| } | |||
| void LSTMGradDataCPUKernel::ReorderLstmWeightGrad(float *dst, float *src) { | |||
| ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIFGO()); | |||
| src += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| dst += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIFGO()); | |||
| int uni_batch = lstm_param_->bidirectional_ ? weight_batch_ / 2 : weight_batch_; | |||
| ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIFGO()); | |||
| src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIFGO()); | |||
| src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| if (lstm_param_->bidirectional_) { | |||
| ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIFGO()); | |||
| src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIFGO()); | |||
| src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| } | |||
| } | |||
| int LSTMGradDataCPUKernel::DoGrad(int thread_id) { return RET_OK; } | |||
| @@ -143,11 +192,6 @@ int LSTMGradDataCPUKernel::InitParam() { | |||
| lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT); | |||
| lstm_param_->batch_ = in_shape.at(SECOND_INPUT); | |||
| auto dy = in_tensors_.at(dy_index); | |||
| MS_ASSERT(dy != nullptr); | |||
| std::vector<int> dy_shape = dy->shape(); | |||
| lstm_param_->hidden_size_ = dy_shape.at(THIRD_INPUT); | |||
| int dir_multiplier = lstm_param_->bidirectional_ ? 2 : 1; | |||
| lstm_param_->output_step_ = dir_multiplier * lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| weight_batch_ = dir_multiplier * num_of_gates; | |||
| @@ -184,7 +228,7 @@ int LSTMGradDataCPUKernel::InitParam() { | |||
| int LSTMGradDataCPUKernel::MallocRunBuffer() { | |||
| int workspace_size = GetRunWorkspaceSize(lstm_param_); | |||
| if ((workspace_size == 0) || (workspace_size > LSTMGRADDATA_MAX_WORKSPACE_SIZE)) { | |||
| if (workspace_size == 0) { | |||
| MS_LOG(ERROR) << "LstmGradDataCPUKernel malloc run workspace 0 error."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -194,7 +238,7 @@ int LSTMGradDataCPUKernel::MallocRunBuffer() { | |||
| return RET_ERROR; | |||
| } | |||
| auto dA_size = num_of_gates * lstm_param_->output_step_ * lstm_param_->seq_len_; | |||
| if ((dA_size == 0) || (dA_size > LSTMGRADDATA_MAX_WORKSPACE_SIZE)) { | |||
| if (dA_size == 0) { | |||
| MS_LOG(ERROR) << "LstmGradDataCPUKernel malloc run dA_tmp size error."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -210,6 +254,12 @@ int LSTMGradDataCPUKernel::MallocRunBuffer() { | |||
| MS_LOG(ERROR) << "LstmGradWeightCPUKernel malloc run weights_tmp_ alloc error."; | |||
| return RET_ERROR; | |||
| } | |||
| int curr_dy_size = lstm_param_->hidden_size_ * lstm_param_->batch_; | |||
| curr_dy_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(curr_dy_size * sizeof(float))); | |||
| if (curr_dy_ == nullptr) { | |||
| MS_LOG(ERROR) << "LstmCPUKernel malloc run curr_dy_ alloc error."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -226,6 +276,10 @@ void LSTMGradDataCPUKernel::FreeRunBuffer() { | |||
| ms_context_->allocator->Free(weights_tmp_); | |||
| weights_tmp_ = nullptr; | |||
| } | |||
| if (curr_dy_ != nullptr) { | |||
| ms_context_->allocator->Free(curr_dy_); | |||
| curr_dy_ = nullptr; | |||
| } | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTMGradData, LiteKernelCreator<LSTMGradDataCPUKernel>) | |||
| @@ -23,8 +23,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr int LSTMGRADDATA_MAX_WORKSPACE_SIZE = 100000; | |||
| constexpr int LSTMGRADDATA_MAX_WEIGHTS_SIZE = 100000; | |||
| class LSTMGradDataCPUKernel : public InnerKernel { | |||
| public: | |||
| explicit LSTMGradDataCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| @@ -39,7 +37,7 @@ class LSTMGradDataCPUKernel : public InnerKernel { | |||
| int DoGrad(int thread_id); | |||
| private: | |||
| int LstmBackpropUnidirectional(float *output, bool is_backward); | |||
| int LstmBackpropUnidirectional(bool is_backward, float *w, float *v); | |||
| void ReorderLstmWeightGrad(float *dst, float *src); | |||
| int InitParam(); | |||
| @@ -71,6 +69,16 @@ class LSTMGradDataCPUKernel : public InnerKernel { | |||
| int input_thread_count_ = 0; | |||
| int input_thread_stride_ = 0; | |||
| float *curr_dy_ = nullptr; | |||
| float *weights_ = nullptr; | |||
| float *dC_ = nullptr; | |||
| float *dH_ = nullptr; | |||
| float *dX_ = nullptr; | |||
| float *dY_ = nullptr; | |||
| float *cell_input_data_ = nullptr; | |||
| float *intermediate_data_ = nullptr; | |||
| float *dh_out_ = nullptr; | |||
| float *dc_out_ = nullptr; | |||
| LstmGradParameter *lstm_param_ = nullptr; | |||
| }; | |||
| } // namespace kernel | |||
| @@ -47,12 +47,7 @@ int LSTMGradCPUKernel::Run() { | |||
| FreeRunBuffer(); | |||
| return RET_ERROR; | |||
| } | |||
| LstmBackpropUnidirectional(false); | |||
| FreeRunBuffer(); | |||
| return RET_OK; | |||
| } | |||
| int LSTMGradCPUKernel::LstmBackpropUnidirectional(bool is_backward) { | |||
| // get input tensors | |||
| auto dC_tensor = in_tensors_.at(dC_index); | |||
| MS_ASSERT(dC_tensor != nullptr); | |||
| @@ -81,58 +76,113 @@ int LSTMGradCPUKernel::LstmBackpropUnidirectional(bool is_backward) { | |||
| auto dC_out_tensor = out_tensors_.at(dC_out_index); | |||
| MS_ASSERT(dC_out_tensor != nullptr); | |||
| auto cell_input_data = reinterpret_cast<float *>(cell_input_tensor->data()); | |||
| auto hidden_input_data = reinterpret_cast<float *>(hidden_input_tensor->data()); | |||
| auto dh_out = reinterpret_cast<float *>(dH_out_tensor->data()); | |||
| auto dc_out = reinterpret_cast<float *>(dC_out_tensor->data()); | |||
| auto intermediate_data = reinterpret_cast<float *>(intermediate_tensor->data()); | |||
| auto dC = reinterpret_cast<float *>(dC_tensor->data()); | |||
| auto dH = reinterpret_cast<float *>(dH_tensor->data()); | |||
| auto dY = reinterpret_cast<float *>(dy_tensor->data()); | |||
| auto dW = reinterpret_cast<float *>(dW_tensor->data()); | |||
| auto dX = reinterpret_cast<float *>(dX_tensor->data()); | |||
| auto weights = reinterpret_cast<float *>(weights_tensor->data()); | |||
| auto input = reinterpret_cast<float *>(input_tensor->data()); | |||
| auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| auto seq_stride = lstm_param_->seq_len_ * state_size; | |||
| float *hidden_state = intermediate_data; | |||
| float *cell_state = intermediate_data + seq_stride * 1; | |||
| float *input_gate = intermediate_data + seq_stride * 2; | |||
| float *output_gate = intermediate_data + seq_stride * 3; | |||
| float *forget_gate = intermediate_data + seq_stride * 4; | |||
| float *cell_gate = intermediate_data + seq_stride * 5; | |||
| ReorderLstmWeightGrad(weights_tmp_, weights, getLstmOrderIFGO(), false); | |||
| memset(dH, 0, dH_tensor->Size()); | |||
| memset(dC, 0, dC_tensor->Size()); | |||
| memset(dW_tmp_, 0, dW_tensor->Size()); // dW_tmp is summed in the loop | |||
| float *workspace_gemm = workspace_ + GetRunWorkspaceGemmOffset(lstm_param_); | |||
| // Get Tensors Data | |||
| int time_stamp_len = lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| weights_ = reinterpret_cast<float *>(weights_tensor->data()); | |||
| ReorderLstmWeightGrad(weights_tmp_, weights_, getLstmOrderIFGO(), false); | |||
| dC_ = reinterpret_cast<float *>(dC_tensor->data()); | |||
| dH_ = reinterpret_cast<float *>(dH_tensor->data()); | |||
| dX_ = reinterpret_cast<float *>(dX_tensor->data()); | |||
| input_ = reinterpret_cast<float *>(input_tensor->data()); | |||
| memset(dH_, 0, dH_tensor->Size()); | |||
| memset(dC_, 0, dC_tensor->Size()); | |||
| memset(dX_, 0, dX_tensor->Size()); | |||
| memset(dW_tmp_, 0, dW_tensor->Size()); | |||
| if (lstm_param_->bidirectional_) { | |||
| // Adjust pointer to backward cell | |||
| cell_input_data_ = reinterpret_cast<float *>(cell_input_tensor->data()) + time_stamp_len; | |||
| hidden_input_data_ = reinterpret_cast<float *>(hidden_input_tensor->data()) + time_stamp_len; | |||
| intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data()) + time_stamp_len; | |||
| dC_ = reinterpret_cast<float *>(dC_tensor->data()) + time_stamp_len; | |||
| dH_ = reinterpret_cast<float *>(dH_tensor->data()) + time_stamp_len; | |||
| dY_ = reinterpret_cast<float *>(dy_tensor->data()) + lstm_param_->hidden_size_; | |||
| int w_offset = num_of_gates * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| int v_offset = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_ + | |||
| num_of_gates * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| int b_offset = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_ + | |||
| weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_ + | |||
| num_of_gates * lstm_param_->hidden_size_; | |||
| float *w = weights_tmp_ + w_offset; | |||
| float *v = weights_tmp_ + v_offset; | |||
| float *dw = dW_tmp_ + w_offset; | |||
| float *dv = dW_tmp_ + v_offset; | |||
| float *db = (lstm_param_->has_bias_) ? dW_tmp_ + b_offset : nullptr; | |||
| LstmBackpropUnidirectional(true, w, v, dw, dv, db); | |||
| } | |||
| // adjust to forward cell | |||
| cell_input_data_ = reinterpret_cast<float *>(cell_input_tensor->data()); | |||
| hidden_input_data_ = reinterpret_cast<float *>(hidden_input_tensor->data()); | |||
| intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data()); | |||
| dC_ = reinterpret_cast<float *>(dC_tensor->data()); | |||
| dH_ = reinterpret_cast<float *>(dH_tensor->data()); | |||
| dY_ = reinterpret_cast<float *>(dy_tensor->data()); | |||
| int w_offset = 0; | |||
| int v_offset = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| int b_offset = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_ + | |||
| weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| float *w = weights_tmp_ + w_offset; | |||
| float *v = weights_tmp_ + v_offset; | |||
| float *dw = dW_tmp_ + w_offset; | |||
| float *dv = dW_tmp_ + v_offset; | |||
| float *db = (lstm_param_->has_bias_) ? dW_tmp_ + b_offset : nullptr; | |||
| LstmBackpropUnidirectional(false, w, v, dw, dv, db); | |||
| // setup output tensors | |||
| dW_ = reinterpret_cast<float *>(dW_tensor->data()); | |||
| dh_out_ = reinterpret_cast<float *>(dH_out_tensor->data()); | |||
| dc_out_ = reinterpret_cast<float *>(dC_out_tensor->data()); | |||
| std::copy(&(dH_[0]), &(dH_[dH_tensor->ElementsNum()]), &(dh_out_[0])); | |||
| std::copy(&(dC_[0]), &(dC_[dC_tensor->ElementsNum()]), &(dc_out_[0])); | |||
| ReorderLstmWeightGrad(dW_, dW_tmp_, getLstmOrderIOFG(), lstm_param_->has_bias_); | |||
| FreeRunBuffer(); | |||
| return RET_OK; | |||
| } | |||
| int LSTMGradCPUKernel::LstmBackpropUnidirectional(bool is_backward, float *w, float *v, float *dw, float *dv, | |||
| float *db) { | |||
| auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_; | |||
| float *hidden_state = intermediate_data_; | |||
| float *cell_state = intermediate_data_ + seq_stride * 1; | |||
| float *input_gate = intermediate_data_ + seq_stride * 2; | |||
| float *output_gate = intermediate_data_ + seq_stride * 3; | |||
| float *forget_gate = intermediate_data_ + seq_stride * 4; | |||
| float *cell_gate = intermediate_data_ + seq_stride * 5; | |||
| float *workspace_gemm = workspace_ + GetRunWorkspaceGemmOffset(lstm_param_); | |||
| int dir_mult = lstm_param_->bidirectional_ ? 2 : 1; | |||
| int prev_time_stamp_offset = (is_backward) ? 1 : -1; | |||
| int first_time_stamp = (is_backward) ? lstm_param_->seq_len_ - 1 : 0; | |||
| for (int t = lstm_param_->seq_len_ - 1; t >= 0; t--) { | |||
| int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t; | |||
| auto stride = real_t * state_size; | |||
| auto stride = real_t * lstm_param_->output_step_; | |||
| float *prev_hidden_state = (real_t > 0) ? hidden_state + (real_t - 1) * state_size : hidden_input_data; | |||
| float *prev_hidden_state = (real_t == first_time_stamp) | |||
| ? hidden_input_data_ | |||
| : hidden_state + (real_t + prev_time_stamp_offset) * lstm_param_->output_step_; | |||
| float *curr_cell_state = cell_state + stride; | |||
| float *prev_cell_state = (real_t > 0) ? cell_state + (real_t - 1) * state_size : cell_input_data; | |||
| float *prev_cell_state = (real_t == first_time_stamp) | |||
| ? cell_input_data_ | |||
| : cell_state + (real_t + prev_time_stamp_offset) * lstm_param_->output_step_; | |||
| float *curr_input_gate = input_gate + stride; | |||
| float *curr_forget_gate = forget_gate + stride; | |||
| float *curr_cell_gate = cell_gate + stride; | |||
| float *curr_output_gate = output_gate + stride; | |||
| float *curr_input = input + real_t * lstm_param_->batch_ * lstm_param_->input_size_; | |||
| float *curr_dx = dX + real_t * lstm_param_->batch_ * lstm_param_->input_size_; | |||
| float *curr_dy = dY + real_t * state_size; | |||
| float *curr_input = input_ + real_t * lstm_param_->batch_ * lstm_param_->input_size_; | |||
| float *curr_dx = dX_ + real_t * lstm_param_->batch_ * lstm_param_->input_size_; | |||
| int seq_offset = real_t * lstm_param_->output_step_; | |||
| for (int b = 0; b < lstm_param_->batch_; b++) { | |||
| int batch_offset = b * dir_mult * lstm_param_->hidden_size_; | |||
| float *dy = dY_ + seq_offset + batch_offset; | |||
| memcpy(curr_dy_ + b * lstm_param_->hidden_size_, dy, lstm_param_->hidden_size_ * sizeof(float)); | |||
| } | |||
| float *dA = nullptr; | |||
| LstmGradDoInputStep(curr_output_gate, curr_cell_state, prev_cell_state, curr_cell_gate, curr_input_gate, | |||
| curr_forget_gate, curr_dy, dC, dH, &dA, curr_dx, weights_tmp_, workspace_, lstm_param_); | |||
| LstmGradDoWeightStep(curr_input, prev_hidden_state, dA, dW_tmp_, workspace_gemm, lstm_param_); | |||
| curr_forget_gate, curr_dy_, dC_, dH_, &dA, curr_dx, w, v, workspace_, lstm_param_); | |||
| LstmGradDoWeightStep(curr_input, prev_hidden_state, dA, dw, dv, db, workspace_gemm, lstm_param_); | |||
| } | |||
| std::copy(&(dH[0]), &(dH[state_size]), &(dh_out[0])); | |||
| std::copy(&(dC[0]), &(dC[state_size]), &(dc_out[0])); | |||
| ReorderLstmWeightGrad(dW, dW_tmp_, getLstmOrderIOFG(), lstm_param_->has_bias_); | |||
| return RET_OK; | |||
| } | |||
| @@ -158,11 +208,6 @@ int LSTMGradCPUKernel::InitParam() { | |||
| lstm_param_->batch_ = in_shape.at(SECOND_INPUT); | |||
| lstm_param_->input_size_ = in_shape.at(THIRD_INPUT); | |||
| auto dy = in_tensors_.at(dy_index); | |||
| MS_ASSERT(dy != nullptr); | |||
| std::vector<int> dy_shape = dy->shape(); | |||
| lstm_param_->hidden_size_ = dy_shape.at(THIRD_INPUT); | |||
| int dir_multiplier = lstm_param_->bidirectional_ ? 2 : 1; | |||
| lstm_param_->output_step_ = dir_multiplier * lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| weight_batch_ = dir_multiplier * num_of_gates; | |||
| @@ -198,7 +243,7 @@ int LSTMGradCPUKernel::InitParam() { | |||
| } | |||
| int LSTMGradCPUKernel::MallocRunBuffer() { | |||
| int workspace_size = GetRunWorkspaceSize(lstm_param_); | |||
| if ((workspace_size == 0) || (workspace_size > LSTMGRAD_MAX_WORKSPACE_SIZE)) { | |||
| if (workspace_size == 0) { | |||
| MS_LOG(ERROR) << "LstmCPUKernel malloc run workspace 0 error."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -210,7 +255,7 @@ int LSTMGradCPUKernel::MallocRunBuffer() { | |||
| auto dW_tensor = out_tensors_.at(dW_out_index); | |||
| MS_ASSERT(dW_tensor != nullptr); | |||
| auto dW_size = dW_tensor->Size(); | |||
| if ((dW_size == 0) || (dW_size > LSTMGRAD_MAX_WEIGHTS_SIZE)) { | |||
| if (dW_size == 0) { | |||
| MS_LOG(ERROR) << "LstmCPUKernel malloc run dW_tmp size error."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -226,7 +271,12 @@ int LSTMGradCPUKernel::MallocRunBuffer() { | |||
| MS_LOG(ERROR) << "LstmCPUKernel malloc run weights_tmp_ alloc error."; | |||
| return RET_ERROR; | |||
| } | |||
| int curr_dy_size = lstm_param_->hidden_size_ * lstm_param_->batch_; | |||
| curr_dy_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(curr_dy_size * sizeof(float))); | |||
| if (curr_dy_ == nullptr) { | |||
| MS_LOG(ERROR) << "LstmCPUKernel malloc run curr_dy_ alloc error."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -243,6 +293,10 @@ void LSTMGradCPUKernel::FreeRunBuffer() { | |||
| ms_context_->allocator->Free(weights_tmp_); | |||
| weights_tmp_ = nullptr; | |||
| } | |||
| if (curr_dy_ != nullptr) { | |||
| ms_context_->allocator->Free(curr_dy_); | |||
| curr_dy_ = nullptr; | |||
| } | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTMGrad, LiteKernelCreator<LSTMGradCPUKernel>) | |||
| @@ -23,8 +23,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr int LSTMGRAD_MAX_WORKSPACE_SIZE = 100000; | |||
| constexpr int LSTMGRAD_MAX_WEIGHTS_SIZE = 100000; | |||
| class LSTMGradCPUKernel : public InnerKernel { | |||
| public: | |||
| explicit LSTMGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| @@ -39,7 +37,7 @@ class LSTMGradCPUKernel : public InnerKernel { | |||
| int DoGrad(int thread_id); | |||
| private: | |||
| int LstmBackpropUnidirectional(bool is_backward); | |||
| int LstmBackpropUnidirectional(bool is_backward, float *w, float *v, float *dw, float *dv, float *db); | |||
| int InitParam(); | |||
| int MallocRunBuffer(); | |||
| @@ -73,6 +71,20 @@ class LSTMGradCPUKernel : public InnerKernel { | |||
| int input_thread_count_ = 0; | |||
| int input_thread_stride_ = 0; | |||
| float *cell_input_data_ = nullptr; | |||
| float *hidden_input_data_ = nullptr; | |||
| float *dh_out_ = nullptr; | |||
| float *dc_out_ = nullptr; | |||
| float *intermediate_data_ = nullptr; | |||
| float *dC_ = nullptr; | |||
| float *dH_ = nullptr; | |||
| float *dY_ = nullptr; | |||
| float *dW_ = nullptr; | |||
| float *dX_ = nullptr; | |||
| float *weights_ = nullptr; | |||
| float *input_ = nullptr; | |||
| float *curr_dy_ = nullptr; | |||
| LstmGradParameter *lstm_param_ = nullptr; | |||
| }; | |||
| } // namespace kernel | |||
| @@ -47,59 +47,115 @@ int LSTMGradWeightCPUKernel::Run() { | |||
| return RET_ERROR; | |||
| } | |||
| auto output = out_tensors_.at(0); | |||
| auto output_ptr = reinterpret_cast<float *>(output->data()); | |||
| CHECK_NULL_RETURN(output_ptr); | |||
| LstmBackpropUnidirectional(output_ptr, false); | |||
| FreeRunBuffer(); | |||
| return RET_OK; | |||
| } | |||
| int LSTMGradWeightCPUKernel::LstmBackpropUnidirectional(float *output, bool is_backward) { | |||
| auto dW_tensor = out_tensors_.at(dW_out_index); | |||
| MS_ASSERT(dW_tensor != nullptr); | |||
| auto intermediate_tensor = in_tensors_.at(intermediate_data_index); | |||
| MS_ASSERT(intermediate_tensor != nullptr); | |||
| auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_; | |||
| auto input_tensor = in_tensors_.at(input_index); | |||
| MS_ASSERT(input_tensor != nullptr); | |||
| auto hidden_input_tensor = in_tensors_.at(hidden_input_index); | |||
| MS_ASSERT(hidden_input_tensor != nullptr); | |||
| auto intermediate_tensor = in_tensors_.at(intermediate_data_index); | |||
| MS_ASSERT(intermediate_tensor != nullptr); | |||
| auto intermediate_data = reinterpret_cast<float *>(intermediate_tensor->data()); | |||
| auto input = reinterpret_cast<float *>(input_tensor->data()); | |||
| auto dW = reinterpret_cast<float *>(dW_tensor->data()); | |||
| auto hidden_input_data = reinterpret_cast<float *>(hidden_input_tensor->data()); | |||
| // Get output tensors | |||
| auto dW_tensor = out_tensors_.at(dW_out_index); | |||
| MS_ASSERT(dW_tensor != nullptr); | |||
| // Get Tensors Data | |||
| int time_stamp_len = lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| input_ = reinterpret_cast<float *>(input_tensor->data()); | |||
| memset(dW_tmp_, 0, dW_tensor->Size()); | |||
| int w_size = lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| int h_size = lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| int b_size = lstm_param_->hidden_size_; | |||
| if (lstm_param_->bidirectional_) { | |||
| // Adjust pointer to backward cell | |||
| hidden_input_data_ = reinterpret_cast<float *>(hidden_input_tensor->data()) + time_stamp_len; | |||
| intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data()); | |||
| dA_ = intermediate_data_ + seq_stride * 1 + lstm_param_->seq_len_ * num_of_gates * time_stamp_len; | |||
| intermediate_data_ += time_stamp_len; | |||
| int w_offset = num_of_gates * (w_size + h_size); | |||
| int v_offset = weight_batch_ * w_size + num_of_gates * h_size; | |||
| int b_offset = weight_batch_ * (w_size + h_size) + num_of_gates * b_size; | |||
| float *dw = dW_tmp_ + w_offset; | |||
| float *dv = dW_tmp_ + v_offset; | |||
| float *db = nullptr; | |||
| if (lstm_param_->has_bias_) { | |||
| db = dW_tmp_ + b_offset; | |||
| } | |||
| LstmBackpropUnidirectional(true, dw, dv, db); | |||
| } | |||
| // adjust to forward cell | |||
| hidden_input_data_ = reinterpret_cast<float *>(hidden_input_tensor->data()); | |||
| intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data()); | |||
| dA_ = intermediate_data_ + seq_stride * 1; | |||
| int w_offset = 0; | |||
| int v_offset = num_of_gates * w_size; | |||
| int b_offset = weight_batch_ * (w_size + h_size); | |||
| float *dw = dW_tmp_ + w_offset; | |||
| float *dv = dW_tmp_ + v_offset; | |||
| float *db = nullptr; | |||
| if (lstm_param_->has_bias_) { | |||
| db = dW_tmp_ + b_offset; | |||
| } | |||
| LstmBackpropUnidirectional(false, dw, dv, db); | |||
| auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| auto seq_stride = lstm_param_->seq_len_ * state_size; | |||
| float *hidden_state = intermediate_data; | |||
| float *dA = intermediate_data + seq_stride * 1; // intremidate tensor used to transfer dA data from GradData kernel | |||
| // setup output tensors | |||
| dW_ = reinterpret_cast<float *>(dW_tensor->data()); | |||
| ReorderLstmWeightGrad(dW_, dW_tmp_, lstm_param_); | |||
| FreeRunBuffer(); | |||
| return RET_OK; | |||
| } | |||
| int LSTMGradWeightCPUKernel::LstmBackpropUnidirectional(bool is_backward, float *dw, float *dv, float *db) { | |||
| float *hidden_state = intermediate_data_; | |||
| int state_len = lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| int prev_time_stamp_offset = (is_backward) ? 1 : -1; | |||
| int first_time_stamp = (is_backward) ? lstm_param_->seq_len_ - 1 : 0; | |||
| memset(dW_tmp_, 0, dW_tensor->Size()); // dW_tmp is summed in the loop | |||
| for (int t = lstm_param_->seq_len_ - 1; t >= 0; t--) { | |||
| int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t; | |||
| float *curr_input = input + real_t * lstm_param_->batch_ * lstm_param_->input_size_; | |||
| float *prev_hidden_state = (real_t > 0) ? hidden_state + (real_t - 1) * state_size : hidden_input_data; | |||
| float *curr_da = dA + real_t * num_of_gates * lstm_param_->output_step_; | |||
| LstmGradDoWeightStep(curr_input, prev_hidden_state, curr_da, dW_tmp_, workspace_, lstm_param_); | |||
| float *prev_hidden_state = (real_t == first_time_stamp) | |||
| ? hidden_input_data_ | |||
| : hidden_state + (real_t + prev_time_stamp_offset) * lstm_param_->output_step_; | |||
| float *curr_input = input_ + real_t * lstm_param_->batch_ * lstm_param_->input_size_; | |||
| float *curr_da = dA_ + real_t * num_of_gates * state_len; | |||
| LstmGradDoWeightStep(curr_input, prev_hidden_state, curr_da, dw, dv, db, workspace_, lstm_param_); | |||
| } | |||
| ReorderLstmWeightGrad(dW, dW_tmp_, lstm_param_->has_bias_); | |||
| return RET_OK; | |||
| } | |||
| void LSTMGradWeightCPUKernel::ReorderLstmWeightGrad(float *dst, float *src, bool has_bias) { | |||
| ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIOFG()); | |||
| src += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| dst += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIOFG()); | |||
| src += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| dst += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| if (has_bias) { | |||
| ReorderLstmWeights(dst, src, weight_batch_, 1, lstm_param_->hidden_size_, getLstmOrderIOFG()); | |||
| // update secend bias term (only if separate GradData and GradWeight) | |||
| dst += weight_batch_ * lstm_param_->hidden_size_; | |||
| ReorderLstmWeights(dst, src, weight_batch_, 1, lstm_param_->hidden_size_, getLstmOrderIOFG()); | |||
| void LSTMGradWeightCPUKernel::ReorderLstmWeightGrad(float *dst, float *src, LstmGradParameter *param) { | |||
| int uni_batch = param->bidirectional_ ? weight_batch_ / 2 : weight_batch_; | |||
| // 4xWixWh,4xWirxWhr,4xBiBh,4xBirBhr | |||
| ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIOFG()); | |||
| src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIOFG()); | |||
| src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| if (param->bidirectional_) { | |||
| ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIOFG()); | |||
| src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_; | |||
| ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIOFG()); | |||
| src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; | |||
| } | |||
| if (param->has_bias_) { | |||
| ReorderLstmWeights(dst, src, uni_batch, 1, lstm_param_->hidden_size_, getLstmOrderIOFG()); | |||
| dst += uni_batch * lstm_param_->hidden_size_; | |||
| ReorderLstmWeights(dst, src, uni_batch, 1, lstm_param_->hidden_size_, getLstmOrderIOFG()); | |||
| dst += uni_batch * lstm_param_->hidden_size_; | |||
| src += uni_batch * lstm_param_->hidden_size_; | |||
| if (param->bidirectional_) { | |||
| ReorderLstmWeights(dst, src, uni_batch, 1, lstm_param_->hidden_size_, getLstmOrderIOFG()); | |||
| dst += uni_batch * lstm_param_->hidden_size_; | |||
| ReorderLstmWeights(dst, src, uni_batch, 1, lstm_param_->hidden_size_, getLstmOrderIOFG()); | |||
| dst += uni_batch * lstm_param_->hidden_size_; | |||
| src += uni_batch * lstm_param_->hidden_size_; | |||
| } | |||
| } | |||
| } | |||
| @@ -112,10 +168,6 @@ int LSTMGradWeightCPUKernel::InitParam() { | |||
| lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT); | |||
| lstm_param_->batch_ = in_shape.at(SECOND_INPUT); | |||
| lstm_param_->input_size_ = in_shape.at(THIRD_INPUT); | |||
| auto y = in_tensors_.at(y_index); | |||
| MS_ASSERT(y != nullptr); | |||
| std::vector<int> y_shape = y->shape(); | |||
| lstm_param_->hidden_size_ = y_shape.at(THIRD_INPUT); | |||
| int dir_multiplier = lstm_param_->bidirectional_ ? 2 : 1; | |||
| lstm_param_->output_step_ = dir_multiplier * lstm_param_->batch_ * lstm_param_->hidden_size_; | |||
| @@ -153,7 +205,7 @@ int LSTMGradWeightCPUKernel::InitParam() { | |||
| int LSTMGradWeightCPUKernel::MallocRunBuffer() { | |||
| int workspace_size = GetRunWorkspaceSize(lstm_param_); | |||
| if ((workspace_size == 0) || (workspace_size > LSTMGRADWEIGHT_MAX_WORKSPACE_SIZE)) { | |||
| if (workspace_size == 0) { | |||
| MS_LOG(ERROR) << "LstmGradWeightCPUKernel malloc run workspace 0 error."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -166,7 +218,7 @@ int LSTMGradWeightCPUKernel::MallocRunBuffer() { | |||
| auto dW_tensor = out_tensors_.at(dW_out_index); | |||
| MS_ASSERT(dW_tensor != nullptr); | |||
| auto dW_size = dW_tensor->Size(); | |||
| if ((dW_size == 0) || (dW_size > LSTMGRADWEIGHT_MAX_WEIGHTS_SIZE)) { | |||
| if (dW_size == 0) { | |||
| MS_LOG(ERROR) << "LstmGradWeightCPUKernel malloc run dW_tmp size error."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -23,8 +23,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr int LSTMGRADWEIGHT_MAX_WORKSPACE_SIZE = 100000; | |||
| constexpr int LSTMGRADWEIGHT_MAX_WEIGHTS_SIZE = 100000; | |||
| class LSTMGradWeightCPUKernel : public InnerKernel { | |||
| public: | |||
| explicit LSTMGradWeightCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| @@ -39,12 +37,12 @@ class LSTMGradWeightCPUKernel : public InnerKernel { | |||
| int DoGrad(int thread_id); | |||
| private: | |||
| int LstmBackpropUnidirectional(float *output, bool is_backward); | |||
| int LstmBackpropUnidirectional(bool is_backward, float *dw, float *dv, float *db); | |||
| int InitParam(); | |||
| int MallocRunBuffer(); | |||
| void FreeRunBuffer(); | |||
| void ReorderLstmWeightGrad(float *dst, float *src, bool has_bias); | |||
| void ReorderLstmWeightGrad(float *dst, float *src, LstmGradParameter *param); | |||
| static const int input_index = 0; | |||
| static const int hidden_input_index = 1; | |||
| @@ -65,7 +63,11 @@ class LSTMGradWeightCPUKernel : public InnerKernel { | |||
| bool state_is_vec_ = false; | |||
| int input_thread_count_ = 0; | |||
| int input_thread_stride_ = 0; | |||
| float *input_ = nullptr; | |||
| float *hidden_input_data_ = nullptr; | |||
| float *intermediate_data_ = nullptr; | |||
| float *dW_ = nullptr; | |||
| float *dA_ = nullptr; | |||
| LstmGradParameter *lstm_param_ = nullptr; | |||
| }; | |||
| } // namespace kernel | |||
| @@ -526,6 +526,7 @@ OpParameter *PopulateLstmGradParameter(const void *prim) { | |||
| param->zoneout_hidden_ = value->zoneout_hidden(); | |||
| param->input_size_ = value->input_size(); | |||
| param->has_bias_ = value->has_bias(); | |||
| param->hidden_size_ = value->hidden_size(); | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| @@ -552,6 +553,7 @@ OpParameter *PopulateLstmGradDataParameter(const void *prim) { | |||
| param->zoneout_hidden_ = value->zoneout_hidden(); | |||
| param->input_size_ = value->input_size(); | |||
| param->has_bias_ = value->has_bias(); | |||
| param->hidden_size_ = value->hidden_size(); | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| @@ -573,7 +575,7 @@ OpParameter *PopulateLstmGradWeightParameter(const void *prim) { | |||
| param->op_parameter_.type_ = primitive->value_type(); | |||
| param->input_size_ = value->input_size(); | |||
| param->hidden_size_ = value->hidden_size(); // output_size | |||
| param->hidden_size_ = value->hidden_size(); | |||
| param->bidirectional_ = value->bidirectional(); | |||
| param->zoneout_cell_ = value->zoneout_cell(); | |||
| param->zoneout_hidden_ = value->zoneout_hidden(); | |||