diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.c b/mindspore/lite/nnacl/fp32/lstm_fp32.c index 35627feb91..969f88fb0d 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.c +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.c @@ -35,6 +35,24 @@ void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, } } +void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float *src_batch = src + i * col; + float *dst_batch = dst + i * col_align; + memcpy(dst_batch, src_batch, col * sizeof(float)); + } + if (is_bidirectional) { + const float *backward_src = src + batch * col; + float *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float *backward_src_batch = backward_src + i * col; + float *backward_dst_batch = backward_dst + i * col_align; + memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float)); + } + } +} + void PackLstmInput(const float *src, float *dst, int row, int deep) { #ifdef ENABLE_AVX RowMajor2Col6Major(src, dst, row, deep); @@ -162,39 +180,28 @@ void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight, } } -void LstmStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight, - const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer[2], - float *matmul_buffer[2], const LstmParameter *lstm_param) { +void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate, + const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state, + float *state_gate, float *state_buffer[2], float *packed_state, const LstmParameter *lstm_param) { bool is_vec = lstm_param->batch_ == 1; - // input * weight - if (is_vec) { - UpdateLstmGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_, - lstm_param->hidden_size_, lstm_param->col_align_, is_vec); - } else { - // pack input for matmul - PackLstmInput(input, matmul_buffer[0], lstm_param->batch_, lstm_param->input_size_); - UpdateLstmGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_, - lstm_param->hidden_size_, lstm_param->col_align_, is_vec); - } - // state * weight - float *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4; - const float *state_bias = bias + lstm_param->col_align_ * 4; if (is_vec) { UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, - lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); } else { // pack state for matmul - PackLstmInput(hidden_state, matmul_buffer[1], lstm_param->batch_, lstm_param->hidden_size_); - UpdateLstmGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, - lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + PackLstmInput(hidden_state, packed_state, lstm_param->batch_, lstm_param->hidden_size_); + UpdateLstmGate(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); } - ElementAdd(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(cell_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 3, cell_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(output_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_, output_gate, + lstm_param->batch_ * lstm_param->hidden_size_); - float *input_gate = gate_buffer; - float *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2; - float *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3; - float *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_; // update input_gate Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate); @@ -223,30 +230,58 @@ void LstmStepUnit(float *output, const float *input, const float *input_weight, } } -void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, - float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer[2], float *matmul_buffer[2], - const LstmParameter *lstm_param) { - // forward +void LstmUnidirectional(float *output, const float *packed_input, const float *weight_i, const float *weight_h, + const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state, + float *state_buffer[2], float *buffer[4], const LstmParameter *lstm_param, bool is_backward) { + float *gate = buffer[1]; + float *packed_state = buffer[2]; + float *state_gate = buffer[3]; + for (int i = 0; i < 4; i++) { + const float *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; + const float *bias_loop = input_bias + lstm_param->input_col_align_ * i; + float *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; + MatMulOpt(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, + lstm_param->seq_len_ * lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, + OutType_Nhwc); + } + + float *input_gate = gate; + float *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2; + float *cell_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 3; + float *output_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_; for (int t = 0; t < lstm_param->seq_len_; t++) { - const float *input_ptr = input + t * lstm_param->input_step_; - float *output_ptr = output + t * lstm_param->output_step_; - LstmStepUnit(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer, state_buffer, - matmul_buffer, lstm_param); + int real_t = is_backward ? lstm_param->seq_len_ - t - 1 : t; + float *input_gate_t = input_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *forget_gate_t = forget_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *cell_gate_t = cell_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *output_ptr = output + real_t * lstm_param->output_step_; + LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, + hidden_state, cell_state, state_gate, state_buffer, packed_state, lstm_param); } +} + +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, + const float *state_bias, float *hidden_state, float *cell_state, float *state_buffer[2], float *buffer[4], + const LstmParameter *lstm_param) { + // forward + float *packed_input = buffer[0]; + PackLstmInput(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_); + LstmUnidirectional(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state, + state_buffer, buffer, lstm_param, false); // backward if (lstm_param->bidirectional_) { - const float *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_; - const float *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_; - const float *backward_bias = bias + 8 * lstm_param->col_align_; + const float *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; + const float *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->hidden_size_; + const float *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; + const float *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; - for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) { - const float *input_ptr = input + t * lstm_param->input_step_; - float *output_ptr = backward_output + t * lstm_param->output_step_; - LstmStepUnit(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias, backward_hidden_state, - backward_cell_state, gate_buffer, state_buffer, matmul_buffer, lstm_param); - } + + LstmUnidirectional(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, + backward_state_bias, backward_hidden_state, backward_cell_state, state_buffer, buffer, + lstm_param, true); } } diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.h b/mindspore/lite/nnacl/fp32/lstm_fp32.h index ef5ff9f954..3a7142e6c6 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.h +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.h @@ -23,6 +23,8 @@ extern "C" { #endif void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align); +void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional); + void PackLstmInput(const float *src, float *dst, int row, int deep); void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec); @@ -31,8 +33,8 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size); -void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, - float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer[2], float *matmul_buffer[2], +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, + const float *state_bias, float *hidden_state, float *cell_state, float *state_buffer[2], float *buffer[4], const LstmParameter *lstm_param); #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/lstm_parameter.h b/mindspore/lite/nnacl/lstm_parameter.h index 98d78727db..d29c880f87 100644 --- a/mindspore/lite/nnacl/lstm_parameter.h +++ b/mindspore/lite/nnacl/lstm_parameter.h @@ -32,6 +32,10 @@ typedef struct LstmParameter { bool bidirectional_; float zoneout_cell_; float zoneout_hidden_; + int input_row_align_; + int input_col_align_; + int state_row_align_; + int state_col_align_; int col_align_; int row_align_; } LstmParameter; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc index 71dbf39068..f35f6f4f3a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc @@ -31,78 +31,98 @@ using mindspore::schema::PrimitiveType_LSTM; namespace mindspore::kernel { void LstmCPUKernel::FreeTmpBuffer() { + if (weight_i_ptr_ != nullptr) { + free(weight_i_ptr_); + weight_i_ptr_ = nullptr; + } + if (input_bias_ != nullptr) { + free(input_bias_); + input_bias_ = nullptr; + } if (!is_vec_) { - if (weight_i_ptr_ != nullptr) { - free(weight_i_ptr_); - weight_i_ptr_ = nullptr; - } if (weight_h_ptr_ != nullptr) { free(weight_h_ptr_); weight_h_ptr_ = nullptr; } - if (bias_ptr_ != nullptr) { - free(bias_ptr_); - bias_ptr_ = nullptr; - } + } + if (state_bias_ != nullptr) { + free(state_bias_); + state_bias_ = nullptr; } } void LstmCPUKernel::FreeRunBuffer() { - context_->allocator->Free(gate_buffer_); for (int i = 0; i < 2; i++) { context_->allocator->Free(state_buffer_[i]); } + context_->allocator->Free(buffer_[0]); + context_->allocator->Free(buffer_[1]); if (!is_vec_) { - for (int i = 0; i < 2; i++) { - context_->allocator->Free(matmul_buffer_[i]); - } + context_->allocator->Free(buffer_[2]); } + context_->allocator->Free(buffer_[3]); } -int LstmCPUKernel::InitWeightBias() { - auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4; - if (!is_vec_) { - // malloc and init input * weight right matrix buffer - auto weight_i = in_tensors_.at(1); - MS_ASSERT(weight_i != nullptr); - weight_i_ptr_ = reinterpret_cast( - malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->input_size_ * sizeof(float))); - if (weight_i_ptr_ == nullptr) { - MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error."; - return RET_ERROR; - } - auto weight_i_data = reinterpret_cast(weight_i->data_c()); - PackLstmWeight(weight_i_ptr_, weight_i_data, weight_batch, lstm_param_->input_size_, lstm_param_->hidden_size_, - lstm_param_->col_align_); +int LstmCPUKernel::InitInputWeightBias() { + // malloc and init input * weight right matrix buffer + // input -- row: seq_len * batch; col: input_size + // weight -- row: hidden_size; col: input_size, need transpose + // result -- row: seq_len * batch; col: hidden_size + auto weight_i = in_tensors_.at(1); + MS_ASSERT(weight_i != nullptr); + weight_i_ptr_ = reinterpret_cast( + malloc(weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float))); + if (weight_i_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error."; + return RET_ERROR; + } + auto weight_i_data = reinterpret_cast(weight_i->data_c()); + PackLstmWeight(weight_i_ptr_, weight_i_data, weight_batch_, lstm_param_->input_size_, lstm_param_->hidden_size_, + lstm_param_->input_col_align_); + + // input bias + input_bias_ = reinterpret_cast(malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float))); + if (input_bias_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc input_bias_ error."; + return RET_ERROR; + } + memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float)); + PackLstmBias(input_bias_, reinterpret_cast(in_tensors_.at(3)->data_c()), weight_batch_, + lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_); + return RET_OK; +} - // malloc and init state * weight right matrix buffer - auto weight_h = in_tensors_.at(2); - MS_ASSERT(weight_h != nullptr); +int LstmCPUKernel::InitStateWeightBias() { + // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. + // state -- row: batch; col: hidden_size + // weight -- row: hidden_size; col: hidden_size, need transpose + // result -- row: batch; col: hidden_size + auto weight_h = in_tensors_.at(2); + MS_ASSERT(weight_h != nullptr); + auto weight_h_data = reinterpret_cast(weight_h->data_c()); + if (!is_vec_) { weight_h_ptr_ = reinterpret_cast( - malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->hidden_size_ * sizeof(float))); + malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float))); if (weight_h_ptr_ == nullptr) { MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ptr_ error."; return RET_ERROR; } - auto weight_h_data = reinterpret_cast(weight_h->data_c()); - PackLstmWeight(weight_h_ptr_, weight_h_data, weight_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, - lstm_param_->col_align_); - - // init bias - int bias_batch = lstm_param_->bidirectional_ ? 16 : 8; - bias_ptr_ = reinterpret_cast(malloc(bias_batch * lstm_param_->col_align_ * sizeof(float))); - if (bias_ptr_ == nullptr) { - MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error."; - return RET_ERROR; - } - memset(bias_ptr_, 0, bias_batch * lstm_param_->col_align_ * sizeof(float)); - auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); - for (int i = 0; i < bias_batch; i++) { - auto src_batch = bias_data + i * lstm_param_->hidden_size_; - auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_; - memcpy(dst_batch, src_batch, lstm_param_->hidden_size_ * sizeof(float)); - } + PackLstmWeight(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_, + lstm_param_->state_col_align_); + } else { + weight_h_ptr_ = weight_h_data; + } + + // state bias + state_bias_ = reinterpret_cast(malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float))); + if (state_bias_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc state_bias_ error."; + return RET_ERROR; } + memset(state_bias_, 0, weight_batch_ * lstm_param_->state_col_align_ * sizeof(float)); + auto state_bias = reinterpret_cast(in_tensors_.at(3)->data_c()) + 4 * lstm_param_->hidden_size_; + PackLstmBias(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, lstm_param_->state_col_align_, + lstm_param_->bidirectional_); return RET_OK; } @@ -119,9 +139,9 @@ int LstmCPUKernel::InitParam() { std::vector w_shape = weight_i->shape(); lstm_param_->hidden_size_ = w_shape.at(1) / 4; - lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_; lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ : lstm_param_->batch_ * lstm_param_->hidden_size_; + weight_batch_ = lstm_param_->bidirectional_ ? 8 : 4; #ifdef ENABLE_AVX row_tile_ = C6NUM; @@ -136,9 +156,12 @@ int LstmCPUKernel::InitParam() { row_tile_ = C12NUM; col_tile_ = C8NUM; #endif + lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, row_tile_); + lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, col_tile_); + is_vec_ = lstm_param_->batch_ == 1; - lstm_param_->row_align_ = is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, row_tile_); - lstm_param_->col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, col_tile_); + lstm_param_->state_row_align_ = is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, row_tile_); + lstm_param_->state_col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, col_tile_); return RET_OK; } @@ -157,9 +180,16 @@ int LstmCPUKernel::ReSize() { } FreeTmpBuffer(); - ret = InitWeightBias(); + ret = InitInputWeightBias(); if (ret != RET_OK) { - MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error."; + MS_LOG(ERROR) << "LstmCPUKernel InitInputWeightBias error."; + FreeTmpBuffer(); + return RET_ERROR; + } + + ret = InitStateWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitStateWeightBias error."; FreeTmpBuffer(); return RET_ERROR; } @@ -167,32 +197,42 @@ int LstmCPUKernel::ReSize() { } int LstmCPUKernel::MallocRunBuffer() { - if (!is_vec_) { - matmul_buffer_[0] = reinterpret_cast( - context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->input_size_ * sizeof(float))); - if (matmul_buffer_[0] == nullptr) { - MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight left matirx error."; - return RET_ERROR; - } + for (int i = 0; i < 4; i++) { + buffer_[i] = nullptr; + } + buffer_[0] = reinterpret_cast( + context_->allocator->Malloc(lstm_param_->input_row_align_ * lstm_param_->input_size_ * sizeof(float))); + if (buffer_[0] == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight left matirx error."; + return RET_ERROR; + } - matmul_buffer_[1] = reinterpret_cast( - context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->hidden_size_ * sizeof(float))); - if (matmul_buffer_[1] == nullptr) { + buffer_[1] = reinterpret_cast(context_->allocator->Malloc(4 * lstm_param_->seq_len_ * lstm_param_->batch_ * + lstm_param_->hidden_size_ * sizeof(float))); + if (buffer_[1] == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight result matirx error."; + return RET_ERROR; + } + + if (!is_vec_) { + buffer_[2] = reinterpret_cast( + context_->allocator->Malloc(4 * lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float))); + if (buffer_[2] == nullptr) { MS_LOG(ERROR) << "LstmCPUKernel malloc state * weight left matirx error."; return RET_ERROR; } } - gate_buffer_ = reinterpret_cast( - context_->allocator->Malloc(8 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float))); - if (gate_buffer_ == nullptr) { - MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; + buffer_[3] = reinterpret_cast( + context_->allocator->Malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float))); + if (buffer_[3] == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc state gate buffer error."; return RET_ERROR; } state_buffer_[0] = nullptr; state_buffer_[1] = nullptr; if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { - int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); + auto buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); state_buffer_[0] = reinterpret_cast(context_->allocator->Malloc(buffer_size)); if (state_buffer_[0] == nullptr) { MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer for cell error."; @@ -200,7 +240,7 @@ int LstmCPUKernel::MallocRunBuffer() { } } if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { - int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); + auto buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); state_buffer_[1] = reinterpret_cast(context_->allocator->Malloc(buffer_size)); if (state_buffer_[1] == nullptr) { MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer for hidden error."; @@ -235,18 +275,13 @@ int LstmCPUKernel::Run() { return RET_ERROR; } - if (is_vec_) { - weight_i_ptr_ = reinterpret_cast(in_tensors_[1]->data_c()); - weight_h_ptr_ = reinterpret_cast(in_tensors_[2]->data_c()); - bias_ptr_ = reinterpret_cast(in_tensors_[3]->data_c()); - } MS_ASSERT(weight_h_ptr_); MS_ASSERT(weight_i_ptr_); - MS_ASSERT(bias_ptr_); - MS_ASSERT(gate_buffer_); - Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, + MS_ASSERT(input_bias_); + MS_ASSERT(state_bias_); + Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, reinterpret_cast(output_hidden_state->data_c()), reinterpret_cast(output_cell_state->data_c()), - gate_buffer_, state_buffer_, matmul_buffer_, lstm_param_); + state_buffer_, buffer_, lstm_param_); FreeRunBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h index 90256141b7..53179c598b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h @@ -41,16 +41,18 @@ class LstmCPUKernel : public LiteKernel { void FreeRunBuffer(); int InitParam(); int MallocRunBuffer(); - int InitWeightBias(); + int InitInputWeightBias(); + int InitStateWeightBias(); - float *gate_buffer_ = nullptr; float *state_buffer_[2]; float *weight_i_ptr_ = nullptr; float *weight_h_ptr_ = nullptr; - float *bias_ptr_ = nullptr; - float *matmul_buffer_[2]; + float *input_bias_ = nullptr; + float *state_bias_ = nullptr; + float *buffer_[4]; int row_tile_ = 0; int col_tile_ = 0; + int weight_batch_ = 0; bool is_vec_ = false; LstmParameter *lstm_param_ = nullptr; };