Browse Source

!27157 [MS][LITE][ToD]Add LSTM grad op for backpropagation.

Merge pull request !27157 from amirama/export_amirama
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
e2790c37bd
13 changed files with 769 additions and 75 deletions
  1. +8
    -7
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/lstm_fp32.c
  2. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/lstm_fp32.h
  3. +182
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/lstm_grad_fp32.c
  4. +36
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/lstm_grad_fp32.h
  5. +40
    -18
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/lstm_infer.c
  6. +4
    -4
      mindspore/lite/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc
  7. +4
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc
  8. +93
    -26
      mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc
  9. +22
    -7
      mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h
  10. +279
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.cc
  11. +98
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.h
  12. +0
    -1
      mindspore/lite/tools/converter/CMakeLists.txt
  13. +0
    -6
      mindspore/lite/tools/converter/import/mindspore_importer.cc

+ 8
- 7
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/lstm_fp32.c View File

@@ -22,10 +22,10 @@
#include "nnacl/fp32/matmul_fp32.h"
#include "nnacl/fp32/pack_fp32.h"

void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align) {
void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int *order) {
for (int i = 0; i < batch; i++) {
const float *src_batch = src + i * col * deep;
float *dst_batch = dst + i * col_align * deep;
float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align * deep;
#ifdef ENABLE_AVX
RowMajor2Col16Major(src_batch, dst_batch, col, deep);
#elif defined(ENABLE_ARM32)
@@ -36,19 +36,20 @@ void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col,
}
}

void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional) {
void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional,
const int *order) {
int unidirectional_batch = is_bidirectional ? batch / 2 : batch;
for (int i = 0; i < unidirectional_batch; i++) {
const float *src_batch = src + i * col;
float *dst_batch = dst + i * col_align;
memcpy(dst_batch, src_batch, col * (int)sizeof(float));
float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align;
memcpy(dst_batch, src_batch, col * sizeof(float));
}
if (is_bidirectional) {
const float *backward_src = src + batch * col;
float *backward_dst = dst + unidirectional_batch * col_align;
for (int i = 0; i < unidirectional_batch; i++) {
const float *backward_src_batch = backward_src + i * col;
float *backward_dst_batch = backward_dst + i * col_align;
float *backward_dst_batch = backward_dst + ((order == NULL) ? i : order[i]) * col_align;
memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float));
}
}
@@ -167,7 +168,7 @@ void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight,
weight_i += deep * col;
}
#else
weight_i += deep * col;
weight_i += deep * col_align;
#endif
bias_i += col_align;
gate_i += row * col;


+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/lstm_fp32.h View File

@@ -21,9 +21,10 @@
#ifdef __cplusplus
extern "C" {
#endif
void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align);
void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int *order);

void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional);
void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional,
const int *order);

void PackLstmInput(const float *src, float *dst, int row, int deep);



+ 182
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/lstm_grad_fp32.c View File

@@ -0,0 +1,182 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "nnacl/fp32_grad/lstm_grad_fp32.h"
#include <string.h>
#include <float.h>
#include "nnacl/lstm_parameter.h"
#include "nnacl/fp32/activation_fp32.h"
#include "nnacl/fp32/arithmetic_fp32.h"
#include "nnacl/fp32/matmul_fp32.h"
#include "nnacl/fp32_grad/gemm.h"
#include "nnacl/fp32/lstm_fp32.h"
#include "nnacl/fp32/pack_fp32.h"
#include "nnacl/nnacl_utils.h"

static const int no_of_temp_matrices_sized_output_step = 15;
static const int no_of_temp_matrices_sized_batch_times_seq_len = 8; // 4 dW_ and 4 dV_ matrices
static const int num_of_gates = 4;

static inline float *AllocteFromScrachPad(float **scrach_pad, int size) {
float *buffer = *scrach_pad;
*scrach_pad += size;
return buffer;
}

void PackLstmWeightTranspose(float *dst, const float *src, int batch, int col, int row, int row_align) {
for (int i = 0; i < batch; i++) {
const float *src_batch = src + i * col * row;
float *dst_batch = dst + i * col * row_align;
#ifdef ENABLE_AVX
RowMajor2Row16Major(src_batch, dst_batch, row, col);
#elif defined(ENABLE_ARM32)
RowMajor2Row4Major(src_batch, dst_batch, row, col);
#else
RowMajor2Row8Major(src_batch, dst_batch, row, col);
#endif
}
}

void sumRows(int m, int n, int stride, float *inMat, float *outMat) {
for (int idm = 0; idm < m; idm++) {
float *row = inMat + idm * stride;
*outMat = 0;
for (int idn = 0; idn < n; idn++) {
*outMat += *row++;
}
outMat++;
}
}

int GetGemmMatMullWorkspace(int batch, int seq_len, int hidden_size) {
int workspace_size = MatSizeTotal(batch, seq_len, hidden_size, 0);
int temp = MatSizeTotal(batch, hidden_size, seq_len, 0);
workspace_size = (temp > workspace_size) ? temp : workspace_size;
return workspace_size;
}

int GetRunWorkspaceSize(const LstmParameter *lstm_param) {
int workspace_size = GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->seq_len_, lstm_param->hidden_size_);
workspace_size += no_of_temp_matrices_sized_output_step * lstm_param->output_step_;
workspace_size += no_of_temp_matrices_sized_batch_times_seq_len * lstm_param->batch_ * lstm_param->seq_len_;
return workspace_size;
}

