From: @yangruoqi713 Reviewed-by: @zhang_xue_tong Signed-off-by: @zhang_xue_tongtags/v1.2.0-rc1
| @@ -19,20 +19,7 @@ | |||||
| #include <float.h> | #include <float.h> | ||||
| #include "nnacl/fp32/activation_fp32.h" | #include "nnacl/fp32/activation_fp32.h" | ||||
| #include "nnacl/fp32/arithmetic_fp32.h" | #include "nnacl/fp32/arithmetic_fp32.h" | ||||
| #include "nnacl/fp32/mul_fp32.h" | |||||
| void InitGate(float *gate_buffer, const float *bias, const LstmParameter *lstm_parm) { | |||||
| int gate_offest = 0; | |||||
| for (int l = 0; l < 4; l++) { | |||||
| int batch_offest = gate_offest; | |||||
| int bias_offest = l * lstm_parm->hidden_size_; | |||||
| for (int b = 0; b < lstm_parm->batch_; b++) { | |||||
| memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float)); | |||||
| batch_offest += lstm_parm->hidden_size_; | |||||
| } | |||||
| gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_; | |||||
| } | |||||
| } | |||||
| #include "nnacl/fp32/matmul_fp32.h" | |||||
| // input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] | // input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] | ||||
| void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) { | void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) { | ||||
| @@ -134,106 +121,131 @@ void UpdataOutput(const float *cell_state, const float *output_gate, float *hidd | |||||
| } | } | ||||
| } | } | ||||
| void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight, | |||||
| const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight, | |||||
| const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight, | |||||
| const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, | |||||
| const LstmParameter *lstm_parm) { | |||||
| InitGate(gate_buffer, bias, lstm_parm); | |||||
| void LstmMatmul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) { | |||||
| if (is_vec) { | |||||
| memcpy(c, bias, col * sizeof(float)); | |||||
| MatMulAcc(c, a, b, row, col, deep); | |||||
| } else { | |||||
| MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); | |||||
| } | |||||
| } | |||||
| void PackLstmInput(float *dst, const float *src, int row, int deep) { | |||||
| #ifdef ENABLE_AVX | |||||
| RowMajor2Col6Major(src, dst, row, deep); | |||||
| #elif defined(ENABLE_SSE) | |||||
| RowMajor2Col4Major(src, dst, row, deep); | |||||
| #else | |||||
| RowMajor2Col12Major(src, dst, row, deep); | |||||
| #endif | |||||
| } | |||||
| void UpdateGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep, | |||||
| int col, int col_align, bool is_vec) { | |||||
| const float *input_weight = weight; | |||||
| const float *forget_weight = weight + deep * col * 2; | |||||
| const float *cell_weight = weight + deep * col * 3; | |||||
| const float *output_weight = weight + deep * col; | |||||
| const float *input_bias = bias; | |||||
| const float *forget_bias = bias + col_align * 2; | |||||
| const float *cell_bias = bias + col_align * 3; | |||||
| const float *output_bias = bias + col_align; | |||||
| float *input_gate = gate_buffer; | float *input_gate = gate_buffer; | ||||
| float *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2; | |||||
| float *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3; | |||||
| float *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1; | |||||
| float *forget_gate = gate_buffer + row * col * 2; | |||||
| float *cell_gate = gate_buffer + row * col * 3; | |||||
| float *output_gate = gate_buffer + row * col; | |||||
| LstmMatmul(input_gate, input, input_weight, input_bias, row, deep, col, is_vec); | |||||
| LstmMatmul(forget_gate, input, forget_weight, forget_bias, row, deep, col, is_vec); | |||||
| LstmMatmul(cell_gate, input, cell_weight, cell_bias, row, deep, col, is_vec); | |||||
| LstmMatmul(output_gate, input, output_weight, output_bias, row, deep, col, is_vec); | |||||
| } | |||||
| void LstmStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight, | |||||
| const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, | |||||
| float *matmul_buffer[2], const LstmParameter *lstm_param) { | |||||
| bool is_vec = lstm_param->batch_ == 1; | |||||
| // input * weight | // input * weight | ||||
| MatMulAcc(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_); | |||||
| MatMulAcc(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, | |||||
| lstm_parm->input_size_); | |||||
| MatMulAcc(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_); | |||||
| MatMulAcc(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, | |||||
| lstm_parm->input_size_); | |||||
| if (is_vec) { | |||||
| UpdateGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_, | |||||
| lstm_param->hidden_size_, lstm_param->col_align_, is_vec); | |||||
| } else { | |||||
| // pack input for matmul | |||||
| PackLstmInput(matmul_buffer[0], input, lstm_param->batch_, lstm_param->input_size_); | |||||
| UpdateGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_, | |||||
| lstm_param->hidden_size_, lstm_param->col_align_, is_vec); | |||||
| } | |||||
| // state * weight | // state * weight | ||||
| MatMulAcc(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, | |||||
| lstm_parm->hidden_size_); | |||||
| MatMulAcc(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, | |||||
| lstm_parm->hidden_size_); | |||||
| MatMulAcc(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, | |||||
| lstm_parm->hidden_size_); | |||||
| MatMulAcc(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, | |||||
| lstm_parm->hidden_size_); | |||||
| float *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4; | |||||
| const float *state_bias = bias + lstm_param->col_align_ * 4; | |||||
| if (is_vec) { | |||||
| UpdateGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, | |||||
| lstm_param->hidden_size_, lstm_param->col_align_, is_vec); | |||||
| } else { | |||||
| // pack state for matmul | |||||
| PackLstmInput(matmul_buffer[1], hidden_state, lstm_param->batch_, lstm_param->hidden_size_); | |||||
| UpdateGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, | |||||
| lstm_param->hidden_size_, lstm_param->col_align_, is_vec); | |||||
| } | |||||
| ElementAdd(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_); | |||||
| float *input_gate = gate_buffer; | |||||
| float *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2; | |||||
| float *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3; | |||||
| float *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_; | |||||
| // update input_gate | // update input_gate | ||||
| Sigmoid(input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, input_gate); | |||||
| Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate); | |||||
| // update forget_gate | // update forget_gate | ||||
| Sigmoid(forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, forget_gate); | |||||
| Sigmoid(forget_gate, lstm_param->batch_ * lstm_param->hidden_size_, forget_gate); | |||||
| // update cell_gate | // update cell_gate | ||||
| Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate); | |||||
| Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate); | |||||
| // update cell state | // update cell state | ||||
| UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, | |||||
| lstm_parm->smooth_); | |||||
| UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_param->batch_, | |||||
| lstm_param->hidden_size_, lstm_param->smooth_); | |||||
| // update output_gate | // update output_gate | ||||
| Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate); | |||||
| Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate); | |||||
| // update output | // update output | ||||
| UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, | |||||
| lstm_parm->smooth_); | |||||
| memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); | |||||
| if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) { | |||||
| memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); | |||||
| memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_, | |||||
| lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); | |||||
| UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_param->batch_, lstm_param->hidden_size_, | |||||
| lstm_param->smooth_); | |||||
| memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); | |||||
| if (!(lstm_param->smooth_ >= -FLT_EPSILON && lstm_param->smooth_ <= FLT_EPSILON)) { | |||||
| memcpy(cell_state, state_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); | |||||
| memcpy(hidden_state, state_buffer + lstm_param->batch_ * lstm_param->hidden_size_, | |||||
| lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); | |||||
| } | } | ||||
| } | } | ||||
| void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, | void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, | ||||
| float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, | |||||
| const LstmParameter *lstm_parm) { | |||||
| float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2], | |||||
| const LstmParameter *lstm_param) { | |||||
| // forward | // forward | ||||
| const float *input_input_weight = weight_i; | |||||
| const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; | |||||
| const float *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3; | |||||
| const float *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1; | |||||
| const float *state_input_weight = weight_h; | |||||
| const float *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2; | |||||
| const float *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3; | |||||
| const float *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1; | |||||
| for (int t = 0; t < lstm_parm->seq_len_; t++) { | |||||
| const float *input_ptr = input + t * lstm_parm->input_step_; | |||||
| float *output_ptr = output + t * lstm_parm->output_step_; | |||||
| LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight, | |||||
| state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state, | |||||
| cell_state, gate_buffer, state_buffer, lstm_parm); | |||||
| for (int t = 0; t < lstm_param->seq_len_; t++) { | |||||
| const float *input_ptr = input + t * lstm_param->input_step_; | |||||
| float *output_ptr = output + t * lstm_param->output_step_; | |||||
| LstmStepUnit(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer, state_buffer, | |||||
| matmul_buffer, lstm_param); | |||||
| } | } | ||||
| // backward | // backward | ||||
| if (lstm_parm->bidirectional_) { | |||||
| input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4; | |||||
| input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6; | |||||
| input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7; | |||||
| input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5; | |||||
| state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4; | |||||
| state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6; | |||||
| state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7; | |||||
| state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5; | |||||
| float *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_; | |||||
| const float *backward_bias = bias + 4 * lstm_parm->hidden_size_; | |||||
| float *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_; | |||||
| float *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_; | |||||
| for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) { | |||||
| const float *input_ptr = input + t * lstm_parm->input_step_; | |||||
| float *output_ptr = backward_output + t * lstm_parm->output_step_; | |||||
| LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, | |||||
| input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, | |||||
| backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, lstm_parm); | |||||
| if (lstm_param->bidirectional_) { | |||||
| const float *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_; | |||||
| const float *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_; | |||||
| const float *backward_bias = bias + 8 * lstm_param->hidden_size_; | |||||
| float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; | |||||
| float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; | |||||
| float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; | |||||
| for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) { | |||||
| const float *input_ptr = input + t * lstm_param->input_step_; | |||||
| float *output_ptr = backward_output + t * lstm_param->output_step_; | |||||
| LstmStepUnit(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias, backward_hidden_state, | |||||
| backward_cell_state, gate_buffer, state_buffer, matmul_buffer, lstm_param); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -28,7 +28,7 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int | |||||
| int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size); | int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size); | ||||
| void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, | void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, | ||||
| float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, | |||||
| float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2], | |||||
| const LstmParameter *lstm_parm); | const LstmParameter *lstm_parm); | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -34,6 +34,8 @@ typedef struct LstmParameter { | |||||
| // output_hidden = old_hidden * smooth + new_hidden * (1 - smooth) | // output_hidden = old_hidden * smooth + new_hidden * (1 - smooth) | ||||
| // output_cell = old_cell * smooth + new_cell * (1 - smooth) | // output_cell = old_cell * smooth + new_cell * (1 - smooth) | ||||
| float smooth_; | float smooth_; | ||||
| int col_align_; | |||||
| int row_align_; | |||||
| } LstmParameter; | } LstmParameter; | ||||
| #endif // MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_ | #endif // MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_ | ||||
| @@ -84,9 +84,8 @@ std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, const size_t | |||||
| bool IsPackedOp(schema::PrimitiveType op_type) { | bool IsPackedOp(schema::PrimitiveType op_type) { | ||||
| static std::vector<schema::PrimitiveType> packed_ops = { | static std::vector<schema::PrimitiveType> packed_ops = { | ||||
| schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, | |||||
| schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, | |||||
| schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm}; | |||||
| schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D, | |||||
| schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_MatMul}; | |||||
| return IsContain(packed_ops, op_type); | return IsContain(packed_ops, op_type); | ||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -20,35 +20,104 @@ | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "nnacl/fp32/matmul_fp32.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_MEMORY_FAILED; | |||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using mindspore::schema::PrimitiveType_Lstm; | using mindspore::schema::PrimitiveType_Lstm; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| void LstmCPUKernel::FreeTmpBuffer() { | void LstmCPUKernel::FreeTmpBuffer() { | ||||
| if (gate_buffer_ != nullptr) { | |||||
| free(gate_buffer_); | |||||
| gate_buffer_ = nullptr; | |||||
| } | |||||
| if (state_buffer_ != nullptr) { | |||||
| free(state_buffer_); | |||||
| state_buffer_ = nullptr; | |||||
| if (!is_vec_) { | |||||
| if (weight_i_ptr_ != nullptr) { | |||||
| free(weight_i_ptr_); | |||||
| weight_i_ptr_ = nullptr; | |||||
| } | |||||
| if (weight_h_ptr_ != nullptr) { | |||||
| free(weight_h_ptr_); | |||||
| weight_h_ptr_ = nullptr; | |||||
| } | |||||
| if (bias_ptr_ != nullptr) { | |||||
| free(bias_ptr_); | |||||
| bias_ptr_ = nullptr; | |||||
| } | |||||
| } | } | ||||
| if (weight_i_ptr_ != nullptr) { | |||||
| free(weight_i_ptr_); | |||||
| weight_i_ptr_ = nullptr; | |||||
| } | |||||
| void LstmCPUKernel::FreeRunBuffer() { | |||||
| context_->allocator->Free(gate_buffer_); | |||||
| context_->allocator->Free(state_buffer_); | |||||
| if (!is_vec_) { | |||||
| for (int i = 0; i < 2; i++) { | |||||
| context_->allocator->Free(matmul_buffer_[i]); | |||||
| } | |||||
| } | } | ||||
| if (weight_h_ptr_ != nullptr) { | |||||
| free(weight_h_ptr_); | |||||
| weight_h_ptr_ = nullptr; | |||||
| } | |||||
| int InitRightMatrix(float *dst, const float *src, int batch, int deep, int col, int col_align, bool is_vec) { | |||||
| for (int i = 0; i < batch; i++) { | |||||
| auto src_batch = src + i * col * deep; | |||||
| auto dst_batch = dst + i * col_align * deep; | |||||
| #ifdef ENABLE_AVX | |||||
| RowMajor2Col16Major(src_batch, dst_batch, col, deep); | |||||
| #elif defined(ENABLE_ARM32) | |||||
| RowMajor2Col4Major(src_batch, dst_batch, col, deep); | |||||
| #else | |||||
| RowMajor2Col8Major(src_batch, dst_batch, col, deep); | |||||
| #endif | |||||
| } | } | ||||
| if (bias_ptr_ != nullptr) { | |||||
| free(bias_ptr_); | |||||
| bias_ptr_ = nullptr; | |||||
| return RET_OK; | |||||
| } | |||||
| int LstmCPUKernel::InitWeightBias() { | |||||
| auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4; | |||||
| if (!is_vec_) { | |||||
| // malloc and init input * weight right matrix buffer | |||||
| auto weight_i = in_tensors_.at(1); | |||||
| MS_ASSERT(weight_i != nullptr); | |||||
| weight_i_ptr_ = reinterpret_cast<float *>( | |||||
| malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->input_size_ * sizeof(float))); | |||||
| if (weight_i_ptr_ == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto weight_i_data = reinterpret_cast<float *>(weight_i->data_c()); | |||||
| InitRightMatrix(weight_i_ptr_, weight_i_data, weight_batch, lstm_param_->input_size_, lstm_param_->hidden_size_, | |||||
| lstm_param_->col_align_, is_vec_); | |||||
| // malloc and init state * weight right matrix buffer | |||||
| auto weight_h = in_tensors_.at(2); | |||||
| MS_ASSERT(weight_h != nullptr); | |||||
| weight_h_ptr_ = reinterpret_cast<float *>( | |||||
| malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->hidden_size_ * sizeof(float))); | |||||
| if (weight_h_ptr_ == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ptr_ error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto weight_h_data = reinterpret_cast<float *>(weight_h->data_c()); | |||||
| InitRightMatrix(weight_h_ptr_, weight_h_data, weight_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, | |||||
| lstm_param_->col_align_, is_vec_); | |||||
| // init bias | |||||
| int bias_batch = lstm_param_->bidirectional_ ? 16 : 8; | |||||
| bias_ptr_ = reinterpret_cast<float *>(malloc(bias_batch * lstm_param_->col_align_ * sizeof(float))); | |||||
| if (bias_ptr_ == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(bias_ptr_, 0, bias_batch * lstm_param_->col_align_ * sizeof(float)); | |||||
| auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c()); | |||||
| for (int i = 0; i < bias_batch; i++) { | |||||
| auto src_batch = bias_data + i * lstm_param_->hidden_size_; | |||||
| auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_; | |||||
| memcpy(dst_batch, src_batch, lstm_param_->hidden_size_ * sizeof(float)); | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | |||||
| } | } | ||||
| int LstmCPUKernel::InitParam() { | int LstmCPUKernel::InitParam() { | ||||
| @@ -67,80 +136,27 @@ int LstmCPUKernel::InitParam() { | |||||
| lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_; | lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_; | ||||
| lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ | lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ | ||||
| : lstm_param_->batch_ * lstm_param_->hidden_size_; | : lstm_param_->batch_ * lstm_param_->hidden_size_; | ||||
| return RET_OK; | |||||
| } | |||||
| int LstmCPUKernel::InitBuffer() { | |||||
| gate_buffer_ = reinterpret_cast<float *>(malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float))); | |||||
| if (gate_buffer_ == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) { | |||||
| int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); | |||||
| state_buffer_ = reinterpret_cast<float *>(malloc(buffer_size)); | |||||
| if (state_buffer_ == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int LstmCPUKernel::InitWeightBias() { | |||||
| // copy weight_i and weight_h | |||||
| auto weight_i = in_tensors_.at(1); | |||||
| MS_ASSERT(weight_i != nullptr); | |||||
| weight_i_ptr_ = reinterpret_cast<float *>(malloc(weight_i->ElementsNum() * sizeof(float))); | |||||
| if (weight_i_ptr_ == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memcpy(weight_i_ptr_, weight_i->data_c(), weight_i->ElementsNum() * sizeof(float)); | |||||
| auto weight_h = in_tensors_.at(2); | |||||
| MS_ASSERT(weight_h != nullptr); | |||||
| weight_h_ptr_ = reinterpret_cast<float *>(malloc(weight_h->ElementsNum() * sizeof(float))); | |||||
| if (weight_h_ptr_ == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memcpy(weight_h_ptr_, weight_h->data_c(), weight_h->ElementsNum() * sizeof(float)); | |||||
| std::vector<int> w_shape = weight_i->shape(); | |||||
| auto hidden_size = w_shape.at(1) / 4; | |||||
| // init bias | |||||
| int bias_num = lstm_param_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size; | |||||
| bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float))); | |||||
| if (bias_ptr_ == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c()); | |||||
| const int state_bias_offset = 4 * hidden_size; | |||||
| for (int i = 0; i < state_bias_offset; i++) { | |||||
| bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; | |||||
| } | |||||
| if (lstm_param_->bidirectional_) { | |||||
| bias_data += 4 * hidden_size * 2; | |||||
| auto backward_bias = bias_ptr_ + 4 * hidden_size; | |||||
| for (int i = 0; i < state_bias_offset; i++) { | |||||
| backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset]; | |||||
| } | |||||
| } | |||||
| #ifdef ENABLE_AVX | |||||
| row_tile_ = C6NUM; | |||||
| col_tile_ = C16NUM; | |||||
| #elif defined(ENABLE_ARM32) | |||||
| row_tile_ = C12NUM; | |||||
| col_tile_ = C4NUM; | |||||
| #elif defined(ENABLE_SSE) | |||||
| row_tile_ = C4NUM; | |||||
| col_tile_ = C8NUM; | |||||
| #else | |||||
| row_tile_ = C12NUM; | |||||
| col_tile_ = C8NUM; | |||||
| #endif | |||||
| is_vec_ = lstm_param_->batch_ == 1; | |||||
| lstm_param_->row_align_ = is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, row_tile_); | |||||
| lstm_param_->col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, col_tile_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int LstmCPUKernel::Init() { | int LstmCPUKernel::Init() { | ||||
| FreeTmpBuffer(); | |||||
| auto ret = InitWeightBias(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error."; | |||||
| FreeTmpBuffer(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!InferShapeDone()) { | if (!InferShapeDone()) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -154,15 +170,50 @@ int LstmCPUKernel::ReSize() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| ret = InitBuffer(); | |||||
| FreeTmpBuffer(); | |||||
| ret = InitWeightBias(); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error."; | |||||
| MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error."; | |||||
| FreeTmpBuffer(); | FreeTmpBuffer(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int LstmCPUKernel::MallocRunBuffer() { | |||||
| if (!is_vec_) { | |||||
| matmul_buffer_[0] = reinterpret_cast<float *>( | |||||
| context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->input_size_ * sizeof(float))); | |||||
| if (matmul_buffer_[0] == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight left matirx error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| matmul_buffer_[1] = reinterpret_cast<float *>( | |||||
| context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->hidden_size_ * sizeof(float))); | |||||
| if (matmul_buffer_[1] == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc state * weight left matirx error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| gate_buffer_ = reinterpret_cast<float *>( | |||||
| context_->allocator->Malloc(8 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float))); | |||||
| if (gate_buffer_ == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) { | |||||
| int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); | |||||
| state_buffer_ = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size)); | |||||
| if (state_buffer_ == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int LstmCPUKernel::Run() { | int LstmCPUKernel::Run() { | ||||
| auto input = in_tensors_.at(kInputIndex); | auto input = in_tensors_.at(kInputIndex); | ||||
| MS_ASSERT(input != nullptr); | MS_ASSERT(input != nullptr); | ||||
| @@ -182,13 +233,26 @@ int LstmCPUKernel::Run() { | |||||
| auto output_cell_state = out_tensors_[2]; | auto output_cell_state = out_tensors_[2]; | ||||
| memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float)); | memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float)); | ||||
| auto ret = MallocRunBuffer(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel InitRunBuffer error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (is_vec_) { | |||||
| weight_i_ptr_ = reinterpret_cast<float *>(in_tensors_[1]->data_c()); | |||||
| weight_h_ptr_ = reinterpret_cast<float *>(in_tensors_[2]->data_c()); | |||||
| bias_ptr_ = reinterpret_cast<float *>(in_tensors_[3]->data_c()); | |||||
| } | |||||
| MS_ASSERT(weight_h_ptr_); | MS_ASSERT(weight_h_ptr_); | ||||
| MS_ASSERT(weight_i_ptr_); | MS_ASSERT(weight_i_ptr_); | ||||
| MS_ASSERT(bias_ptr_); | MS_ASSERT(bias_ptr_); | ||||
| MS_ASSERT(gate_buffer_); | MS_ASSERT(gate_buffer_); | ||||
| Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, | Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, | ||||
| reinterpret_cast<float *>(output_hidden_state->data_c()), reinterpret_cast<float *>(output_cell_state->data_c()), | reinterpret_cast<float *>(output_hidden_state->data_c()), reinterpret_cast<float *>(output_cell_state->data_c()), | ||||
| gate_buffer_, state_buffer_, lstm_param_); | |||||
| gate_buffer_, state_buffer_, matmul_buffer_, lstm_param_); | |||||
| FreeRunBuffer(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -39,8 +39,9 @@ class LstmCPUKernel : public LiteKernel { | |||||
| private: | private: | ||||
| void FreeTmpBuffer(); | void FreeTmpBuffer(); | ||||
| void FreeRunBuffer(); | |||||
| int InitParam(); | int InitParam(); | ||||
| int InitBuffer(); | |||||
| int MallocRunBuffer(); | |||||
| int InitWeightBias(); | int InitWeightBias(); | ||||
| float *gate_buffer_ = nullptr; | float *gate_buffer_ = nullptr; | ||||
| @@ -48,6 +49,10 @@ class LstmCPUKernel : public LiteKernel { | |||||
| float *weight_i_ptr_ = nullptr; | float *weight_i_ptr_ = nullptr; | ||||
| float *weight_h_ptr_ = nullptr; | float *weight_h_ptr_ = nullptr; | ||||
| float *bias_ptr_ = nullptr; | float *bias_ptr_ = nullptr; | ||||
| float *matmul_buffer_[2]; | |||||
| int row_tile_ = 0; | |||||
| int col_tile_ = 0; | |||||
| bool is_vec_ = false; | |||||
| LstmParameter *lstm_param_ = nullptr; | LstmParameter *lstm_param_ = nullptr; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||