diff --git a/mindspore/lite/nnacl/fp32/gru_fp32.c b/mindspore/lite/nnacl/fp32/gru_fp32.c index 164e9cc195..8e147e495f 100644 --- a/mindspore/lite/nnacl/fp32/gru_fp32.c +++ b/mindspore/lite/nnacl/fp32/gru_fp32.c @@ -18,115 +18,126 @@ #include "nnacl/fp32/lstm_fp32.h" #include "nnacl/fp32/activation_fp32.h" #include "nnacl/fp32/arithmetic_fp32.h" - -void InitGruGate(float *gate_buffer, const float *bias, const GruParameter *gru_parm) { - int gate_offest = 0; - for (int l = 0; l < 3; l++) { - int batch_offest = gate_offest; - int bias_offest = l * gru_parm->hidden_size_; - for (int b = 0; b < gru_parm->batch_; b++) { - memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float)); - batch_offest += gru_parm->hidden_size_; - } - gate_offest += gru_parm->batch_ * gru_parm->hidden_size_; +#include "nnacl/fp32/matmul_fp32.h" + +void UpdateGruInputGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, + int deep, int col, int col_align, bool is_vec) { + for (int i = 0; i < 3; i++) { + const float *weight_i = weight + deep * col * i; + const float *bias_i = bias + col_align * i; + float *gate = gate_buffer + row * col * i; + LstmMatMul(gate, input, weight_i, bias_i, row, deep, col, is_vec); } } -void GruStepUnit(float *output, const float *input, const float *input_reset_weight, const float *input_update_weight, - const float *input_hidden_weight, const float *state_reset_weight, const float *state_update_weight, - const float *state_hidden_weight, const float *bias, float *hidden_state, float *gate_buffer, - const GruParameter *gru_parm) { - InitGruGate(gate_buffer, bias, gru_parm); - - float *update_gate = gate_buffer; - float *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_; - float *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2; - +void GruStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight, + const float *bias, float *hidden_state, float *gate_buffer, float *matmul_buffer[2], + const GruParameter *gru_param) { + bool is_vec = gru_param->batch_ == 1; // input * weight - MatMulAcc(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); - MatMulAcc(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); - MatMulAcc(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); + if (is_vec) { + UpdateGruInputGate(gate_buffer, input, input_weight, bias, gru_param->batch_, gru_param->input_size_, + gru_param->hidden_size_, gru_param->col_align_, is_vec); + } else { + // pack input for matmul + PackLstmInput(matmul_buffer[0], input, gru_param->batch_, gru_param->input_size_); + UpdateGruInputGate(gate_buffer, matmul_buffer[0], input_weight, bias, gru_param->batch_, gru_param->input_size_, + gru_param->hidden_size_, gru_param->col_align_, is_vec); + } + + const float *state_update_weight = state_weight; + const float *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_; + const float *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2; + float *state_update_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 3; + float *state_reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 4; + float *state_hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 5; + const float *state_update_bias = bias + gru_param->hidden_size_ * 3; + const float *state_reset_bias = bias + gru_param->hidden_size_ * 4; + const float *state_hidden_bias = bias + gru_param->hidden_size_ * 5; // state * weight - MatMulAcc(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, - gru_parm->hidden_size_); - MatMulAcc(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_, - gru_parm->hidden_size_); + if (is_vec) { + LstmMatMul(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + LstmMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + PackLstmInput(matmul_buffer[1], hidden_state, gru_param->batch_, gru_param->hidden_size_); + LstmMatMul(state_reset_gate, matmul_buffer[1], state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + LstmMatMul(state_update_gate, matmul_buffer[1], state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAdd(gate_buffer, state_update_gate, gate_buffer, gru_param->batch_ * gru_param->hidden_size_ * 2); + float *update_gate = gate_buffer; + float *reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_; + float *hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 2; // update reset_gate - Sigmoid(reset_gate, gru_parm->batch_ * gru_parm->hidden_size_, reset_gate); - + Sigmoid(reset_gate, gru_param->batch_ * gru_param->hidden_size_, reset_gate); // update update_gate - Sigmoid(update_gate, gru_parm->batch_ * gru_parm->hidden_size_, update_gate); - - ElementMul(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_); - MatMulAcc(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, - gru_parm->hidden_size_); + Sigmoid(update_gate, gru_param->batch_ * gru_param->hidden_size_, update_gate); + + ElementMul(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); + if (is_vec) { + LstmMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + PackLstmInput(matmul_buffer[1], reset_gate, gru_param->batch_, gru_param->hidden_size_); + LstmMatMul(state_hidden_buffer, matmul_buffer[1], state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAdd(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); - Tanh(hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_, hidden_buffer); + Tanh(hidden_buffer, gru_param->batch_ * gru_param->hidden_size_, hidden_buffer); - ElementMul(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); + ElementMul(update_gate, hidden_state, hidden_state, gru_param->batch_ * gru_param->hidden_size_); ArithmeticParameter parameter; parameter.in_elements_num0_ = 1; - parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_; + parameter.in_elements_num1_ = gru_param->batch_ * gru_param->hidden_size_; const float one = 1.0f; - ElementOptSub(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, ¶meter); + ElementOptSub(&one, update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_, ¶meter); - ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); + ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_param->batch_ * gru_param->hidden_size_); - memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float)); + memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float)); } void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias, - float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm) { + float *hidden_state, float *gate_buffer, float *matmul_buffer[2], int check_seq_len, + const GruParameter *gru_param) { // forward - const float *input_update_weight = weight_g; - const float *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_; - const float *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2; - - const float *state_update_weight = weight_r; - const float *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_; - const float *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2; - for (int t = 0; t < check_seq_len; t++) { - const float *input_ptr = input + t * gru_parm->input_step_; - float *output_ptr = output + t * gru_parm->output_step_; - GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, state_reset_weight, - state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer, gru_parm); + const float *input_ptr = input + t * gru_param->input_step_; + float *output_ptr = output + t * gru_param->output_step_; + GruStepUnit(output_ptr, input_ptr, weight_g, weight_r, bias, hidden_state, gate_buffer, matmul_buffer, gru_param); } // zero out extra fw outputs - for (int t = check_seq_len; t < gru_parm->seq_len_; t++) { - float *output_ptr = output + t * gru_parm->output_step_; - for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { + for (int t = check_seq_len; t < gru_param->seq_len_; t++) { + float *output_ptr = output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { output_ptr[i] = 0.0f; } } // backward - if (gru_parm->bidirectional_) { - input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3; - input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4; - input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5; - - state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3; - state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4; - state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5; - - float *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_; - const float *backward_bias = bias + 3 * gru_parm->hidden_size_; - float *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_; + if (gru_param->bidirectional_) { + const float *backward_weight_g = weight_g + 3 * gru_param->col_align_ * gru_param->input_size_; + const float *backward_weight_r = weight_r + 3 * gru_param->col_align_ * gru_param->hidden_size_; + const float *backward_bias = bias + 6 * gru_param->hidden_size_; + float *backward_output = output + gru_param->batch_ * gru_param->hidden_size_; + float *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_; for (int t = check_seq_len - 1; t >= 0; t--) { - const float *input_ptr = input + t * gru_parm->input_step_; - float *output_ptr = backward_output + t * gru_parm->output_step_; - GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, - state_reset_weight, state_update_weight, state_hidden_weight, backward_bias, backward_hidden_state, - gate_buffer, gru_parm); + const float *input_ptr = input + t * gru_param->input_step_; + float *output_ptr = backward_output + t * gru_param->output_step_; + GruStepUnit(output_ptr, input_ptr, backward_weight_g, backward_weight_r, backward_bias, backward_hidden_state, + gate_buffer, matmul_buffer, gru_param); } // zero out extra bw outputs - for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) { - float *output_ptr = backward_output + t * gru_parm->output_step_; - for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { + for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) { + float *output_ptr = backward_output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { output_ptr[i] = 0.0f; } } diff --git a/mindspore/lite/nnacl/fp32/gru_fp32.h b/mindspore/lite/nnacl/fp32/gru_fp32.h index a9fc4d2555..69ddd23bf0 100644 --- a/mindspore/lite/nnacl/fp32/gru_fp32.h +++ b/mindspore/lite/nnacl/fp32/gru_fp32.h @@ -21,7 +21,8 @@ extern "C" { #endif void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias, - float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm); + float *hidden_state, float *gate_buffer, float *matmul_buffer[2], int check_seq_len, + const GruParameter *gru_parm); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.c b/mindspore/lite/nnacl/fp32/lstm_fp32.c index cd1f12007c..13c416f59e 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.c +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.c @@ -21,6 +21,30 @@ #include "nnacl/fp32/arithmetic_fp32.h" #include "nnacl/fp32/matmul_fp32.h" +void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align) { + for (int i = 0; i < batch; i++) { + const float *src_batch = src + i * col * deep; + float *dst_batch = dst + i * col_align * deep; +#ifdef ENABLE_AVX + RowMajor2Col16Major(src_batch, dst_batch, col, deep); +#elif defined(ENABLE_ARM32) + RowMajor2Col4Major(src_batch, dst_batch, col, deep); +#else + RowMajor2Col8Major(src_batch, dst_batch, col, deep); +#endif + } +} + +void PackLstmInput(float *dst, const float *src, int row, int deep) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src, dst, row, deep); +#elif defined(ENABLE_SSE) + RowMajor2Col4Major(src, dst, row, deep); +#else + RowMajor2Col12Major(src, dst, row, deep); +#endif +} + // input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) { for (int r = 0; r < rows; r++) { @@ -52,6 +76,15 @@ void MatMulAcc(float *output, const float *input, const float *weight, int rows, } } +void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) { + if (is_vec) { + memcpy(c, bias, col * sizeof(float)); + MatMulAcc(c, a, b, row, col, deep); + } else { + MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} + void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size) { int index = 0; #ifdef ENABLE_ARM @@ -121,74 +154,42 @@ void UpdataOutput(const float *cell_state, const float *output_gate, float *hidd } } -void LstmMatmul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) { - if (is_vec) { - memcpy(c, bias, col * sizeof(float)); - MatMulAcc(c, a, b, row, col, deep); - } else { - MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); +void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep, + int col, int col_align, bool is_vec) { + for (int i = 0; i < 4; i++) { + const float *weight_i = weight + deep * col * i; + const float *bias_i = bias + col_align * i; + float *gate = gate_buffer + row * col * i; + LstmMatMul(gate, input, weight_i, bias_i, row, deep, col, is_vec); } } -void PackLstmInput(float *dst, const float *src, int row, int deep) { -#ifdef ENABLE_AVX - RowMajor2Col6Major(src, dst, row, deep); -#elif defined(ENABLE_SSE) - RowMajor2Col4Major(src, dst, row, deep); -#else - RowMajor2Col12Major(src, dst, row, deep); -#endif -} - -void UpdateGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep, - int col, int col_align, bool is_vec) { - const float *input_weight = weight; - const float *forget_weight = weight + deep * col * 2; - const float *cell_weight = weight + deep * col * 3; - const float *output_weight = weight + deep * col; - - const float *input_bias = bias; - const float *forget_bias = bias + col_align * 2; - const float *cell_bias = bias + col_align * 3; - const float *output_bias = bias + col_align; - - float *input_gate = gate_buffer; - float *forget_gate = gate_buffer + row * col * 2; - float *cell_gate = gate_buffer + row * col * 3; - float *output_gate = gate_buffer + row * col; - - LstmMatmul(input_gate, input, input_weight, input_bias, row, deep, col, is_vec); - LstmMatmul(forget_gate, input, forget_weight, forget_bias, row, deep, col, is_vec); - LstmMatmul(cell_gate, input, cell_weight, cell_bias, row, deep, col, is_vec); - LstmMatmul(output_gate, input, output_weight, output_bias, row, deep, col, is_vec); -} - void LstmStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight, const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2], const LstmParameter *lstm_param) { bool is_vec = lstm_param->batch_ == 1; // input * weight if (is_vec) { - UpdateGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_, - lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + UpdateLstmGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_, + lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } else { // pack input for matmul PackLstmInput(matmul_buffer[0], input, lstm_param->batch_, lstm_param->input_size_); - UpdateGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_, - lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + UpdateLstmGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_, + lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } // state * weight float *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4; const float *state_bias = bias + lstm_param->col_align_ * 4; if (is_vec) { - UpdateGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, - lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } else { // pack state for matmul PackLstmInput(matmul_buffer[1], hidden_state, lstm_param->batch_, lstm_param->hidden_size_); - UpdateGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, - lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + UpdateLstmGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } ElementAdd(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_); diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.h b/mindspore/lite/nnacl/fp32/lstm_fp32.h index 709a62a6fa..5e2d3d6176 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.h +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.h @@ -21,7 +21,11 @@ #ifdef __cplusplus extern "C" { #endif -void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size); +void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align); + +void PackLstmInput(float *dst, const float *src, int row, int deep); + +void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec); void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size); diff --git a/mindspore/lite/nnacl/gru_parameter.h b/mindspore/lite/nnacl/gru_parameter.h index cbd85e1c3f..29ebdbdc12 100644 --- a/mindspore/lite/nnacl/gru_parameter.h +++ b/mindspore/lite/nnacl/gru_parameter.h @@ -30,6 +30,8 @@ typedef struct GruParameter { int input_step_; int output_step_; bool bidirectional_; + int col_align_; + int row_align_; } GruParameter; #endif // MINDSPORE_LITE_NNACL_GRU_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc index b623275215..c99a62de4f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc @@ -19,6 +19,7 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" #include "nnacl/fp32/gru_fp32.h" +#include "nnacl/fp32/lstm_fp32.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -28,82 +29,109 @@ using mindspore::schema::PrimitiveType_Gru; namespace mindspore::kernel { void GruCPUKernel::FreeTmpBuffer() { - if (gate_buffer_ != nullptr) { - free(gate_buffer_); - gate_buffer_ = nullptr; + if (!is_vec_) { + if (weight_g_ptr_ != nullptr) { + free(weight_g_ptr_); + weight_g_ptr_ = nullptr; + } + if (weight_r_ptr_ != nullptr) { + free(weight_r_ptr_); + weight_r_ptr_ = nullptr; + } + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; + } } - if (bias_ptr_ != nullptr) { - free(bias_ptr_); - bias_ptr_ = nullptr; +} + +void GruCPUKernel::FreeRunBuffer() { + context_->allocator->Free(gate_buffer_); + if (!is_vec_) { + for (int i = 0; i < 2; i++) { + context_->allocator->Free(matmul_buffer_[i]); + } } - weight_g_ptr_ = nullptr; - weight_r_ptr_ = nullptr; } int GruCPUKernel::InitParam() { auto input = in_tensors_.front(); MS_ASSERT(input != nullptr); std::vector in_shape = input->shape(); - gru_parm_->seq_len_ = in_shape.at(0); - gru_parm_->batch_ = in_shape.at(1); - gru_parm_->input_size_ = in_shape.at(2); + gru_param_->seq_len_ = in_shape.at(0); + gru_param_->batch_ = in_shape.at(1); + gru_param_->input_size_ = in_shape.at(2); auto weight_g = in_tensors_.at(1); MS_ASSERT(weight_g != nullptr); std::vector w_shape = weight_g->shape(); - gru_parm_->hidden_size_ = w_shape.at(1) / 3; - - gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_; - gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_ - : gru_parm_->batch_ * gru_parm_->hidden_size_; - return RET_OK; -} - -int GruCPUKernel::InitBuffer() { - gate_buffer_ = reinterpret_cast(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float))); - if (gate_buffer_ == nullptr) { - MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error."; - return RET_ERROR; - } + gru_param_->hidden_size_ = w_shape.at(1) / 3; + + gru_param_->input_step_ = gru_param_->batch_ * gru_param_->input_size_; + gru_param_->output_step_ = gru_param_->bidirectional_ ? 2 * gru_param_->batch_ * gru_param_->hidden_size_ + : gru_param_->batch_ * gru_param_->hidden_size_; + +#ifdef ENABLE_AVX + row_tile_ = C6NUM; + col_tile_ = C16NUM; +#elif defined(ENABLE_ARM32) + row_tile_ = C12NUM; + col_tile_ = C4NUM; +#elif defined(ENABLE_SSE) + row_tile_ = C4NUM; + col_tile_ = C8NUM; +#else + row_tile_ = C12NUM; + col_tile_ = C8NUM; +#endif + is_vec_ = gru_param_->batch_ == 1; + gru_param_->row_align_ = is_vec_ ? 1 : UP_ROUND(gru_param_->batch_, row_tile_); + gru_param_->col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, col_tile_); return RET_OK; } int GruCPUKernel::InitWeightBias() { - auto weight_gate = in_tensors_.at(1); - MS_ASSERT(weight_gate != nullptr); - weight_g_ptr_ = reinterpret_cast(malloc(weight_gate->ElementsNum() * sizeof(float))); - if (weight_g_ptr_ == nullptr) { - MS_LOG(ERROR) << "GruCPUKernel malloc weight_g_ptr_ error."; - return RET_ERROR; - } - memcpy(weight_g_ptr_, weight_gate->data_c(), weight_gate->ElementsNum() * sizeof(float)); - - auto weight_recu = in_tensors_.at(2); - MS_ASSERT(weight_recu != nullptr); - weight_r_ptr_ = reinterpret_cast(malloc(weight_recu->ElementsNum() * sizeof(float))); - if (weight_r_ptr_ == nullptr) { - MS_LOG(ERROR) << "GruCPUKernel malloc weight_r_ptr_ error."; - return RET_ERROR; - } - memcpy(weight_r_ptr_, weight_recu->data_c(), weight_recu->ElementsNum() * sizeof(float)); - - int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_; - bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float))); - if (bias_ptr_ == nullptr) { - MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error."; - return RET_ERROR; - } - - auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); - const int state_bias_offset = 3 * gru_parm_->hidden_size_; - for (int i = 0; i < state_bias_offset; i++) { - bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; - } - if (gru_parm_->bidirectional_) { - bias_data += 3 * gru_parm_->hidden_size_ * 2; - auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_; - for (int i = 0; i < state_bias_offset; i++) { - backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset]; + auto weight_batch = gru_param_->bidirectional_ ? 6 : 3; + if (!is_vec_) { + // malloc and init input * weight right matrix buffer + auto weight_g = in_tensors_.at(1); + MS_ASSERT(weight_g != nullptr); + weight_g_ptr_ = reinterpret_cast( + malloc(weight_batch * gru_param_->col_align_ * gru_param_->input_size_ * sizeof(float))); + if (weight_g_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc weight_g_ptr_ error."; + return RET_ERROR; + } + auto weight_i_data = reinterpret_cast(weight_g->data_c()); + PackLstmWeight(weight_g_ptr_, weight_i_data, weight_batch, gru_param_->input_size_, gru_param_->hidden_size_, + gru_param_->col_align_); + + // malloc and init state * weight right matrix buffer + auto weight_r = in_tensors_.at(2); + MS_ASSERT(weight_r != nullptr); + weight_r_ptr_ = reinterpret_cast( + malloc(weight_batch * gru_param_->col_align_ * gru_param_->hidden_size_ * sizeof(float))); + if (weight_r_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc weight_r_ptr_ error."; + return RET_ERROR; + } + auto weight_r_data = reinterpret_cast(weight_r->data_c()); + PackLstmWeight(weight_r_ptr_, weight_r_data, weight_batch, gru_param_->hidden_size_, gru_param_->hidden_size_, + gru_param_->col_align_); + + // init bias + int bias_batch = gru_param_->bidirectional_ ? 16 : 8; + bias_ptr_ = reinterpret_cast(malloc(bias_batch * gru_param_->col_align_ * sizeof(float))); + if (bias_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error."; + return RET_ERROR; + } + memset(bias_ptr_, 0, bias_batch * gru_param_->col_align_ * sizeof(float)); + auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); + for (int i = 0; i < bias_batch; i++) { + auto src_batch = bias_data + i * gru_param_->hidden_size_; + auto dst_batch = bias_ptr_ + i * gru_param_->col_align_; + memcpy(dst_batch, src_batch, gru_param_->hidden_size_ * sizeof(float)); } } return RET_OK; @@ -117,24 +145,42 @@ int GruCPUKernel::Init() { } int GruCPUKernel::ReSize() { - FreeTmpBuffer(); auto ret = InitParam(); if (ret != RET_OK) { MS_LOG(ERROR) << "GruCPUKernel InitParam error."; return RET_ERROR; } + FreeTmpBuffer(); ret = InitWeightBias(); if (ret != RET_OK) { MS_LOG(ERROR) << "GruCPUKernel InitWeightBias error."; FreeTmpBuffer(); return RET_ERROR; } + return RET_OK; +} - ret = InitBuffer(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "GruCPUKernel InitBuffer error."; - FreeTmpBuffer(); +int GruCPUKernel::MallocRunBuffer() { + if (!is_vec_) { + matmul_buffer_[0] = reinterpret_cast( + context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->input_size_ * sizeof(float))); + if (matmul_buffer_[0] == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc input * weight left matirx error."; + return RET_ERROR; + } + + matmul_buffer_[1] = reinterpret_cast( + context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->hidden_size_ * sizeof(float))); + if (matmul_buffer_[1] == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc state * weight left matirx error."; + return RET_ERROR; + } + } + gate_buffer_ = reinterpret_cast( + context_->allocator->Malloc(6 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float))); + if (gate_buffer_ == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error."; return RET_ERROR; } return RET_OK; @@ -153,22 +199,35 @@ int GruCPUKernel::Run() { MS_ASSERT(output_ptr); auto output_hidden_state = out_tensors_[1]; memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float)); - int check_seq_len = gru_parm_->seq_len_; + int check_seq_len = gru_param_->seq_len_; + if (in_tensors_.size() == 6) { auto seq_len = reinterpret_cast(in_tensors_.at(5)->data_c()); - if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) { + if (!std::equal(seq_len + 1, seq_len + gru_param_->batch_, seq_len)) { MS_LOG(ERROR) << "different batch seq_len is currently not supported"; return RET_ERROR; } check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0])); } + auto ret = MallocRunBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruCPUKernel MallocRunBuffer error."; + return RET_ERROR; + } + if (is_vec_) { + weight_g_ptr_ = reinterpret_cast(in_tensors_[1]->data_c()); + weight_r_ptr_ = reinterpret_cast(in_tensors_[2]->data_c()); + bias_ptr_ = reinterpret_cast(in_tensors_[3]->data_c()); + } MS_ASSERT(weight_g_ptr_ != nullptr); MS_ASSERT(weight_r_ptr_ != nullptr); MS_ASSERT(bias_ptr_ != nullptr); MS_ASSERT(gate_buffer_ != nullptr); Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_, - reinterpret_cast(output_hidden_state->data_c()), gate_buffer_, check_seq_len, gru_parm_); + reinterpret_cast(output_hidden_state->data_c()), gate_buffer_, matmul_buffer_, check_seq_len, + gru_param_); + FreeRunBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h index ee661d9d25..53ebe7a8b1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h @@ -26,7 +26,7 @@ class GruCPUKernel : public LiteKernel { const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) { - gru_parm_ = reinterpret_cast(op_parameter_); + gru_param_ = reinterpret_cast(op_parameter_); } ~GruCPUKernel() override { FreeTmpBuffer(); } @@ -37,15 +37,20 @@ class GruCPUKernel : public LiteKernel { private: void FreeTmpBuffer(); + void FreeRunBuffer(); int InitParam(); - int InitBuffer(); + int MallocRunBuffer(); int InitWeightBias(); float *gate_buffer_ = nullptr; float *weight_g_ptr_ = nullptr; float *weight_r_ptr_ = nullptr; float *bias_ptr_ = nullptr; - GruParameter *gru_parm_ = nullptr; + float *matmul_buffer_[2]; + int row_tile_ = 0; + int col_tile_ = 0; + bool is_vec_ = false; + GruParameter *gru_param_ = nullptr; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc index 828dd14f04..b036ba9181 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc @@ -57,24 +57,8 @@ void LstmCPUKernel::FreeRunBuffer() { } } -int InitRightMatrix(float *dst, const float *src, int batch, int deep, int col, int col_align, bool is_vec) { - for (int i = 0; i < batch; i++) { - auto src_batch = src + i * col * deep; - auto dst_batch = dst + i * col_align * deep; -#ifdef ENABLE_AVX - RowMajor2Col16Major(src_batch, dst_batch, col, deep); -#elif defined(ENABLE_ARM32) - RowMajor2Col4Major(src_batch, dst_batch, col, deep); -#else - RowMajor2Col8Major(src_batch, dst_batch, col, deep); -#endif - } - return RET_OK; -} - int LstmCPUKernel::InitWeightBias() { auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4; - if (!is_vec_) { // malloc and init input * weight right matrix buffer auto weight_i = in_tensors_.at(1); @@ -86,8 +70,8 @@ int LstmCPUKernel::InitWeightBias() { return RET_ERROR; } auto weight_i_data = reinterpret_cast(weight_i->data_c()); - InitRightMatrix(weight_i_ptr_, weight_i_data, weight_batch, lstm_param_->input_size_, lstm_param_->hidden_size_, - lstm_param_->col_align_, is_vec_); + PackLstmWeight(weight_i_ptr_, weight_i_data, weight_batch, lstm_param_->input_size_, lstm_param_->hidden_size_, + lstm_param_->col_align_); // malloc and init state * weight right matrix buffer auto weight_h = in_tensors_.at(2); @@ -99,8 +83,8 @@ int LstmCPUKernel::InitWeightBias() { return RET_ERROR; } auto weight_h_data = reinterpret_cast(weight_h->data_c()); - InitRightMatrix(weight_h_ptr_, weight_h_data, weight_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, - lstm_param_->col_align_, is_vec_); + PackLstmWeight(weight_h_ptr_, weight_h_data, weight_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, + lstm_param_->col_align_); // init bias int bias_batch = lstm_param_->bidirectional_ ? 16 : 8; @@ -235,7 +219,7 @@ int LstmCPUKernel::Run() { auto ret = MallocRunBuffer(); if (ret != RET_OK) { - MS_LOG(ERROR) << "LstmCPUKernel InitRunBuffer error."; + MS_LOG(ERROR) << "LstmCPUKernel MallocRunBuffer error."; return RET_ERROR; } @@ -244,7 +228,6 @@ int LstmCPUKernel::Run() { weight_h_ptr_ = reinterpret_cast(in_tensors_[2]->data_c()); bias_ptr_ = reinterpret_cast(in_tensors_[3]->data_c()); } - MS_ASSERT(weight_h_ptr_); MS_ASSERT(weight_i_ptr_); MS_ASSERT(bias_ptr_);