void LstmGradStepUnit(float *packed_input, float *output, float *input_gate, float *forget_gate, float *cell_gate,
float *output_gate, float *hidden_state, float *cell_state, float *dC, float *dH, float *dY,
float *cell_state_minus1, float *weights, float *workspace, const LstmParameter *lstm_param) {
float *workspace_i = workspace;
float *mat_workspace = AllocteFromScrachPad(
&workspace_i, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->seq_len_, lstm_param->hidden_size_));
float *tanh_c = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_);
Tanh(cell_state, lstm_param->output_step_, tanh_c);
float *dO = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dO = dH * Tanh(C_{t+1})
ElementMul(dH, tanh_c, dO, lstm_param->output_step_);
float *dF = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dF = dC_{t+1} * C_t
ElementMul(dC, cell_state_minus1, dF, lstm_param->output_step_);
float *dG = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dG = dC_{t+1} * I
ElementMul(dC, input_gate, dG, lstm_param->output_step_);
float *dI = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dI = dC_{t+1} * G
ElementMul(dC, cell_gate, dI, lstm_param->output_step_);
float *dAg = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dAg = dG * (1 - G^2)
ElementMul(cell_gate, cell_gate, dAg, lstm_param->output_step_);
ElementMul(dG, dAg, dAg, lstm_param->output_step_);
ElementSub(dG, dAg, dAg, lstm_param->output_step_);
float *dAi = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dAi = dI * I * (1 - I)
float *temp = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_);
ElementMul(dI, input_gate, dAi, lstm_param->output_step_);
ElementMul(dAi, input_gate, temp, lstm_param->output_step_);
ElementSub(dAi, temp, dAi, lstm_param->output_step_);
float *dAf = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dAf = dF * F * (1 - F)
ElementMul(dF, forget_gate, dAf, lstm_param->output_step_);
ElementMul(dAf, forget_gate, temp, lstm_param->output_step_);
ElementSub(dAf, temp, dAf, lstm_param->output_step_);
float *dAo = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dAo = dO * O * (1 - O)
ElementMul(dO, output_gate, dAo, lstm_param->output_step_);
ElementMul(dAo, output_gate, temp, lstm_param->output_step_);
ElementSub(dAo, temp, dAo, lstm_param->output_step_);

float *dX = dY;
memset(dX, 0, lstm_param->batch_ * lstm_param->seq_len_ * sizeof(float));
float *weights_loop = weights;
float *dA_loop = dAg; // dAg, dAi, dAf, DAo
for (int idx = 0; idx < num_of_gates; idx++) {
GemmMatmul(0, 1, lstm_param->batch_, lstm_param->seq_len_, lstm_param->hidden_size_, 1.0, dA_loop,
lstm_param->hidden_size_, weights_loop, lstm_param->hidden_size_, 1.0, dX, lstm_param->seq_len_,
mat_workspace);
weights_loop += lstm_param->hidden_size_ * lstm_param->seq_len_;
dA_loop += lstm_param->output_step_;
}
float *dWg = AllocteFromScrachPad(&workspace_i, lstm_param->batch_ * lstm_param->hidden_size_);
GemmMatmul(1, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->seq_len_, 1.0, packed_input,
lstm_param->batch_, dAg, lstm_param->hidden_size_, 0.0, dWg, lstm_param->hidden_size_, mat_workspace);

float *dWi = AllocteFromScrachPad(&workspace_i, lstm_param->batch_ * lstm_param->hidden_size_);
GemmMatmul(1, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->seq_len_, 1.0, packed_input,
lstm_param->batch_, dAi, lstm_param->hidden_size_, 0.0, dWi, lstm_param->hidden_size_, mat_workspace);

float *dWf = AllocteFromScrachPad(&workspace_i, lstm_param->batch_ * lstm_param->hidden_size_);
GemmMatmul(1, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->seq_len_, 1.0, packed_input,
lstm_param->batch_, dAf, lstm_param->hidden_size_, 0.0, dWf, lstm_param->hidden_size_, mat_workspace);

float *dWo = AllocteFromScrachPad(&workspace_i, lstm_param->batch_ * lstm_param->hidden_size_);
GemmMatmul(1, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->seq_len_, 1.0, packed_input,
lstm_param->batch_, dAo, lstm_param->hidden_size_, 0.0, dWo, lstm_param->hidden_size_, mat_workspace);

memset(dH, 0, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
dA_loop = dAg;
for (int idx = 0; idx < num_of_gates; idx++) {
GemmMatmul(0, 1, lstm_param->batch_, lstm_param->seq_len_, lstm_param->hidden_size_, 1.0, dA_loop,
lstm_param->hidden_size_, weights_loop, lstm_param->hidden_size_, 1.0, dH, lstm_param->seq_len_,
mat_workspace);
weights_loop += lstm_param->hidden_size_ * lstm_param->hidden_size_;
dA_loop += lstm_param->output_step_;
}
float *dVg = AllocteFromScrachPad(&workspace_i, lstm_param->batch_ * lstm_param->hidden_size_);
GemmMatmul(1, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->seq_len_, 1.0, hidden_state,
lstm_param->batch_, dAg, lstm_param->hidden_size_, 0.0, dVg, lstm_param->hidden_size_, mat_workspace);

float *dVi = AllocteFromScrachPad(&workspace_i, lstm_param->batch_ * lstm_param->hidden_size_);
GemmMatmul(1, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->seq_len_, 1.0, hidden_state,
lstm_param->batch_, dAi, lstm_param->hidden_size_, 0.0, dVi, lstm_param->hidden_size_, mat_workspace);

float *dVf = AllocteFromScrachPad(&workspace_i, lstm_param->batch_ * lstm_param->hidden_size_);
GemmMatmul(1, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->seq_len_, 1.0, hidden_state,
lstm_param->batch_, dAf, lstm_param->hidden_size_, 0.0, dVf, lstm_param->hidden_size_, mat_workspace);

float *dVo = AllocteFromScrachPad(&workspace_i, lstm_param->batch_ * lstm_param->hidden_size_);
GemmMatmul(1, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->seq_len_, 1.0, hidden_state,
lstm_param->batch_, dAo, lstm_param->hidden_size_, 0.0, dVo, lstm_param->hidden_size_, mat_workspace);

float *dBg = AllocteFromScrachPad(&workspace_i, lstm_param->batch_);
sumRows(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dAg, dBg);
float *dBi = AllocteFromScrachPad(&workspace_i, lstm_param->batch_);
sumRows(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dAi, dBi);
float *dBf = AllocteFromScrachPad(&workspace_i, lstm_param->batch_);
sumRows(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dAf, dBf);
float *dBo = AllocteFromScrachPad(&workspace_i, lstm_param->batch_);
NNACL_ASSERT(workspace_i <= workspace + GetRunWorkspaceSize(lstm_param));
sumRows(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dAo, dBo);

ElementMul(dC, forget_gate, dC, lstm_param->output_step_);
ElementMul(dH, output_gate, temp, lstm_param->output_step_);
ElementAdd(dC, temp, dC, lstm_param->output_step_);

Tanh(cell_state_minus1, lstm_param->output_step_, tanh_c);
ElementMul(tanh_c, tanh_c, tanh_c, lstm_param->output_step_);
ElementMul(temp, tanh_c, temp, lstm_param->output_step_);
ElementSub(dC, temp, dC, lstm_param->output_step_);
}

