Browse Source

!13446 [MSLITE][Develop] optimize cpu fp32 op: lstm

From: @yangruoqi713
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
ddffb61c62
5 changed files with 205 additions and 127 deletions
  1. +77
    -42
      mindspore/lite/nnacl/fp32/lstm_fp32.c
  2. +4
    -2
      mindspore/lite/nnacl/fp32/lstm_fp32.h
  3. +4
    -0
      mindspore/lite/nnacl/lstm_parameter.h
  4. +114
    -79
      mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc
  5. +6
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h

+ 77
- 42
mindspore/lite/nnacl/fp32/lstm_fp32.c View File

@@ -35,6 +35,24 @@ void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col,
}
}

void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional) {
int unidirectional_batch = is_bidirectional ? batch / 2 : batch;
for (int i = 0; i < unidirectional_batch; i++) {
const float *src_batch = src + i * col;
float *dst_batch = dst + i * col_align;
memcpy(dst_batch, src_batch, col * sizeof(float));
}
if (is_bidirectional) {
const float *backward_src = src + batch * col;
float *backward_dst = dst + unidirectional_batch * col_align;
for (int i = 0; i < unidirectional_batch; i++) {
const float *backward_src_batch = backward_src + i * col;
float *backward_dst_batch = backward_dst + i * col_align;
memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float));
}
}
}

void PackLstmInput(const float *src, float *dst, int row, int deep) {
#ifdef ENABLE_AVX
RowMajor2Col6Major(src, dst, row, deep);
@@ -162,39 +180,28 @@ void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight,
}
}

void LstmStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight,
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer[2],
float *matmul_buffer[2], const LstmParameter *lstm_param) {
void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate,
const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state,
float *state_gate, float *state_buffer[2], float *packed_state, const LstmParameter *lstm_param) {
bool is_vec = lstm_param->batch_ == 1;
// input * weight
if (is_vec) {
UpdateLstmGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
} else {
// pack input for matmul
PackLstmInput(input, matmul_buffer[0], lstm_param->batch_, lstm_param->input_size_);
UpdateLstmGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
}

// state * weight
float *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4;
const float *state_bias = bias + lstm_param->col_align_ * 4;
if (is_vec) {
UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec);
} else {
// pack state for matmul
PackLstmInput(hidden_state, matmul_buffer[1], lstm_param->batch_, lstm_param->hidden_size_);
UpdateLstmGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
PackLstmInput(hidden_state, packed_state, lstm_param->batch_, lstm_param->hidden_size_);
UpdateLstmGate(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec);
}
ElementAdd(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_);
ElementAdd(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_);
ElementAdd(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate,
lstm_param->batch_ * lstm_param->hidden_size_);
ElementAdd(cell_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 3, cell_gate,
lstm_param->batch_ * lstm_param->hidden_size_);
ElementAdd(output_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_, output_gate,
lstm_param->batch_ * lstm_param->hidden_size_);

float *input_gate = gate_buffer;
float *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2;
float *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3;
float *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_;
// update input_gate
Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate);

@@ -223,30 +230,58 @@ void LstmStepUnit(float *output, const float *input, const float *input_weight,
}
}

void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer[2], float *matmul_buffer[2],
const LstmParameter *lstm_param) {
// forward
void LstmUnidirectional(float *output, const float *packed_input, const float *weight_i, const float *weight_h,
const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state,
float *state_buffer[2], float *buffer[4], const LstmParameter *lstm_param, bool is_backward) {
float *gate = buffer[1];
float *packed_state = buffer[2];
float *state_gate = buffer[3];
for (int i = 0; i < 4; i++) {
const float *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i;
const float *bias_loop = input_bias + lstm_param->input_col_align_ * i;
float *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i;
MatMulOpt(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_,
lstm_param->seq_len_ * lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_,
OutType_Nhwc);
}

float *input_gate = gate;
float *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2;
float *cell_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 3;
float *output_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_;
for (int t = 0; t < lstm_param->seq_len_; t++) {
const float *input_ptr = input + t * lstm_param->input_step_;
float *output_ptr = output + t * lstm_param->output_step_;
LstmStepUnit(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer, state_buffer,
matmul_buffer, lstm_param);
int real_t = is_backward ? lstm_param->seq_len_ - t - 1 : t;
float *input_gate_t = input_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
float *forget_gate_t = forget_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
float *cell_gate_t = cell_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
float *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
float *output_ptr = output + real_t * lstm_param->output_step_;
LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
hidden_state, cell_state, state_gate, state_buffer, packed_state, lstm_param);
}
}