+ 36
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/lstm_grad_fp32.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_NNACL_FP32_GRAD_LSTM_GRAD_H_
#define MINDSPORE_NNACL_FP32_GRAD_LSTM_GRAD_H_

#include "nnacl/lstm_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif

int GetRunWorkspaceSize(const LstmParameter *lstm_param);

void PackLstmWeightTranspose(float *dst, const float *src, int batch, int col, int row, int row_align);

void LstmGradStepUnit(float *packed_input, float *output, float *input_gate, float *forget_gate, float *cell_gate,
float *output_gate, float *hidden_state, float *cell_state, float *dC, float *dH, float *dY,
float *last_cell, float *weights, float *workspace, const LstmParameter *lstm_param);

#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_FP32_GRAD_LSTM_GRAD_H_

+ 40
- 18
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/lstm_infer.c View File

@@ -17,6 +17,9 @@
#include "nnacl/infer/lstm_infer.h"
#include "nnacl/infer/infer_register.h"

static const int num_of_gates = 4;
static const int no_of_recorde_values = 7;

int CheckInputShapeValid(const TensorC *const *inputs, const LstmParameter *parameter) {
const TensorC *input = inputs[FIRST_INPUT];
const TensorC *weight_i = inputs[SECOND_INPUT];
@@ -50,7 +53,7 @@ int CheckInputShapeValid(const TensorC *const *inputs, const LstmParameter *para

int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 6, 3);
int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 4, 3);
if (check_ret != NNACL_OK) {
return check_ret;
}
@@ -68,34 +71,53 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
return NNACL_INFER_INVALID;
}

if (CheckInputShapeValid(inputs, param) != NNACL_OK) {
return NNACL_ERR;
}

int hidden_size = weight_i->shape_[1] / 4;
int out_shape[MAX_SHAPE_SIZE];
size_t out_shape_size = 0;
int hidden_size = 1;
ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_);
out_shape[2] = hidden_size;
if (param->bidirectional_) {
int ret = ShapeInsert(out_shape, &out_shape_size, 1, 2);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
if (inputs_size == DIMENSION_4D) { // if input from MINDIR
hidden_size = weight_i->shape_[THIRD_INPUT];
out_shape[THIRD_INPUT] = hidden_size;
} else {
int ret = ShapeInsert(out_shape, &out_shape_size, 1, 1);
if (ret != NNACL_OK) {
if (CheckInputShapeValid(inputs, param) != NNACL_OK) {
return NNACL_ERR;
}
hidden_size = weight_i->shape_[1] / num_of_gates;
out_shape[2] = hidden_size;
if (param->bidirectional_) {
int ret = ShapeInsert(out_shape, &out_shape_size, 1, 2);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
} else {
int ret = ShapeInsert(out_shape, &out_shape_size, 1, 1);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
}
}
SetShapeArray(output, out_shape, out_shape_size);
int state_shape[MAX_SHAPE_SIZE];
size_t state_shape_size = 0;
int dir_multiplier = param->bidirectional_ ? 2 : 1;
ShapeSet(state_shape, &state_shape_size, input->shape_, input->shape_size_);
state_shape[0] = param->bidirectional_ ? 2 : 1;
state_shape[2] = hidden_size;
SetShapeArray(outputs[1], state_shape, state_shape_size);
SetShapeArray(outputs[2], state_shape, state_shape_size);
state_shape[FIRST_INPUT] = dir_multiplier;
state_shape[THIRD_INPUT] = hidden_size;
SetShapeArray(outputs[SECOND_INPUT], state_shape, state_shape_size);
SetShapeArray(outputs[THIRD_INPUT], state_shape, state_shape_size);

if (outputs_size > DIMENSION_4D) {
int intermediate_states_shape[MAX_SHAPE_SIZE];
size_t intermediate_states_shape_size = 1;
int batch_size = input->shape_[SECOND_INPUT];
int seq_len = input->shape_[FIRST_INPUT];
intermediate_states_shape[FIRST_INPUT] = no_of_recorde_values * batch_size * hidden_size * seq_len * dir_multiplier;
SetDataTypeFormat(outputs[FOURTH_INPUT], inputs[FIRST_INPUT]);
SetShapeArray(outputs[FOURTH_INPUT], intermediate_states_shape, intermediate_states_shape_size);

SetDataTypeFormat(outputs[FIFTH_INPUT], inputs[FIRST_INPUT]);
SetShapeArray(outputs[FIFTH_INPUT], state_shape, state_shape_size);
}

return NNACL_OK;
}


+ 4
- 4
mindspore/lite/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc View File

@@ -38,7 +38,7 @@ int LstmFP32Coder::InitInputWeightBias(CoderContext *const context) {
init_code.CodeMallocExpression(weight_i_ptr_, weight_i_size);
init_code.CodeFunction("memset", weight_i_ptr_, 0, weight_i_size);
init_code.CodeFunction("PackLstmWeight", weight_i_ptr_, weight_i, weight_batch_, lstm_param_->input_size_,
lstm_param_->hidden_size_, lstm_param_->input_col_align_);
lstm_param_->hidden_size_, lstm_param_->input_col_align_, "NULL");

Tensor *bias_i = input_tensors_.at(kInputSize2);
MS_CHECK_PTR(bias_i);
@@ -48,7 +48,7 @@ int LstmFP32Coder::InitInputWeightBias(CoderContext *const context) {
init_code.CodeMallocExpression(input_bias_, bias_i_size);
init_code.CodeFunction("memset", input_bias_, 0, bias_i_size);
init_code.CodeFunction("PackLstmBias", input_bias_, bias_i, weight_batch_, lstm_param_->hidden_size_,
lstm_param_->input_col_align_, lstm_param_->bidirectional_);
lstm_param_->input_col_align_, lstm_param_->bidirectional_, "NULL");
context->AppendInitCode(init_code.str());
return RET_OK;
}
@@ -64,7 +64,7 @@ int LstmFP32Coder::InitStateWeightBias(CoderContext *const context) {
init_code.CodeMallocExpression(weight_i_ptr_, weight_h_size);
init_code.CodeFunction("memset", weight_i_ptr_, 0, weight_h_size);
init_code.CodeFunction("PackLstmWeight", weight_h_ptr_, weight_h, weight_batch_, lstm_param_->hidden_size_,
lstm_param_->hidden_size_, lstm_param_->state_col_align_);
lstm_param_->hidden_size_, lstm_param_->state_col_align_, "NULL");
} else {
size_t weight_h_size = weight_h->Size();
weight_h_ptr_ =
@@ -84,7 +84,7 @@ int LstmFP32Coder::InitStateWeightBias(CoderContext *const context) {
std::string state_bias_addr =
allocator_->GetRuntimeAddr(bias_i) + "+" + std::to_string(4 * lstm_param_->hidden_size_);
init_code.CodeFunction("PackLstmBias", state_bias_, state_bias_addr, weight_batch_, lstm_param_->hidden_size_,
lstm_param_->state_col_align_, lstm_param_->bidirectional_);
lstm_param_->state_col_align_, lstm_param_->bidirectional_, "NULL");
context->AppendInitCode(init_code.str());
return RET_OK;
}


+ 4
- 4
mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc View File

@@ -113,7 +113,7 @@ int GruCPUKernel::InitInputWeightBias() {
auto weight_g_data = reinterpret_cast<float *>(weight_g->data());
CHECK_NULL_RETURN(weight_g_data);
PackLstmWeight(weight_g_ptr_, weight_g_data, weight_batch_, gru_param_->input_size_, gru_param_->hidden_size_,
gru_param_->input_col_align_);
gru_param_->input_col_align_, nullptr);

// input bias
input_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * gru_param_->input_col_align_ * sizeof(float)));
@@ -125,7 +125,7 @@ int GruCPUKernel::InitInputWeightBias() {
auto bias_g_data = reinterpret_cast<float *>(in_tensors_.at(bias_index)->data());
CHECK_NULL_RETURN(bias_g_data);
PackLstmBias(input_bias_, bias_g_data, weight_batch_, gru_param_->hidden_size_, gru_param_->input_col_align_,
gru_param_->bidirectional_);
gru_param_->bidirectional_, nullptr);
return RET_OK;
}

@@ -146,7 +146,7 @@ int GruCPUKernel::InitStateWeightBias() {
return RET_ERROR;
}
PackLstmWeight(weight_r_ptr_, weight_r_data, weight_batch_, gru_param_->hidden_size_, gru_param_->hidden_size_,
gru_param_->state_col_align_);
gru_param_->state_col_align_, nullptr);
} else {
weight_r_ptr_ = weight_r_data;
}
@@ -162,7 +162,7 @@ int GruCPUKernel::InitStateWeightBias() {
CHECK_NULL_RETURN(bias_r_data);
auto state_bias = bias_r_data + gate_num * gru_param_->hidden_size_;
PackLstmBias(state_bias_, state_bias, weight_batch_, gru_param_->hidden_size_, gru_param_->state_col_align_,
gru_param_->bidirectional_);
gru_param_->bidirectional_, nullptr);
return RET_OK;
}



+ 93
- 26
mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc View File

@@ -30,8 +30,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LSTM;

namespace mindspore::kernel {
constexpr static int kWorkspaceOutIdx = 4;

void LstmCPUKernel::FreeTmpBuffer() {
if (weight_i_ptr_ != nullptr) {
free(weight_i_ptr_);
@@ -83,17 +81,19 @@ int LstmCPUKernel::InitInputWeightBias() {
// input -- row: seq_len * batch; col: input_size
// weight -- row: hidden_size; col: input_size, need transpose
// result -- row: seq_len * batch; col: hidden_size
auto weight_i = in_tensors_.at(weight_i_index);
weight_i_ptr_ = reinterpret_cast<float *>(
malloc(weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float)));
if (weight_i_ptr_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error.";
return RET_ERROR;
}
int i_index = (in_tensors_.size() == mindir_input_tensors) ? combined_weights_index : onnx_weight_i_index;
const int *weights_order = (in_tensors_.size() == mindir_input_tensors) ? weights_order_IFOG : nullptr;
auto weight_i = in_tensors_.at(i_index);
auto weight_i_data = reinterpret_cast<float *>(weight_i->data());
CHECK_NULL_RETURN(weight_i_data);
PackLstmWeight(weight_i_ptr_, weight_i_data, weight_batch_, lstm_param_->input_size_, lstm_param_->hidden_size_,
lstm_param_->input_col_align_);
lstm_param_->input_col_align_, weights_order);

// input bias
input_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float)));
@@ -102,10 +102,15 @@ int LstmCPUKernel::InitInputWeightBias() {
return RET_ERROR;
}
memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float));
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(bias_index)->data());
float *bias_data =
weight_i_data + gate_num * lstm_param_->hidden_size_ * (lstm_param_->input_size_ + lstm_param_->hidden_size_);
if (in_tensors_.size() > mindir_input_tensors) {
bias_data = reinterpret_cast<float *>(in_tensors_.at(onnx_bias_index)->data());
}

CHECK_NULL_RETURN(bias_data);
PackLstmBias(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_col_align_,
lstm_param_->bidirectional_);
lstm_param_->bidirectional_, weights_order);
return RET_OK;
}