void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias,
const float *state_bias, float *hidden_state, float *cell_state, float *state_buffer[2], float *buffer[4],
const LstmParameter *lstm_param) {
// forward
float *packed_input = buffer[0];
PackLstmInput(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_);
LstmUnidirectional(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state,
state_buffer, buffer, lstm_param, false);

// backward
if (lstm_param->bidirectional_) {
const float *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_;
const float *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_;
const float *backward_bias = bias + 8 * lstm_param->col_align_;
const float *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_;
const float *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->hidden_size_;
const float *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_;
const float *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_;
float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_;
float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_;
float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_;
for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) {
const float *input_ptr = input + t * lstm_param->input_step_;
float *output_ptr = backward_output + t * lstm_param->output_step_;
LstmStepUnit(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias, backward_hidden_state,
backward_cell_state, gate_buffer, state_buffer, matmul_buffer, lstm_param);
}

LstmUnidirectional(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias,
backward_state_bias, backward_hidden_state, backward_cell_state, state_buffer, buffer,
lstm_param, true);
}
}

+ 4
- 2
mindspore/lite/nnacl/fp32/lstm_fp32.h View File

@@ -23,6 +23,8 @@ extern "C" {
#endif
void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align);

void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional);

void PackLstmInput(const float *src, float *dst, int row, int deep);

void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec);
@@ -31,8 +33,8 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int

int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size);

void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer[2], float *matmul_buffer[2],
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias,
const float *state_bias, float *hidden_state, float *cell_state, float *state_buffer[2], float *buffer[4],
const LstmParameter *lstm_param);
#ifdef __cplusplus
}


+ 4
- 0
mindspore/lite/nnacl/lstm_parameter.h View File

@@ -32,6 +32,10 @@ typedef struct LstmParameter {
bool bidirectional_;
float zoneout_cell_;
float zoneout_hidden_;
int input_row_align_;
int input_col_align_;
int state_row_align_;
int state_col_align_;
int col_align_;
int row_align_;
} LstmParameter;


+ 114
- 79
mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc View File

@@ -31,78 +31,98 @@ using mindspore::schema::PrimitiveType_LSTM;

namespace mindspore::kernel {
void LstmCPUKernel::FreeTmpBuffer() {
if (weight_i_ptr_ != nullptr) {
free(weight_i_ptr_);
weight_i_ptr_ = nullptr;
}
if (input_bias_ != nullptr) {
free(input_bias_);
input_bias_ = nullptr;
}
if (!is_vec_) {
if (weight_i_ptr_ != nullptr) {
free(weight_i_ptr_);
weight_i_ptr_ = nullptr;
}
if (weight_h_ptr_ != nullptr) {
free(weight_h_ptr_);
weight_h_ptr_ = nullptr;
}
if (bias_ptr_ != nullptr) {
free(bias_ptr_);
bias_ptr_ = nullptr;
}
}
if (state_bias_ != nullptr) {
free(state_bias_);
state_bias_ = nullptr;
}
}

void LstmCPUKernel::FreeRunBuffer() {
context_->allocator->Free(gate_buffer_);
for (int i = 0; i < 2; i++) {
context_->allocator->Free(state_buffer_[i]);
}
context_->allocator->Free(buffer_[0]);
context_->allocator->Free(buffer_[1]);
if (!is_vec_) {
for (int i = 0; i < 2; i++) {
context_->allocator->Free(matmul_buffer_[i]);
}
context_->allocator->Free(buffer_[2]);
}
context_->allocator->Free(buffer_[3]);
}

int LstmCPUKernel::InitWeightBias() {
auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4;
if (!is_vec_) {
// malloc and init input * weight right matrix buffer
auto weight_i = in_tensors_.at(1);
MS_ASSERT(weight_i != nullptr);
weight_i_ptr_ = reinterpret_cast<float *>(
malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->input_size_ * sizeof(float)));
if (weight_i_ptr_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error.";
return RET_ERROR;
}
auto weight_i_data = reinterpret_cast<float *>(weight_i->data_c());
PackLstmWeight(weight_i_ptr_, weight_i_data, weight_batch, lstm_param_->input_size_, lstm_param_->hidden_size_,
lstm_param_->col_align_);
int LstmCPUKernel::InitInputWeightBias() {
// malloc and init input * weight right matrix buffer
// input -- row: seq_len * batch; col: input_size
// weight -- row: hidden_size; col: input_size, need transpose
// result -- row: seq_len * batch; col: hidden_size
auto weight_i = in_tensors_.at(1);
MS_ASSERT(weight_i != nullptr);
weight_i_ptr_ = reinterpret_cast<float *>(
malloc(weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float)));
if (weight_i_ptr_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error.";
return RET_ERROR;
}
auto weight_i_data = reinterpret_cast<float *>(weight_i->data_c());
PackLstmWeight(weight_i_ptr_, weight_i_data, weight_batch_, lstm_param_->input_size_, lstm_param_->hidden_size_,
lstm_param_->input_col_align_);

// input bias
input_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float)));
if (input_bias_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc input_bias_ error.";
return RET_ERROR;
}
memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float));
PackLstmBias(input_bias_, reinterpret_cast<float *>(in_tensors_.at(3)->data_c()), weight_batch_,
lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_);
return RET_OK;
}