@@ -114,8 +119,15 @@ int LstmCPUKernel::InitStateWeightBias() {
// state -- row: batch; col: hidden_size
// weight -- row: hidden_size; col: hidden_size, need transpose
// result -- row: batch; col: hidden_size
auto weight_h = in_tensors_.at(weight_h_index);
auto weight_h_data = reinterpret_cast<float *>(weight_h->data());
int weight_i_size = gate_num * lstm_param_->hidden_size_ * lstm_param_->input_size_;
// int weight_h_size = gate_num * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
int h_index = (in_tensors_.size() == mindir_input_tensors) ? combined_weights_index : onnx_weight_h_index;
auto weight_h = in_tensors_.at(h_index);
auto weight_h_data = (reinterpret_cast<float *>(weight_h->data()));
if (in_tensors_.size() == mindir_input_tensors) {
weight_h_data += weight_i_size;
}

CHECK_NULL_RETURN(weight_h_data);
if (!state_is_vec_) {
weight_h_ptr_ = reinterpret_cast<float *>(
@@ -124,8 +136,9 @@ int LstmCPUKernel::InitStateWeightBias() {
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ptr_ error.";
return RET_ERROR;
}
const int *weights_order = (in_tensors_.size() == mindir_input_tensors) ? weights_order_IFOG : nullptr;
PackLstmWeight(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_,
lstm_param_->state_col_align_);
lstm_param_->state_col_align_, weights_order);
} else {
#ifdef ENABLE_AVX
weight_h_ptr_ = reinterpret_cast<float *>(
@@ -151,24 +164,33 @@ int LstmCPUKernel::InitStateWeightBias() {
return RET_ERROR;
}
memset(state_bias_, 0, weight_batch_ * lstm_param_->state_col_align_ * sizeof(float));
auto state_bias =
reinterpret_cast<float *>(in_tensors_.at(bias_index)->data()) + gate_num * lstm_param_->hidden_size_;
CHECK_NULL_RETURN(state_bias);
PackLstmBias(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, lstm_param_->state_col_align_,
lstm_param_->bidirectional_);
// if ONNX, secend bias is also present
if (in_tensors_.size() > mindir_input_tensors) {
float *state_bias =
reinterpret_cast<float *>(in_tensors_.at(onnx_bias_index)->data()) + gate_num * lstm_param_->hidden_size_;
CHECK_NULL_RETURN(state_bias);
PackLstmBias(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, lstm_param_->state_col_align_,
lstm_param_->bidirectional_, nullptr);
}
return RET_OK;
}

int LstmCPUKernel::InitParam() {
auto input = in_tensors_.front();
std::vector<int> in_shape = input->shape();
lstm_param_->seq_len_ = in_shape.at(0);
lstm_param_->batch_ = in_shape.at(1);
lstm_param_->input_size_ = in_shape.at(2);
lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT);
lstm_param_->batch_ = in_shape.at(SECOND_INPUT);
lstm_param_->input_size_ = in_shape.at(THIRD_INPUT);

auto weight_i = in_tensors_.at(weight_i_index);
auto weight_i = in_tensors_.at(onnx_weight_i_index);
std::vector<int> w_shape = weight_i->shape();
lstm_param_->hidden_size_ = w_shape.at(1) / gate_num;
if (in_tensors_.size() == mindir_input_tensors) {
hidden_state_input_index_ = mindir_hidden_state_input_index;
cell_state_input_index_ = mindir_cell_state_input_index;
lstm_param_->hidden_size_ = w_shape.at(THIRD_INPUT);
} else {
lstm_param_->hidden_size_ = w_shape.at(SECOND_INPUT) / gate_num;
}

lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
: lstm_param_->batch_ * lstm_param_->hidden_size_;
@@ -214,7 +236,7 @@ int LstmCPUKernel::InitParam() {
}

int LstmCPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_6D);
CHECK_LESS_RETURN(in_tensors_.size(), mindir_input_tensors);
for (size_t i = 0; i < in_tensors_.size(); i++) {
CHECK_NULL_RETURN(in_tensors_.at(i));
}
@@ -371,18 +393,63 @@ int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, cons
float *cell_gate_t = cell_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
float *output_ptr = output + real_t * lstm_param_->output_step_;

if (IsTrain() && IsTrainable()) {
RecordPreState(cell_state, is_backward ? real_t : t);
}
LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
hidden_state, cell_state, buffer_, lstm_param_);
if (IsTrain() && IsTrainable()) {
RecordStates(cell_state, real_t);
RecordStates(hidden_state, cell_state, input_gate_t, output_gate_t, forget_gate_t, cell_gate_t,
is_backward ? real_t : t);
}
}
return RET_OK;
}

void LstmCPUKernel::RecordStates(const float *cell_state, int step) {
float *workspace = reinterpret_cast<float *>(out_tensors_[kWorkspaceOutIdx]->MutableData());
workspace[step] = *cell_state;
void LstmCPUKernel::RecordPreState(float *cell_state_minus1, int step) {
float *states = reinterpret_cast<float *>(out_tensors_[out_intermediate_states_index]->data());
auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_;
auto stride = step * state_size;
auto seq_stride = (lstm_param_->seq_len_ + 1) * state_size;
std::cout << "RECORD pre state. step: " << step << std::endl;
std::cout << "batch: " << lstm_param_->batch_ << "hidden_size" << lstm_param_->hidden_size_
<< " seq_len_: " << lstm_param_->seq_len_ << std::endl;
std::cout << "state_size: " << state_size << " stride: " << stride << " seq_stride " << seq_stride << std::endl;
for (int ix = 0; ix < state_size; ix++) {
std::cout << "index: " << ix << " CS_minus1: " << cell_state_minus1[ix] << std::endl;
}
stride += (no_of_recorde_values - 1) * seq_stride;
memcpy(states + stride, cell_state_minus1, state_size * sizeof(float));
}

void LstmCPUKernel::RecordStates(float *hidden_state, float *cell_state, float *input_gate, float *output_gate,
float *forget_gate, float *cell_gate, int step) {
float *states = reinterpret_cast<float *>(out_tensors_[out_intermediate_states_index]->data());
auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_;
auto stride = step * state_size;
auto seq_stride = lstm_param_->seq_len_ * state_size;
std::cout << "RECORD step: " << step << std::endl;
std::cout << "batch: " << lstm_param_->batch_ << "hidden_size" << lstm_param_->hidden_size_
<< " seq_len_: " << lstm_param_->seq_len_ << std::endl;
std::cout << "state_size: " << state_size << " stride: " << stride << " seq_stride " << seq_stride << std::endl;
for (int ix = 0; ix < state_size; ix++) {
std::cout << "index: " << ix << " HS: " << hidden_state[ix] << " CS: " << cell_state[ix]
<< " IG: " << input_gate[ix] << " OG: " << output_gate[ix] << " FG: " << forget_gate[ix]
<< " CG: " << cell_gate[ix] << std::endl;
}

memcpy(states + stride, hidden_state, state_size * sizeof(float));
stride += seq_stride;
memcpy(states + stride, cell_state, state_size * sizeof(float));
stride += seq_stride;
memcpy(states + stride, input_gate, state_size * sizeof(float));
stride += seq_stride;
memcpy(states + stride, output_gate, state_size * sizeof(float));
stride += seq_stride;
memcpy(states + stride, forget_gate, state_size * sizeof(float));
stride += seq_stride;
memcpy(states + stride, cell_gate, state_size * sizeof(float));
}

int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state) {
@@ -427,9 +494,9 @@ int LstmCPUKernel::Run() {
auto output_ptr = reinterpret_cast<float *>(output->data());
CHECK_NULL_RETURN(output_ptr);

auto hidden_state = in_tensors_.at(4);
auto hidden_state = in_tensors_.at(hidden_state_input_index_);
CHECK_NULL_RETURN(hidden_state->data());
auto cell_state = in_tensors_.at(5);
auto cell_state = in_tensors_.at(cell_state_input_index_);
CHECK_NULL_RETURN(cell_state->data());

auto output_hidden_state = out_tensors_[1];


+ 22
- 7
mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h View File

@@ -49,10 +49,12 @@ class LstmCPUKernel : public InnerKernel {
int LstmUnidirectional(float *output, const float *weight_i, const float *weight_h, const float *input_bias,
const float *state_bias, float *hidden_state, float *cell_state, bool is_backward);
int InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state);
void RecordStates(const float *cell_state, int step);
const float *weight_loop_ = nullptr;
const float *bias_loop_ = nullptr;
float *gate_loop_ = nullptr;
void RecordStates(float *hidden_state, float *cell_state, float *input_gate, float *output_gate, float *forget_gate,
float *cell_gate, int step);
void RecordPreState(float *cell_state_minus1, int step);
const float *weight_loop_;
const float *bias_loop_;
float *gate_loop_;
int input_thread_count_ = 0;
int input_thread_stride_ = 0;

@@ -60,9 +62,20 @@ class LstmCPUKernel : public InnerKernel {
float *weight_h_ptr_ = nullptr;
float *input_bias_ = nullptr;
float *state_bias_ = nullptr;
const int weight_i_index = 1;
const int weight_h_index = 2;
const int bias_index = 3;
// indices of weights when split
const size_t mindir_input_tensors = 4;
const int onnx_weight_i_index = 1;
const int onnx_weight_h_index = 2;
const int onnx_bias_index = 3;
const int onnx_hidden_state_index = 4;
const int onnx_cell_state_index = 5;
// index of combined weightes when combined
const int combined_weights_index = 3;
const int mindir_hidden_state_input_index = 1;
const int mindir_cell_state_input_index = 2;
const int no_of_recorde_values = 7;
int hidden_state_input_index_ = onnx_hidden_state_index;
int cell_state_input_index_ = onnx_cell_state_index;

float *buffer_[7] = {nullptr};
const int gate_num = 4;
@@ -73,6 +86,8 @@ class LstmCPUKernel : public InnerKernel {
const int cell_state_index = 4;
const int hidden_state_index = 5;
const int avx_state_output_index = 6;
static const int out_intermediate_states_index = 3;
const int weights_order_IFOG[2 * 4] = {0, 2, 3, 1, 4, 6, 7, 4}; // IFGO order to IOFG order

int row_tile_ = 0;
int col_tile_ = 0;


+ 279
- 0
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.cc View File

@@ -0,0 +1,279 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.h"
#include <string>
#include <memory>
#include "utils/ms_utils.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "nnacl/fp32/lstm_fp32.h"

namespace mindspore {
namespace kernel {
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LSTMGrad;

int LSTMGradCPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_11D);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int LSTMGradCPUKernel::ReSize() { return InitParam(); }

int LSTMGradCPUKernel::Run() {
auto ret = MallocRunBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "LstmGradCPUKernel MallocRunBuffer error.";
FreeRunBuffer();
return RET_ERROR;
}

PackWeights();
auto output = out_tensors_.at(0);
auto output_ptr = reinterpret_cast<float *>(output->data());
CHECK_NULL_RETURN(output_ptr);

LstmBackpropUnidirectional(output_ptr, false);
FreeRunBuffer();
return RET_OK;
}

int LSTMGradCPUKernel::LstmBackpropUnidirectional(float *output, bool is_backward) {
// auto dy_tensor = in_tensors_.at(dy_index); / * [SeqLen, Batch, insize ] * / AA Temporary
auto dC_tensor = in_tensors_.at(dC_index); /* [1, Batch, hidden_size ] */
auto dH_tensor = in_tensors_.at(dH_index); /* [1, Batch, hidden_size ] */
auto dy_tensor = in_tensors_.at(dy_index);
auto cell_tensor = in_tensors_.at(cell_state_index); /* [1, Batch, hidden_size ] */
auto weights_tensor = in_tensors_.at(weights_index); /* [all weights + biases, 1, 1] */
auto intermediate_tensor = in_tensors_.at(intermediate_data_index);
auto input_tensor = in_tensors_.at(input_index);
MS_ASSERT(dy_tensor != nullptr);
MS_ASSERT(dC_tensor != nullptr);
MS_ASSERT(dH_tensor != nullptr);
MS_ASSERT(cell_tensor != nullptr);
MS_ASSERT(weights_tensor != nullptr);
MS_ASSERT(intermediate_tensor != nullptr);
MS_ASSERT(input_tensor != nullptr);
auto intermediate_data = reinterpret_cast<float *>(intermediate_tensor->data());
auto dC = reinterpret_cast<float *>(dC_tensor->data());
auto dH = reinterpret_cast<float *>(dH_tensor->data());
auto dY = reinterpret_cast<float *>(dy_tensor->data());
auto weights = reinterpret_cast<float *>(weights_tensor->data());
auto last_cell = reinterpret_cast<float *>(cell_tensor->data());
auto input = reinterpret_cast<float *>(input_tensor->data());

auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_;
auto seq_stride = lstm_param_->seq_len_ * state_size;
float *hidden_state = intermediate_data;
float *cell_state = intermediate_data + seq_stride * 1;
float *input_gate = intermediate_data + seq_stride * 2;
float *output_gate = intermediate_data + seq_stride * 3;
float *forget_gate = intermediate_data + seq_stride * 4;
float *cell_gate = intermediate_data + seq_stride * 5;
float *cell_state_minus1 = intermediate_data + seq_stride * 6;
for (int t = lstm_param_->seq_len_ - 1; t >= 0; t--) {
int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t;
auto stride = real_t * state_size;
float *hidden_state_t = hidden_state + stride;
float *cell_state_t = cell_state + stride;
float *input_gate_t = input_gate + stride;
float *forget_gate_t = forget_gate + stride;
float *cell_gate_t = cell_gate + stride;
float *output_gate_t = output_gate + stride;
float *cell_state_minus1_t = cell_state_minus1 + stride;
float *output_ptr = output + real_t * lstm_param_->output_step_;
float *input_ptr = input + real_t * lstm_param_->batch_ * lstm_param_->input_size_;

std::cout << "seq_len: " << lstm_param_->seq_len_ << " t: " << t << " real_t: " << real_t << " cell_state "
<< cell_state[0] << " last_cell " << last_cell[0] << std::endl;
LstmGradStepUnit(input_ptr, output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, hidden_state_t,
cell_state_t, dC, dH, dY, cell_state_minus1_t, weights, workspace_, lstm_param_);
}
return RET_OK;
}

int LSTMGradCPUKernel::DoGrad(int thread_id) { return RET_OK; }

float *LSTMGradCPUKernel::InputWeightPtr() { return reinterpret_cast<float *>(in_tensors_.at(weights_index)->data()); }

float *LSTMGradCPUKernel::StateWeightPtr() {
int weight_i_size = num_of_gates * lstm_param_->hidden_size_ * lstm_param_->input_size_;
return (reinterpret_cast<float *>(in_tensors_.at(weights_index)->data()) + weight_i_size);
}

float *LSTMGradCPUKernel::InputBiasPtr() {
int bias_offset = num_of_gates * lstm_param_->hidden_size_ * (lstm_param_->input_size_ + lstm_param_->hidden_size_);
return (reinterpret_cast<float *>(in_tensors_.at(weights_index)->data()) + bias_offset);
}

float *LSTMGradCPUKernel::StateBiasPtr() {
int bias_offset = num_of_gates * lstm_param_->hidden_size_ * (lstm_param_->input_size_ + lstm_param_->hidden_size_);
bias_offset += (num_of_gates * (num_of_gates + lstm_param_->hidden_size_));
return (reinterpret_cast<float *>(in_tensors_.at(weights_index)->data()) + bias_offset);
}

void LSTMGradCPUKernel::FreeTmpBuffer() {
if (weight_i_ptr_ != nullptr) {
free(weight_i_ptr_);
weight_i_ptr_ = nullptr;
}
if (input_bias_ != nullptr) {
free(input_bias_);
input_bias_ = nullptr;
}
#ifdef ENABLE_AVX
if (weight_h_ptr_ != nullptr) {
free(weight_h_ptr_);
weight_h_ptr_ = nullptr;
}
#else
if (!state_is_vec_) {
if (weight_h_ptr_ != nullptr) {
free(weight_h_ptr_);
weight_h_ptr_ = nullptr;
}
}
#endif
if (state_bias_ != nullptr) {
free(state_bias_);
state_bias_ = nullptr;
}
}

int LSTMGradCPUKernel::InitParam() {
auto input = in_tensors_.front();
MS_ASSERT(input != nullptr);
std::vector<int> in_shape = input->shape();
lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT);
lstm_param_->batch_ = in_shape.at(SECOND_INPUT);
lstm_param_->input_size_ = in_shape.at(THIRD_INPUT);

auto dy = in_tensors_.at(dy_index);
MS_ASSERT(dy != nullptr);
std::vector<int> dy_shape = dy->shape();
lstm_param_->hidden_size_ = dy_shape.at(THIRD_INPUT);

int dir_multiplier = lstm_param_->bidirectional_ ? 2 : 1;
lstm_param_->output_step_ = dir_multiplier * lstm_param_->batch_ * lstm_param_->hidden_size_;
weight_batch_ = dir_multiplier * num_of_gates;
state_is_vec_ = lstm_param_->batch_ == 1;

#ifdef ENABLE_AVX
row_tile_ = C6NUM;
col_tile_ = C16NUM;
#elif defined(ENABLE_ARM32)
row_tile_ = C12NUM;
col_tile_ = C4NUM;
#elif defined(ENABLE_SSE)
row_tile_ = C4NUM;
col_tile_ = C8NUM;
#else
row_tile_ = C12NUM;
col_tile_ = C8NUM;
#endif
lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, row_tile_);
lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, col_tile_);
input_size_align_ = UP_ROUND(lstm_param_->input_size_, row_tile_);
input_thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(lstm_param_->input_col_align_, col_tile_));
input_thread_stride_ = UP_DIV(UP_DIV(lstm_param_->input_col_align_, col_tile_), input_thread_count_);

state_row_tile_ = row_tile_;
state_col_tile_ = col_tile_;

lstm_param_->state_row_align_ = state_is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, state_row_tile_);
lstm_param_->state_col_align_ =
state_is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, state_col_tile_);