// malloc and init state * weight right matrix buffer
auto weight_h = in_tensors_.at(2);
MS_ASSERT(weight_h != nullptr);
int LstmCPUKernel::InitStateWeightBias() {
// malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times.
// state -- row: batch; col: hidden_size
// weight -- row: hidden_size; col: hidden_size, need transpose
// result -- row: batch; col: hidden_size
auto weight_h = in_tensors_.at(2);
MS_ASSERT(weight_h != nullptr);
auto weight_h_data = reinterpret_cast<float *>(weight_h->data_c());
if (!is_vec_) {
weight_h_ptr_ = reinterpret_cast<float *>(
malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->hidden_size_ * sizeof(float)));
malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float)));
if (weight_h_ptr_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ptr_ error.";
return RET_ERROR;
}
auto weight_h_data = reinterpret_cast<float *>(weight_h->data_c());
PackLstmWeight(weight_h_ptr_, weight_h_data, weight_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_,
lstm_param_->col_align_);

// init bias
int bias_batch = lstm_param_->bidirectional_ ? 16 : 8;
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_batch * lstm_param_->col_align_ * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error.";
return RET_ERROR;
}
memset(bias_ptr_, 0, bias_batch * lstm_param_->col_align_ * sizeof(float));
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
for (int i = 0; i < bias_batch; i++) {
auto src_batch = bias_data + i * lstm_param_->hidden_size_;
auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_;
memcpy(dst_batch, src_batch, lstm_param_->hidden_size_ * sizeof(float));
}
PackLstmWeight(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_,
lstm_param_->state_col_align_);
} else {
weight_h_ptr_ = weight_h_data;
}

// state bias
state_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float)));
if (state_bias_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc state_bias_ error.";
return RET_ERROR;
}
memset(state_bias_, 0, weight_batch_ * lstm_param_->state_col_align_ * sizeof(float));
auto state_bias = reinterpret_cast<float *>(in_tensors_.at(3)->data_c()) + 4 * lstm_param_->hidden_size_;
PackLstmBias(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, lstm_param_->state_col_align_,
lstm_param_->bidirectional_);
return RET_OK;
}

@@ -119,9 +139,9 @@ int LstmCPUKernel::InitParam() {
std::vector<int> w_shape = weight_i->shape();
lstm_param_->hidden_size_ = w_shape.at(1) / 4;

lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_;
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
: lstm_param_->batch_ * lstm_param_->hidden_size_;
weight_batch_ = lstm_param_->bidirectional_ ? 8 : 4;

#ifdef ENABLE_AVX
row_tile_ = C6NUM;
@@ -136,9 +156,12 @@ int LstmCPUKernel::InitParam() {
row_tile_ = C12NUM;
col_tile_ = C8NUM;
#endif
lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, row_tile_);
lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, col_tile_);

is_vec_ = lstm_param_->batch_ == 1;
lstm_param_->row_align_ = is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, row_tile_);
lstm_param_->col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, col_tile_);
lstm_param_->state_row_align_ = is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, row_tile_);
lstm_param_->state_col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, col_tile_);
return RET_OK;
}

@@ -157,9 +180,16 @@ int LstmCPUKernel::ReSize() {
}