return AllocateWeights();
}

int LSTMGradCPUKernel::AllocateWeights() {
if (weight_i_ptr_ == nullptr) {
weight_i_ptr_ = reinterpret_cast<float *>(
malloc(weight_batch_ * lstm_param_->input_size_ * lstm_param_->input_col_align_ * sizeof(float)));
if (weight_i_ptr_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error.";
return RET_ERROR;
}
}
if (input_bias_ == nullptr) {
input_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float)));
if (input_bias_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc input_bias_ error.";
return RET_ERROR;
}
}
if (weight_h_ptr_ == nullptr) {
weight_h_ptr_ = reinterpret_cast<float *>(
malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float)));
if (weight_h_ptr_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ptr_ error.";
return RET_ERROR;
}
}
if (state_bias_ == nullptr) {
state_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float)));
if (state_bias_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc state_bias_ error.";
return RET_ERROR;
}
}
return RET_OK;
}

int LSTMGradCPUKernel::MallocRunBuffer() {
int workspace_size = GetRunWorkspaceSize(lstm_param_);
if ((workspace_size == 0) || (workspace_size > LSTMGRAD_MAX_WORKSPACE_SIZE)) {
MS_LOG(ERROR) << "LstmCPUKernel malloc run workspace 0 error.";
return RET_ERROR;
}
workspace_ = reinterpret_cast<float *>(malloc(workspace_size * sizeof(float)));
if (workspace_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc run workspace error.";
return RET_ERROR;
}
return RET_OK;
}