FreeTmpBuffer();
ret = InitWeightBias();
ret = InitInputWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error.";
MS_LOG(ERROR) << "LstmCPUKernel InitInputWeightBias error.";
FreeTmpBuffer();
return RET_ERROR;
}

ret = InitStateWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "LstmCPUKernel InitStateWeightBias error.";
FreeTmpBuffer();
return RET_ERROR;
}
@@ -167,32 +197,42 @@ int LstmCPUKernel::ReSize() {
}

int LstmCPUKernel::MallocRunBuffer() {
if (!is_vec_) {
matmul_buffer_[0] = reinterpret_cast<float *>(
context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->input_size_ * sizeof(float)));
if (matmul_buffer_[0] == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight left matirx error.";
return RET_ERROR;
}
for (int i = 0; i < 4; i++) {
buffer_[i] = nullptr;
}
buffer_[0] = reinterpret_cast<float *>(
context_->allocator->Malloc(lstm_param_->input_row_align_ * lstm_param_->input_size_ * sizeof(float)));
if (buffer_[0] == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight left matirx error.";
return RET_ERROR;
}

matmul_buffer_[1] = reinterpret_cast<float *>(
context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->hidden_size_ * sizeof(float)));
if (matmul_buffer_[1] == nullptr) {
buffer_[1] = reinterpret_cast<float *>(context_->allocator->Malloc(4 * lstm_param_->seq_len_ * lstm_param_->batch_ *
lstm_param_->hidden_size_ * sizeof(float)));
if (buffer_[1] == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight result matirx error.";
return RET_ERROR;
}

if (!is_vec_) {
buffer_[2] = reinterpret_cast<float *>(
context_->allocator->Malloc(4 * lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float)));
if (buffer_[2] == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc state * weight left matirx error.";
return RET_ERROR;
}
}

gate_buffer_ = reinterpret_cast<float *>(
context_->allocator->Malloc(8 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float)));
if (gate_buffer_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error.";
buffer_[3] = reinterpret_cast<float *>(
context_->allocator->Malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float)));
if (buffer_[3] == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc state gate buffer error.";
return RET_ERROR;
}
state_buffer_[0] = nullptr;
state_buffer_[1] = nullptr;
if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) {
int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float);
auto buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float);
state_buffer_[0] = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size));
if (state_buffer_[0] == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer for cell error.";
@@ -200,7 +240,7 @@ int LstmCPUKernel::MallocRunBuffer() {
}
}
if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) {
int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float);
auto buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float);
state_buffer_[1] = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size));
if (state_buffer_[1] == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer for hidden error.";
@@ -235,18 +275,13 @@ int LstmCPUKernel::Run() {
return RET_ERROR;
}

if (is_vec_) {
weight_i_ptr_ = reinterpret_cast<float *>(in_tensors_[1]->data_c());
weight_h_ptr_ = reinterpret_cast<float *>(in_tensors_[2]->data_c());
bias_ptr_ = reinterpret_cast<float *>(in_tensors_[3]->data_c());
}
MS_ASSERT(weight_h_ptr_);
MS_ASSERT(weight_i_ptr_);
MS_ASSERT(bias_ptr_);
MS_ASSERT(gate_buffer_);
Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_,
MS_ASSERT(input_bias_);
MS_ASSERT(state_bias_);
Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_,
reinterpret_cast<float *>(output_hidden_state->data_c()), reinterpret_cast<float *>(output_cell_state->data_c()),
gate_buffer_, state_buffer_, matmul_buffer_, lstm_param_);
state_buffer_, buffer_, lstm_param_);
FreeRunBuffer();
return RET_OK;
}


+ 6
- 4
mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h View File

@@ -41,16 +41,18 @@ class LstmCPUKernel : public LiteKernel {
void FreeRunBuffer();
int InitParam();
int MallocRunBuffer();
int InitWeightBias();
int InitInputWeightBias();
int InitStateWeightBias();

float *gate_buffer_ = nullptr;
float *state_buffer_[2];
float *weight_i_ptr_ = nullptr;
float *weight_h_ptr_ = nullptr;
float *bias_ptr_ = nullptr;
float *matmul_buffer_[2];
float *input_bias_ = nullptr;
float *state_bias_ = nullptr;
float *buffer_[4];
int row_tile_ = 0;
int col_tile_ = 0;
int weight_batch_ = 0;
bool is_vec_ = false;
LstmParameter *lstm_param_ = nullptr;
};


Loading…
Cancel
Save