void LSTMGradCPUKernel::FreeRunBuffer() {
if (workspace_ != nullptr) {
free(workspace_);
workspace_ = nullptr;
}
}

int LSTMGradCPUKernel::PackWeights() {
auto weight_i_data = InputWeightPtr();
CHECK_NULL_RETURN(weight_i_data);
PackLstmWeightTranspose(weight_i_ptr_, weight_i_data, weight_batch_, lstm_param_->hidden_size_,
lstm_param_->input_size_, input_size_align_);

auto weight_h_data = StateWeightPtr();
CHECK_NULL_RETURN(weight_h_data);
PackLstmWeightTranspose(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->hidden_size_,
lstm_param_->hidden_size_, lstm_param_->state_col_align_);
return RET_OK;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTMGrad, LiteKernelCreator<LSTMGradCPUKernel>)
} // namespace kernel
} // namespace mindspore

+ 98
- 0
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.h View File

@@ -0,0 +1,98 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_LSTM_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_LSTM_GRAD_H_

#include <vector>
#include "src/inner_kernel.h"
#include "nnacl/fp32_grad/lstm_grad_fp32.h"

namespace mindspore {
namespace kernel {
constexpr int LSTMGRAD_MAX_WORKSPACE_SIZE = 100000;
class LSTMGradCPUKernel : public InnerKernel {
public:
explicit LSTMGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx), thread_count_(ctx->thread_num_) {
lstm_param_ = reinterpret_cast<LstmParameter *>(op_parameter_);
}
~LSTMGradCPUKernel() { FreeTmpBuffer(); }
int Prepare() override;
int ReSize() override;
int Run() override;
int DoGrad(int thread_id);

private:
int LstmBackpropUnidirectional(float *output, bool is_backward);

int InitParam();
void FreeTmpBuffer();
int MallocRunBuffer();
void FreeRunBuffer();
int InitInputWeightBias();
int InitStateWeightBias();
float *InputWeightPtr();
float *StateWeightPtr();
float *InputBiasPtr();
float *StateBiasPtr();
int AllocateWeights();
int PackWeights();

int thread_count_;
static const int input_index = 0;
static const int cell_state_index = 2;
static const int weights_index = 3;
static const int dy_index = 7;
static const int dH_index = 8;
static const int dC_index = 9;
static const int intermediate_data_index = 10;
static const int num_of_gates = 4;

int input_size_align_ = 1;
float *weight_i_ptr_ = nullptr;
float *weight_h_ptr_ = nullptr;
float *input_bias_ = nullptr;
float *state_bias_ = nullptr;
float *workspace_ = nullptr;

int64_t weight_size_ = 0;
int64_t weight_h_size_ = 0;
int64_t input_size_;
int64_t hidden_size_;
int64_t num_layers_;
int64_t batch_size_;
int64_t seq_len_;
int num_directions_;
bool bidirectional_;
bool has_bias_;
size_t reserve_size_;
int row_tile_ = 0;
int col_tile_ = 0;
int state_row_tile_ = 0;
int state_col_tile_ = 0;
int weight_batch_ = 0;
bool state_is_vec_ = false;
int input_thread_count_ = 0;
int input_thread_stride_ = 0;

LstmParameter *lstm_param_ = nullptr;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_LSTM_GRAD_H_

+ 0
- 1
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -43,7 +43,6 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/parser/conv1d_inout_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/inputs_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/unify_format.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/lstm_adjust_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/import/cast_op_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/import/mindspore_importer.cc
${CMAKE_CURRENT_SOURCE_DIR}/import/primitive_adjust.cc


+ 0
- 6
mindspore/lite/tools/converter/import/mindspore_importer.cc View File

@@ -340,12 +340,6 @@ FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const converter::Flags &
return nullptr;
}

auto lstm_adjust_pass = std::make_shared<opt::LstmAdjustPass>();
MS_CHECK_TRUE_MSG(lstm_adjust_pass != nullptr, nullptr, "lstm_adjust_pass is nullptr.");
if (!lstm_adjust_pass->Run(func_graph)) {
MS_LOG(ERROR) << "Run mindir lstm adjust failed.";
return nullptr;
}
return func_graph;
}
} // namespace mindspore::lite

Loading…
Cancel
Save