Browse Source

!31665 [MS][LITE][TOD] LSTM BiDir Support (BWD + FWD) OPS

Merge pull request !31665 from Haim/export_haim
r1.7
i-robot Gitee 4 years ago
parent
commit
02cb6ba0f9
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 582 additions and 375 deletions
  1. +50
    -0
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.c
  2. +6
    -0
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h
  3. +46
    -157
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/lstm_grad_fp32.c
  4. +3
    -3
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/lstm_grad_fp32.h
  5. +2
    -1
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_grad_weight_infer.c
  6. +4
    -4
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_infer.c
  7. +111
    -39
      mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc
  8. +8
    -3
      mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h
  9. +112
    -58
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_data_fp32.cc
  10. +11
    -3
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_data_fp32.h
  11. +106
    -52
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.cc
  12. +15
    -3
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.h
  13. +98
    -46
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_weight_fp32.cc
  14. +7
    -5
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_weight_fp32.h
  15. +3
    -1
      mindspore/lite/src/train/train_populate_parameter.cc

+ 50
- 0
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.c View File

@@ -36,6 +36,37 @@ void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col,
}
}

void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align,
bool is_bidirectional, int stride, const int *order) {
int unidirectional_batch = is_bidirectional ? batch / 2 : batch;
for (int i = 0; i < unidirectional_batch; i++) {
const float *src_batch = src + i * col * deep;
float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align * deep;
#ifdef ENABLE_AVX
RowMajor2Col16Major(src_batch, dst_batch, col, deep);
#elif defined(ENABLE_ARM32)
RowMajor2Col4Major(src_batch, dst_batch, col, deep);
#else
RowMajor2Col8Major(src_batch, dst_batch, col, deep);
#endif
}
src += stride;
dst += unidirectional_batch * col_align * deep;
if (is_bidirectional) {
for (int i = 0; i < unidirectional_batch; i++) {
const float *src_batch = src + i * col * deep;
float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align * deep;
#ifdef ENABLE_AVX
RowMajor2Col16Major(src_batch, dst_batch, col, deep);
#elif defined(ENABLE_ARM32)
RowMajor2Col4Major(src_batch, dst_batch, col, deep);
#else
RowMajor2Col8Major(src_batch, dst_batch, col, deep);
#endif
}
}
}

void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional,
const int *order) {
int unidirectional_batch = is_bidirectional ? batch / 2 : batch;
@@ -55,6 +86,25 @@ void PackLstmBias(float *dst, const float *src, int batch, int col, int col_alig
}
}

void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional,
int b_stride, const int *order) {
int unidirectional_batch = is_bidirectional ? batch / 2 : batch;
for (int i = 0; i < unidirectional_batch; i++) {
const float *src_batch = src + i * col;
float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align;
memcpy(dst_batch, src_batch, col * sizeof(float));
}
if (is_bidirectional) {
const float *backward_src = src + b_stride;
float *backward_dst = dst + unidirectional_batch * col_align;
for (int i = 0; i < unidirectional_batch; i++) {
const float *backward_src_batch = backward_src + i * col;
float *backward_dst_batch = backward_dst + ((order == NULL) ? i : order[i]) * col_align;
memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float));
}
}
}

void PackLstmInput(const float *src, float *dst, int row, int deep) {
#ifdef ENABLE_AVX
RowMajor2Col6Major(src, dst, row, deep);


+ 6
- 0
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h View File

@@ -23,9 +23,15 @@ extern "C" {
#endif
void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int *order);

void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align,
bool is_bidirectional, int stride, const int *order);

void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional,
const int *order);

void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional,
int b_stride, const int *order);

void PackLstmInput(const float *src, float *dst, int row, int deep);

void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align,


+ 46
- 157
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/lstm_grad_fp32.c View File

@@ -84,7 +84,7 @@ int GetGemmMatMullWorkspace(int batch, int input_size, int hidden_size) {
int workspace_size, temp;
// if the appropriate GemmMatNul use beta>0 matSizeTotal must have col as last parameter.
workspace_size = MatSizeTotal(batch, input_size, hidden_size, input_size);
temp = MatSizeTotal(hidden_size, batch, hidden_size, batch);
temp = MatSizeTotal(batch, hidden_size, hidden_size, hidden_size);
workspace_size = (temp > workspace_size) ? temp : workspace_size;
temp = MatSizeTotal(hidden_size, input_size, batch, input_size);
workspace_size = (temp > workspace_size) ? temp : workspace_size;
@@ -94,103 +94,105 @@ int GetGemmMatMullWorkspace(int batch, int input_size, int hidden_size) {
}

int GetRunWorkspaceSize(const LstmGradParameter *lstm_param) {
int workspace_size = no_of_temp_matrices_sized_output_step * lstm_param->output_step_;
int time_stamp_len = lstm_param->batch_ * lstm_param->hidden_size_;
int workspace_size = no_of_temp_matrices_sized_output_step * time_stamp_len;
workspace_size += GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_);
return workspace_size;
}

size_t GetRunWorkspaceGemmOffset(const LstmGradParameter *lstm_param) {
return no_of_temp_matrices_sized_output_step * lstm_param->output_step_;
int time_stamp_len = lstm_param->batch_ * lstm_param->hidden_size_;
return no_of_temp_matrices_sized_output_step * time_stamp_len;
}

void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate,
float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX,
float *weights, float *workspace, const LstmGradParameter *lstm_param) {
float *w, float *v, float *workspace, const LstmGradParameter *lstm_param) {
float *scratchPad = workspace;

float *temp0 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp1 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp2 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp3 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp4 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
int seq_len = lstm_param->batch_ * lstm_param->hidden_size_;
float *temp0 = AllocteFromScrachPad(&scratchPad, seq_len);
float *temp1 = AllocteFromScrachPad(&scratchPad, seq_len);
float *temp2 = AllocteFromScrachPad(&scratchPad, seq_len);
float *temp3 = AllocteFromScrachPad(&scratchPad, seq_len);
float *temp4 = AllocteFromScrachPad(&scratchPad, seq_len);

// Accumulate gradients into dH
ElementAdd(dH, dY, dH, lstm_param->output_step_);
ElementAdd(dH, dY, dH, seq_len);

ElementMul(dH, output_gate, temp1, lstm_param->output_step_);
Tanh(cell_state, lstm_param->output_step_, temp0);
ElementMul(temp0, temp0, temp2, lstm_param->output_step_);
ElementMul(temp1, temp2, temp4, lstm_param->output_step_);
ElementSub(temp1, temp4, temp1, lstm_param->output_step_);
ElementAdd(dC, temp1, dC, lstm_param->output_step_);
ElementMul(dH, output_gate, temp1, seq_len);
Tanh(cell_state, seq_len, temp0);
ElementMul(temp0, temp0, temp2, seq_len);
ElementMul(temp1, temp2, temp4, seq_len);
ElementSub(temp1, temp4, temp1, seq_len);
ElementAdd(dC, temp1, dC, seq_len);

// calculate dI, dO, dF and dG
float *dI = temp1; // dI = dC_{t} * G
ElementMul(dC, cell_gate, dI, lstm_param->output_step_);
ElementMul(dC, cell_gate, dI, seq_len);
float *dO = temp2; // dO = dH * Tanh(C_{t})
ElementMul(dH, temp0, dO, lstm_param->output_step_);
ElementMul(dH, temp0, dO, seq_len);
float *dF = temp3; // dF = dC_{t} * C_{t-1}
ElementMul(dC, prev_cell_state, dF, lstm_param->output_step_);
ElementMul(dC, prev_cell_state, dF, seq_len);
float *dG = temp4; // dG = dC_{t} * I
ElementMul(dC, input_gate, dG, lstm_param->output_step_);
ElementMul(dC, input_gate, dG, seq_len);

// dAi = dI * I * (1 - I)
float *dAi = temp1;
*dA = dAi;
ElementMul(dI, input_gate, dAi, lstm_param->output_step_);
ElementMul(dAi, input_gate, temp0, lstm_param->output_step_);
ElementSub(dAi, temp0, dAi, lstm_param->output_step_);
ElementMul(dI, input_gate, dAi, seq_len);
ElementMul(dAi, input_gate, temp0, seq_len);
ElementSub(dAi, temp0, dAi, seq_len);

// dAo = dO * O * (1 - O)
float *dAo = temp2;
ElementMul(dO, output_gate, dAo, lstm_param->output_step_);
ElementMul(dAo, output_gate, temp0, lstm_param->output_step_);
ElementSub(dAo, temp0, dAo, lstm_param->output_step_);
ElementMul(dO, output_gate, dAo, seq_len);
ElementMul(dAo, output_gate, temp0, seq_len);
ElementSub(dAo, temp0, dAo, seq_len);

// dAf = dF * F * (1 - F)
float *dAf = temp3;
ElementMul(dF, forget_gate, dAf, lstm_param->output_step_);
ElementMul(dAf, forget_gate, temp0, lstm_param->output_step_);
ElementSub(dAf, temp0, dAf, lstm_param->output_step_);
ElementMul(dF, forget_gate, dAf, seq_len);
ElementMul(dAf, forget_gate, temp0, seq_len);
ElementSub(dAf, temp0, dAf, seq_len);

float *dAg = temp4;
ElementMul(cell_gate, cell_gate, temp0, lstm_param->output_step_);
ElementMul(dG, temp0, temp0, lstm_param->output_step_);
ElementSub(dG, temp0, dAg, lstm_param->output_step_);
ElementMul(cell_gate, cell_gate, temp0, seq_len);
ElementMul(dG, temp0, temp0, seq_len);
ElementSub(dG, temp0, dAg, seq_len);

// calculate dX
size_t dX_size = lstm_param->batch_ * lstm_param->input_size_ * sizeof(float);
memset(dX, 0, dX_size);
float *mat_workspace = AllocteFromScrachPad(
&scratchPad, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_));
float *weights_loop = weights;
float *weights_loop = w;
float *dA_loop = dAi; // dAi, dAo, dAf, dAg
for (int idx = 0; idx < num_of_gates; idx++) {
GemmMatmul(0, 0, lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_, 1.0, dA_loop,
lstm_param->hidden_size_, weights_loop, lstm_param->input_size_, 1.0, dX, lstm_param->input_size_,
mat_workspace);
weights_loop += lstm_param->hidden_size_ * lstm_param->input_size_;
dA_loop += lstm_param->output_step_;
dA_loop += seq_len;
}

// calculate dH next
size_t dH_size = lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float);
memset(dH, 0, dH_size);
dA_loop = dAi;
weights_loop = v;
for (int idx = 0; idx < num_of_gates; idx++) {
GemmMatmul(0, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, 1.0, dA_loop,
lstm_param->hidden_size_, weights_loop, lstm_param->hidden_size_, 1.0, dH, lstm_param->hidden_size_,
mat_workspace);
weights_loop += lstm_param->hidden_size_ * lstm_param->hidden_size_;
dA_loop += lstm_param->output_step_;
dA_loop += seq_len;
}
// calculate dC next
ElementMul(dC, forget_gate, dC, lstm_param->output_step_);
ElementMul(dC, forget_gate, dC, seq_len);
}

void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *workspace,
const LstmGradParameter *lstm_param) {
void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *dV, float *dB,
float *workspace, const LstmGradParameter *lstm_param) {
// Calc dWi, dWo, dWf, dWg, dVi, dVo, dVf, dVg, dBi, dBo, dBf, dBg
int seq_len = lstm_param->batch_ * lstm_param->hidden_size_;
float *mat_workspace = AllocteFromScrachPad(
&workspace, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_));
float *dA_loop = dA; // dAi, dAo, dAf, dAg
@@ -198,10 +200,10 @@ void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, f
int dV_size = lstm_param->hidden_size_ * lstm_param->hidden_size_;
int dB_size = 0;
float *dW_loop = dW;
float *dV_loop = dW + (num_of_gates * dW_size);
float *dV_loop = dV;
float *dB_loop = 0;
if (lstm_param->has_bias_) {
dB_loop = dW + (num_of_gates * (dW_size + dV_size));
dB_loop = dB;
dB_size = lstm_param->hidden_size_;
}

@@ -218,120 +220,7 @@ void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, f
if (dB_loop != 0) {
sumCols(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dA_loop, dB_loop, true);
}
dA_loop += lstm_param->output_step_;
dW_loop += dW_size;
dV_loop += dV_size;
dB_loop += dB_size;
}
}

void LstmGradDoStep(const float *output_gate, float *cell_state, float *cell_state_minus1, float *cell_gate,
float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float *dX, float *weights,
float *dW, float *hidden_state, float *input_t, float *workspace,
const LstmGradParameter *lstm_param) {
float *workspace_i = workspace;

float buffer[1024];
float *scratchPad = buffer;

float *tanh_c = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp2 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *tanh_c_sqr = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);

// Accumulate gradients into dH
ElementAdd(dH, dY, dH, lstm_param->output_step_);

ElementMul(dH, output_gate, temp2, lstm_param->output_step_);
Tanh(cell_state, lstm_param->output_step_, tanh_c);
ElementMul(tanh_c, tanh_c, tanh_c_sqr, lstm_param->output_step_);
ElementMul(temp2, tanh_c_sqr, temp, lstm_param->output_step_);
ElementSub(temp2, temp, temp2, lstm_param->output_step_);
ElementAdd(dC, temp2, dC, lstm_param->output_step_);

// calculate dI, dO, dF and dG
float *dI = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dI = dC_{t} * G
ElementMul(dC, cell_gate, dI, lstm_param->output_step_);
float *dO = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dO = dH * Tanh(C_{t})
ElementMul(dH, tanh_c, dO, lstm_param->output_step_);
float *dF = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dF = dC_{t} * C_{t-1}
ElementMul(dC, cell_state_minus1, dF, lstm_param->output_step_);
float *dG = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dG = dC_{t} * I
ElementMul(dC, input_gate, dG, lstm_param->output_step_);

// dAi = dI * I * (1 - I)
float *dAi = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
ElementMul(dI, input_gate, dAi, lstm_param->output_step_);
ElementMul(dAi, input_gate, temp, lstm_param->output_step_);
ElementSub(dAi, temp, dAi, lstm_param->output_step_);

// dAo = dO * O * (1 - O)
float *dAo = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
ElementMul(dO, output_gate, dAo, lstm_param->output_step_);
ElementMul(dAo, output_gate, temp, lstm_param->output_step_);
ElementSub(dAo, temp, dAo, lstm_param->output_step_);

// dAf = dF * F * (1 - F)
float *dAf = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
ElementMul(dF, forget_gate, dAf, lstm_param->output_step_);
ElementMul(dAf, forget_gate, temp, lstm_param->output_step_);
ElementSub(dAf, temp, dAf, lstm_param->output_step_);

// dAg = dG * (1 - G^2)
float *dAg = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
ElementMul(cell_gate, cell_gate, dAg, lstm_param->output_step_);
ElementMul(dG, dAg, dAg, lstm_param->output_step_);
ElementSub(dG, dAg, dAg, lstm_param->output_step_);

// calculate dX
float *mat_workspace = AllocteFromScrachPad(
&workspace_i, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_));
float *weights_loop = weights;
float *dA_loop = dAi; // dAi, dAo, dAf, dAg
for (int idx = 0; idx < num_of_gates; idx++) {
GemmMatmul(0, 0, lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_, 1.0, dA_loop,
lstm_param->hidden_size_, weights_loop, lstm_param->input_size_, 1.0, dX, lstm_param->input_size_,
mat_workspace);
weights_loop += lstm_param->hidden_size_ * lstm_param->input_size_;
dA_loop += lstm_param->output_step_;
}

// calculate dH next
size_t dH_size = lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float);
memset(dH, 0, dH_size);
dA_loop = dAi;
for (int idx = 0; idx < num_of_gates; idx++) {
GemmMatmul(0, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, 1.0, dA_loop,
lstm_param->hidden_size_, weights_loop, lstm_param->hidden_size_, 1.0, dH, lstm_param->hidden_size_,
mat_workspace);
weights_loop += lstm_param->hidden_size_ * lstm_param->hidden_size_;
dA_loop += lstm_param->output_step_;
}
// calculate dC next
ElementMul(dC, forget_gate, dC, lstm_param->output_step_);

// Calc dWi, dWo, dWf, dWg, dVi, dVo, dVf, dVg, dBi, dBo, dBf, dBg
dA_loop = dAi;
int dW_size = lstm_param->input_size_ * lstm_param->hidden_size_;
int dV_size = lstm_param->hidden_size_ * lstm_param->hidden_size_;
int dB_size = lstm_param->hidden_size_;
float *dW_loop = dW;
float *dV_loop = dW + (num_of_gates * dW_size);
float *dB_loop = dW + (num_of_gates * (dW_size + dV_size));
for (int idx = 0; idx < num_of_gates; idx++) {
// Calc dW
GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->input_size_, lstm_param->batch_, 1.0, dA_loop,
lstm_param->hidden_size_, input_t, lstm_param->input_size_, 1.0, dW_loop, lstm_param->input_size_,
mat_workspace);
// Calc dV
if (hidden_state != 0) {
GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->batch_, 1.0, dA_loop,
lstm_param->hidden_size_, hidden_state, lstm_param->hidden_size_, 1.0, dV_loop,
lstm_param->hidden_size_, mat_workspace);
}
// Clac dB
sumCols(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dA_loop, dB_loop, true);
dA_loop += lstm_param->output_step_;
dA_loop += seq_len;
dW_loop += dW_size;
dV_loop += dV_size;
dB_loop += dB_size;


+ 3
- 3
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/lstm_grad_fp32.h View File

@@ -58,10 +58,10 @@ void ReorderLstmWeights(float *dst, const float *src, int nof_martices, int col,

void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate,
float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX,
float *weights, float *workspace, const LstmGradParameter *lstm_param);
float *w, float *v, float *workspace, const LstmGradParameter *lstm_param);

void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *workspace,
const LstmGradParameter *lstm_param);
void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *dV, float *dB,
float *workspace, const LstmGradParameter *lstm_param);
#ifdef __cplusplus
}
#endif


+ 2
- 1
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_grad_weight_infer.c View File

@@ -51,7 +51,8 @@ int LstmGradWeightInferShape(const TensorC *const *inputs, size_t inputs_size, T
if (has_bias) {
output_shape[0] += C2NUM * gate_size;
}

int dir_mul = (param->bidirectional_) ? C2NUM : C1NUM;
output_shape[0] *= dir_mul;
SetShapeArray(output, output_shape, C3NUM);

return NNACL_OK;


+ 4
- 4
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_infer.c View File

@@ -18,7 +18,7 @@
#include "nnacl/infer/infer_register.h"

static const int num_of_gates = 4;
static const int no_of_recorde_values = 7;
static const int no_of_recorde_values = 6;

int CheckInputShapeValid(const TensorC *const *inputs, const LstmParameter *parameter) {
const TensorC *input = inputs[FIRST_INPUT];
@@ -70,14 +70,14 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
int dir_multiplier = param->bidirectional_ ? 2 : 1;
int out_shape[MAX_SHAPE_SIZE];
size_t out_shape_size = 0;
int hidden_size = 1;
ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_);
if (inputs_size == DIMENSION_4D) { // if input from MINDIR
hidden_size = weight_i->shape_[THIRD_INPUT];
out_shape[THIRD_INPUT] = hidden_size;
out_shape[THIRD_INPUT] = hidden_size * dir_multiplier;
} else {
if (CheckInputShapeValid(inputs, param) != NNACL_OK) {
return NNACL_ERR;
@@ -99,7 +99,7 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
SetShapeArray(output, out_shape, out_shape_size);
int state_shape[MAX_SHAPE_SIZE];
size_t state_shape_size = 0;
int dir_multiplier = param->bidirectional_ ? 2 : 1;
ShapeSet(state_shape, &state_shape_size, input->shape_, input->shape_size_);
state_shape[FIRST_INPUT] = dir_multiplier;
state_shape[THIRD_INPUT] = hidden_size;


+ 111
- 39
mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc View File

@@ -74,6 +74,7 @@ void LstmCPUKernel::FreeRunBuffer() {
if (output_need_packed_) {
ms_context_->allocator->Free(buffer_[avx_state_output_index]);
}
ms_context_->allocator->Free(buffer_[tmp_hidden_output_index]);
}

int LstmCPUKernel::InitInputWeightBias() {
@@ -91,10 +92,16 @@ int LstmCPUKernel::InitInputWeightBias() {
const int *weights_order = (in_tensors_.size() == mindir_input_tensors) ? weights_order_IFOG : nullptr;
auto weight_i = in_tensors_.at(i_index);
auto weight_i_data = reinterpret_cast<float *>(weight_i->data());
CHECK_NULL_RETURN(weight_i_data);
PackLstmWeight(weight_i_ptr_, weight_i_data, weight_batch_, lstm_param_->input_size_, lstm_param_->hidden_size_,
lstm_param_->input_col_align_, weights_order);

CHECK_NULL_RETURN(weight_i_data);
int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_);
int hh_size = (lstm_param_->hidden_size_ * lstm_param_->hidden_size_);
int b_size = (lstm_param_->hidden_size_);
bool has_bias = (weight_batch_ * (cw_size + hh_size) < weight_i->ElementsNum()) ? true : false;
int stride = (gpu_orig_state_) ? gate_num * (cw_size + hh_size) : gate_num * (cw_size);
PackLstmWeightWithStride(weight_i_ptr_, weight_i_data, weight_batch_, lstm_param_->input_size_,
lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_,
stride, weights_order);
// input bias
input_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float)));
if (input_bias_ == nullptr) {
@@ -103,18 +110,19 @@ int LstmCPUKernel::InitInputWeightBias() {
}
memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float));

float *bias_data = nullptr;
int offset = gate_num * lstm_param_->hidden_size_ * (lstm_param_->input_size_ + lstm_param_->hidden_size_);
if (weight_i->ElementsNum() > offset) {
bias_data = weight_i_data + offset;
}
int offset = weight_batch_ * (cw_size + hh_size);
float *bias_data = (has_bias) ? weight_i_data + offset : nullptr;
int dir_mul = lstm_param_->bidirectional_ ? 2 : 1;
int b_stride = (gpu_orig_state_) ? gate_num * (dir_mul * b_size) : gate_num * (b_size);
if (in_tensors_.size() > mindir_input_tensors) {
bias_data = reinterpret_cast<float *>(in_tensors_.at(onnx_bias_index)->data());
}

if (bias_data != nullptr) {
PackLstmBias(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_col_align_,
lstm_param_->bidirectional_, weights_order);
} else {
if (bias_data != nullptr) {
PackLstmBiasWithStride(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_,
lstm_param_->input_col_align_, lstm_param_->bidirectional_, b_stride, weights_order);
}
}
return RET_OK;
}
@@ -124,14 +132,23 @@ int LstmCPUKernel::InitStateWeightBias() {
// state -- row: batch; col: hidden_size
// weight -- row: hidden_size; col: hidden_size, need transpose
// result -- row: batch; col: hidden_size
int weight_i_size = gate_num * lstm_param_->hidden_size_ * lstm_param_->input_size_;
int weight_i_size = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
int h_index = (in_tensors_.size() == mindir_input_tensors) ? combined_weights_index : onnx_weight_h_index;
auto weight_h = in_tensors_.at(h_index);
auto weight_h_data = (reinterpret_cast<float *>(weight_h->data()));

int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_);
int hh_size = (lstm_param_->hidden_size_ * lstm_param_->hidden_size_);
int b_size = (lstm_param_->hidden_size_);
int stride = (gpu_orig_state_) ? gate_num * (cw_size + hh_size) : gate_num * (hh_size);

if (in_tensors_.size() == mindir_input_tensors) {
weight_h_data += weight_i_size;
if (gpu_orig_state_) {
weight_h_data += gate_num * cw_size;
} else {
weight_h_data += weight_i_size;
}
}

CHECK_NULL_RETURN(weight_h_data);
if (!state_is_vec_) {
weight_h_ptr_ = reinterpret_cast<float *>(
@@ -141,8 +158,9 @@ int LstmCPUKernel::InitStateWeightBias() {
return RET_ERROR;
}
const int *weights_order = (in_tensors_.size() == mindir_input_tensors) ? weights_order_IFOG : nullptr;
PackLstmWeight(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_,
lstm_param_->state_col_align_, weights_order);
PackLstmWeightWithStride(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->hidden_size_,
lstm_param_->hidden_size_, lstm_param_->state_col_align_, lstm_param_->bidirectional_,
stride, weights_order);
} else {
#ifdef ENABLE_AVX
weight_h_ptr_ = reinterpret_cast<float *>(
@@ -162,8 +180,8 @@ int LstmCPUKernel::InitStateWeightBias() {
}

// state bias
int weight_h_size = gate_num * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
int bias_size = gate_num * lstm_param_->hidden_size_;
int weight_h_size = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
int bias_size = weight_batch_ * lstm_param_->hidden_size_;
state_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float)));
if (state_bias_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc state_bias_ error.";
@@ -179,9 +197,13 @@ int LstmCPUKernel::InitStateWeightBias() {
lstm_param_->bidirectional_, nullptr);
} else if (weight_h->ElementsNum() - weight_i_size - weight_h_size - C2NUM * bias_size == 0) {
// mindir from device "GPU", secend bias is also present order IFOG
float *state_bias = weight_h_data + weight_h_size + bias_size;
PackLstmBias(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, lstm_param_->state_col_align_,
lstm_param_->bidirectional_, weights_order_IFOG);
int dir_mul = lstm_param_->bidirectional_ ? 2 : 1;
int bias_offset =
(gpu_orig_state_) ? gate_num * ((dir_mul - 1) * cw_size + dir_mul * hh_size + b_size) : weight_h_size + bias_size;
float *state_bias = weight_h_data + bias_offset;
int b_stride = (gpu_orig_state_) ? gate_num * (b_size * C2NUM) : gate_num * b_size;
PackLstmBiasWithStride(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_,
lstm_param_->state_col_align_, lstm_param_->bidirectional_, b_stride, weights_order_IFOG);
}
return RET_OK;
}
@@ -202,11 +224,25 @@ int LstmCPUKernel::InitParam() {
} else {
lstm_param_->hidden_size_ = w_shape.at(SECOND_INPUT) / gate_num;
}

lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
: lstm_param_->batch_ * lstm_param_->hidden_size_;
weight_batch_ = lstm_param_->bidirectional_ ? 2 * gate_num : gate_num;
state_is_vec_ = lstm_param_->batch_ == 1;
// determine FB origin
gpu_orig_state_ = false;
if (in_tensors_.size() == mindir_input_tensors) {
gpu_orig_state_ = gpu_orig_cfg_;
auto weight_t = in_tensors_.at(combined_weights_index);
int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_);
int hh_size = (lstm_param_->hidden_size_ * lstm_param_->hidden_size_);
int b_size = (lstm_param_->hidden_size_);
bool has_bias = (weight_batch_ * (cw_size + hh_size) < weight_t->ElementsNum()) ? true : false;
// if bias exist we can determine the gpu_orig_state_
if (has_bias) {
gpu_orig_state_ =
(weight_batch_ * (cw_size + hh_size + C2NUM * b_size) == weight_t->ElementsNum()) ? true : false;
}
}

#ifdef ENABLE_AVX
row_tile_ = C6NUM;
@@ -279,7 +315,7 @@ int LstmCPUKernel::MallocRunBuffer() {
buffer_[packed_input_index] = reinterpret_cast<float *>(
ms_context_->allocator->Malloc(lstm_param_->input_row_align_ * lstm_param_->input_size_ * sizeof(float)));
if (buffer_[packed_input_index] == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight left matirx error.";
MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight left matrix error.";
return RET_ERROR;
}

@@ -338,6 +374,16 @@ int LstmCPUKernel::MallocRunBuffer() {
}
}
#endif

buffer_[tmp_hidden_output_index] = nullptr;
if (!(in_tensors_.size() > mindir_input_tensors)) {
auto buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float);
buffer_[tmp_hidden_output_index] = reinterpret_cast<float *>(ms_context_->allocator->Malloc(buffer_size));
if (buffer_[tmp_hidden_output_index] == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer for hidden error.";
return RET_ERROR;
}
}
return RET_OK;
}

@@ -367,7 +413,7 @@ int LstmInputMulWeightRun(void *cdata, int task_id, float, float) {

int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, const float *weight_h,
const float *input_bias, const float *state_bias, float *hidden_state,
float *cell_state, bool is_backward) {
float *cell_state, float *intermediate_states, bool is_backward) {
float *gate = buffer_[input_gate_index];
for (int i = 0; i < gate_num; i++) {
weight_loop_ = weight_i + lstm_param_->input_size_ * lstm_param_->input_col_align_ * i;
@@ -383,30 +429,46 @@ int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, cons
float *forget_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * 2;
float *cell_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * 3;
float *output_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_;
float *tmp = buffer_[tmp_hidden_output_index];
int dir_mult = lstm_param_->bidirectional_ ? 2 : 1;
for (int t = 0; t < lstm_param_->seq_len_; t++) {
int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t;
float *input_gate_t = input_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
float *forget_gate_t = forget_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
float *cell_gate_t = cell_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
float *output_ptr = output + real_t * lstm_param_->output_step_;

LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
hidden_state, cell_state, buffer_, lstm_param_);
if (IsTrain() && IsTrainable()) {
// if ONNX
if (in_tensors_.size() > mindir_input_tensors) {
// Sequence, DirMul, Batch, Hidden
float *output_ptr = output + real_t * lstm_param_->output_step_;

LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
hidden_state, cell_state, buffer_, lstm_param_);
} else {
// Sequence, Batch, DirMul, Hidden
LstmStepUnit(tmp, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, hidden_state,
cell_state, buffer_, lstm_param_);
int seq_offset = real_t * lstm_param_->batch_ * dir_mult * lstm_param_->hidden_size_;
for (int b = 0; b < lstm_param_->batch_; b++) {
int batch_offset = b * dir_mult * lstm_param_->hidden_size_;
float *output_ptr = output + seq_offset + batch_offset;
memcpy(output_ptr, tmp + b * lstm_param_->hidden_size_, lstm_param_->hidden_size_ * sizeof(float));
}
}
if (intermediate_states) {
RecordStates(hidden_state, cell_state, input_gate_t, output_gate_t, forget_gate_t, cell_gate_t,
is_backward ? real_t : t);
intermediate_states, real_t);
}
}
return RET_OK;
}

void LstmCPUKernel::RecordStates(float *hidden_state, float *cell_state, float *input_gate, float *output_gate,
float *forget_gate, float *cell_gate, int step) {
float *states = reinterpret_cast<float *>(out_tensors_[out_intermediate_states_index]->data());
float *forget_gate, float *cell_gate, float *intermediate_states, int step) {
float *states = intermediate_states;
auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_;
auto stride = step * state_size;
auto seq_stride = lstm_param_->seq_len_ * state_size;
auto stride = step * lstm_param_->output_step_;
auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_;
memcpy(states + stride, hidden_state, state_size * sizeof(float));
stride += seq_stride;
memcpy(states + stride, cell_state, state_size * sizeof(float));
@@ -425,8 +487,12 @@ int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden
// buffer_[packed_input_index] : store packed input
PackLstmInput(input, buffer_[packed_input_index], lstm_param_->seq_len_ * lstm_param_->batch_,
lstm_param_->input_size_);
auto ret =
LstmUnidirectional(output, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, hidden_state, cell_state, false);
float *intermediate_states = nullptr;
if (IsTrain() && IsTrainable()) {
intermediate_states = reinterpret_cast<float *>(out_tensors_[out_intermediate_states_index]->data());
}
auto ret = LstmUnidirectional(output, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, hidden_state,
cell_state, intermediate_states, false);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Lstm unidirectional calculation error.";
return RET_ERROR;
@@ -441,11 +507,17 @@ int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden
const float *backward_input_bias = input_bias_ + gate_num * lstm_param_->input_col_align_;
const float *backward_state_bias = state_bias_ + gate_num * lstm_param_->state_col_align_;
float *backward_output = output + lstm_param_->batch_ * lstm_param_->hidden_size_;
if (in_tensors_.size() == mindir_input_tensors) {
backward_output = output + lstm_param_->hidden_size_;
}
float *backward_cell_state = cell_state + lstm_param_->batch_ * lstm_param_->hidden_size_;
float *backward_hidden_state = hidden_state + lstm_param_->batch_ * lstm_param_->hidden_size_;

ret = LstmUnidirectional(backward_output, backward_weight_i, backward_weight_h, backward_input_bias,
backward_state_bias, backward_hidden_state, backward_cell_state, true);
if (intermediate_states) {
intermediate_states += lstm_param_->batch_ * lstm_param_->hidden_size_;
}
ret =
LstmUnidirectional(backward_output, backward_weight_i, backward_weight_h, backward_input_bias,
backward_state_bias, backward_hidden_state, backward_cell_state, intermediate_states, true);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Lstm bidirectional calculation error.";
return RET_ERROR;


+ 8
- 3
mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h View File

@@ -47,10 +47,11 @@ class LstmCPUKernel : public InnerKernel {
int InitStateWeightBias();

int LstmUnidirectional(float *output, const float *weight_i, const float *weight_h, const float *input_bias,
const float *state_bias, float *hidden_state, float *cell_state, bool is_backward);
const float *state_bias, float *hidden_state, float *cell_state, float *intermediate_states,
bool is_backward);
int InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state);
void RecordStates(float *hidden_state, float *cell_state, float *input_gate, float *output_gate, float *forget_gate,
float *cell_gate, int step);
float *cell_gate, float *intermediate_states, int step);
const float *weight_loop_;
const float *bias_loop_;
float *gate_loop_ = nullptr;
@@ -75,7 +76,7 @@ class LstmCPUKernel : public InnerKernel {
int hidden_state_input_index_ = onnx_hidden_state_index;
int cell_state_input_index_ = onnx_cell_state_index;

float *buffer_[7] = {nullptr};
float *buffer_[8] = {nullptr};
const int gate_num = 4;
const int packed_input_index = 0;
const int input_gate_index = 1;
@@ -84,6 +85,7 @@ class LstmCPUKernel : public InnerKernel {
const int cell_state_index = 4;
const int hidden_state_index = 5;
const int avx_state_output_index = 6;
const int tmp_hidden_output_index = 7;
static const int out_intermediate_states_index = 3;
const int weights_order_IFOG[2 * 4] = {0, 2, 3, 1, 4, 6, 7, 5}; // IFGO order to IOFG order

@@ -94,6 +96,9 @@ class LstmCPUKernel : public InnerKernel {
int weight_batch_ = 0;
bool state_is_vec_ = false;
bool output_need_packed_ = false;
// control weight layout
bool gpu_orig_state_ = true;
bool gpu_orig_cfg_ = true;
LstmParameter *lstm_param_ = nullptr;
};
} // namespace mindspore::kernel


+ 112
- 58
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_data_fp32.cc View File

@@ -43,21 +43,11 @@ int LSTMGradDataCPUKernel::ReSize() { return InitParam(); }
int LSTMGradDataCPUKernel::Run() {
auto ret = MallocRunBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "LstmGradDataCPUKernel MallocRunBuffer error.";
MS_LOG(ERROR) << "LSTMGradDataCPUKernel MallocRunBuffer error.";
FreeRunBuffer();
return RET_ERROR;
}

auto output = out_tensors_.at(0);
auto output_ptr = reinterpret_cast<float *>(output->data());
CHECK_NULL_RETURN(output_ptr);

LstmBackpropUnidirectional(output_ptr, false);
FreeRunBuffer();
return RET_OK;
}

int LSTMGradDataCPUKernel::LstmBackpropUnidirectional(float *output, bool is_backward) {
// get input tensors
auto dC_tensor = in_tensors_.at(dC_index);
MS_ASSERT(dC_tensor != nullptr);
@@ -71,8 +61,6 @@ int LSTMGradDataCPUKernel::LstmBackpropUnidirectional(float *output, bool is_bac
MS_ASSERT(intermediate_tensor != nullptr);
auto cell_input_tensor = in_tensors_.at(cell_input_index);
MS_ASSERT(cell_input_tensor != nullptr);

// Get output tensors
auto dX_tensor = out_tensors_.at(dX_out_index);
MS_ASSERT(dX_tensor != nullptr);
auto dH_out_tensor = out_tensors_.at(dH_out_index);
@@ -80,58 +68,119 @@ int LSTMGradDataCPUKernel::LstmBackpropUnidirectional(float *output, bool is_bac
auto dC_out_tensor = out_tensors_.at(dC_out_index);
MS_ASSERT(dC_out_tensor != nullptr);

auto cell_input_data = reinterpret_cast<float *>(cell_input_tensor->data());
auto dh_out = reinterpret_cast<float *>(dH_out_tensor->data());
auto dc_out = reinterpret_cast<float *>(dC_out_tensor->data());
auto intermediate_data = reinterpret_cast<float *>(intermediate_tensor->data());
auto dC = reinterpret_cast<float *>(dC_tensor->data());
auto dH = reinterpret_cast<float *>(dH_tensor->data());
auto dY = reinterpret_cast<float *>(dy_tensor->data());
auto dX = reinterpret_cast<float *>(dX_tensor->data());
auto weights = reinterpret_cast<float *>(weights_tensor->data());

auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_;
auto seq_stride = lstm_param_->seq_len_ * state_size;
float *cell_state = intermediate_data + seq_stride * 1;
float *input_gate = intermediate_data + seq_stride * 2;
float *output_gate = intermediate_data + seq_stride * 3;
float *forget_gate = intermediate_data + seq_stride * 4;
float *cell_gate = intermediate_data + seq_stride * 5;
// reorder weights only from IFGO to IOFG
ReorderLstmWeightGrad(weights_tmp_, weights);
memset(dH, 0, dH_tensor->Size());
memset(dC, 0, dC_tensor->Size());
// Get Tensors Data
int time_stamp_len = lstm_param_->batch_ * lstm_param_->hidden_size_;

weights_ = reinterpret_cast<float *>(weights_tensor->data());
ReorderLstmWeightGrad(weights_tmp_, weights_);

dC_ = reinterpret_cast<float *>(dC_tensor->data());
dH_ = reinterpret_cast<float *>(dH_tensor->data());
dX_ = reinterpret_cast<float *>(dX_tensor->data());
memset(dH_, 0, dH_tensor->Size());
memset(dC_, 0, dC_tensor->Size());
memset(dX_, 0, dX_tensor->Size());

int w_size = lstm_param_->hidden_size_ * lstm_param_->input_size_;
int h_size = lstm_param_->hidden_size_ * lstm_param_->hidden_size_;

float *orig_da = dA_tmp_;
if (lstm_param_->bidirectional_) {
// Adjust pointer to backward cell
cell_input_data_ = reinterpret_cast<float *>(cell_input_tensor->data()) + time_stamp_len;
intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data()) + time_stamp_len;
dC_ = reinterpret_cast<float *>(dC_tensor->data()) + time_stamp_len;
dH_ = reinterpret_cast<float *>(dH_tensor->data()) + time_stamp_len;
dY_ = reinterpret_cast<float *>(dy_tensor->data()) + lstm_param_->hidden_size_;
dA_tmp_ = orig_da + lstm_param_->seq_len_ * num_of_gates * time_stamp_len;
int w_offset = num_of_gates * (w_size + h_size);
int v_offset = weight_batch_ * w_size + num_of_gates * h_size;
float *w = weights_tmp_ + w_offset;
float *v = weights_tmp_ + v_offset;
LstmBackpropUnidirectional(true, w, v);
}
// adjust to forward cell
cell_input_data_ = reinterpret_cast<float *>(cell_input_tensor->data());
intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data());
dC_ = reinterpret_cast<float *>(dC_tensor->data());
dH_ = reinterpret_cast<float *>(dH_tensor->data());
dY_ = reinterpret_cast<float *>(dy_tensor->data());

int w_offset = 0;
int v_offset = num_of_gates * w_size;
float *w = weights_tmp_ + w_offset;
float *v = weights_tmp_ + v_offset;
dA_tmp_ = orig_da;
LstmBackpropUnidirectional(false, w, v);

// setup output tensors
dh_out_ = reinterpret_cast<float *>(dH_out_tensor->data());
dc_out_ = reinterpret_cast<float *>(dC_out_tensor->data());
std::copy(&(dH_[0]), &(dH_[dH_tensor->ElementsNum()]), &(dh_out_[0]));
std::copy(&(dC_[0]), &(dC_[dC_tensor->ElementsNum()]), &(dc_out_[0]));

auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_;
float *cell_state = intermediate_data_ + seq_stride * 1;
std::copy(&(dA_tmp_[0]), &(dA_tmp_[num_of_gates * seq_stride]), &(cell_state[0]));
FreeRunBuffer();
return RET_OK;
}

int LSTMGradDataCPUKernel::LstmBackpropUnidirectional(bool is_backward, float *w, float *v) {
auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_;
int state_len = lstm_param_->batch_ * lstm_param_->hidden_size_;
float *cell_state = intermediate_data_ + seq_stride * 1;
float *input_gate = intermediate_data_ + seq_stride * 2;
float *output_gate = intermediate_data_ + seq_stride * 3;
float *forget_gate = intermediate_data_ + seq_stride * 4;
float *cell_gate = intermediate_data_ + seq_stride * 5;

int dir_mult = lstm_param_->bidirectional_ ? 2 : 1;
int prev_time_stamp_offset = (is_backward) ? 1 : -1;
int first_time_stamp = (is_backward) ? lstm_param_->seq_len_ - 1 : 0;
for (int t = lstm_param_->seq_len_ - 1; t >= 0; t--) {
int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t;
auto stride = real_t * state_size;

auto stride = real_t * lstm_param_->output_step_;
float *curr_cell_state = cell_state + stride;
float *prev_cell_state = (real_t > 0) ? cell_state + (real_t - 1) * state_size : cell_input_data;
float *prev_cell_state = (real_t == first_time_stamp)
? cell_input_data_
: cell_state + (real_t + prev_time_stamp_offset) * lstm_param_->output_step_;
float *curr_input_gate = input_gate + stride;
float *curr_forget_gate = forget_gate + stride;
float *curr_cell_gate = cell_gate + stride;
float *curr_output_gate = output_gate + stride;
float *curr_dx = dX + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *curr_dy = dY + real_t * state_size;

float *curr_dx = dX_ + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
int seq_offset = real_t * lstm_param_->output_step_;
for (int b = 0; b < lstm_param_->batch_; b++) {
int batch_offset = b * dir_mult * lstm_param_->hidden_size_;
float *dy = dY_ + seq_offset + batch_offset;
memcpy(curr_dy_ + b * lstm_param_->hidden_size_, dy, lstm_param_->hidden_size_ * sizeof(float));
}
float *dA = nullptr;
LstmGradDoInputStep(curr_output_gate, curr_cell_state, prev_cell_state, curr_cell_gate, curr_input_gate,
curr_forget_gate, curr_dy, dC, dH, &dA, curr_dx, weights_tmp_, workspace_, lstm_param_);
float *dA_t = dA_tmp_ + t * num_of_gates * lstm_param_->output_step_;
std::copy(&(dA[0]), &(dA[num_of_gates * lstm_param_->output_step_]), &dA_t[0]); // for w grad step
curr_forget_gate, curr_dy_, dC_, dH_, &dA, curr_dx, w, v, workspace_, lstm_param_);
float *dA_t = dA_tmp_ + real_t * num_of_gates * state_len;
std::copy(&(dA[0]), &(dA[num_of_gates * state_len]), &dA_t[0]); // for w grad step
}
std::copy(&(dH[0]), &(dH[state_size]), &(dh_out[0]));
std::copy(&(dC[0]), &(dC[state_size]), &(dc_out[0]));
std::copy(&(dA_tmp_[0]), &(dA_tmp_[num_of_gates * lstm_param_->output_step_ * lstm_param_->seq_len_]),
&(cell_state[0]));
return RET_OK;
}

void LSTMGradDataCPUKernel::ReorderLstmWeightGrad(float *dst, float *src) {
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIFGO());
src += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
dst += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIFGO());
int uni_batch = lstm_param_->bidirectional_ ? weight_batch_ / 2 : weight_batch_;
ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIFGO());
src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_;
dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_;
ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIFGO());
src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
if (lstm_param_->bidirectional_) {
ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIFGO());
src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_;
dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_;
ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIFGO());
src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
}
}

int LSTMGradDataCPUKernel::DoGrad(int thread_id) { return RET_OK; }
@@ -143,11 +192,6 @@ int LSTMGradDataCPUKernel::InitParam() {
lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT);
lstm_param_->batch_ = in_shape.at(SECOND_INPUT);

auto dy = in_tensors_.at(dy_index);
MS_ASSERT(dy != nullptr);
std::vector<int> dy_shape = dy->shape();
lstm_param_->hidden_size_ = dy_shape.at(THIRD_INPUT);

int dir_multiplier = lstm_param_->bidirectional_ ? 2 : 1;
lstm_param_->output_step_ = dir_multiplier * lstm_param_->batch_ * lstm_param_->hidden_size_;
weight_batch_ = dir_multiplier * num_of_gates;
@@ -184,7 +228,7 @@ int LSTMGradDataCPUKernel::InitParam() {

int LSTMGradDataCPUKernel::MallocRunBuffer() {
int workspace_size = GetRunWorkspaceSize(lstm_param_);
if ((workspace_size == 0) || (workspace_size > LSTMGRADDATA_MAX_WORKSPACE_SIZE)) {
if (workspace_size == 0) {
MS_LOG(ERROR) << "LstmGradDataCPUKernel malloc run workspace 0 error.";
return RET_ERROR;
}
@@ -194,7 +238,7 @@ int LSTMGradDataCPUKernel::MallocRunBuffer() {
return RET_ERROR;
}
auto dA_size = num_of_gates * lstm_param_->output_step_ * lstm_param_->seq_len_;
if ((dA_size == 0) || (dA_size > LSTMGRADDATA_MAX_WORKSPACE_SIZE)) {
if (dA_size == 0) {
MS_LOG(ERROR) << "LstmGradDataCPUKernel malloc run dA_tmp size error.";
return RET_ERROR;
}
@@ -210,6 +254,12 @@ int LSTMGradDataCPUKernel::MallocRunBuffer() {
MS_LOG(ERROR) << "LstmGradWeightCPUKernel malloc run weights_tmp_ alloc error.";
return RET_ERROR;
}
int curr_dy_size = lstm_param_->hidden_size_ * lstm_param_->batch_;
curr_dy_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(curr_dy_size * sizeof(float)));
if (curr_dy_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc run curr_dy_ alloc error.";
return RET_ERROR;
}
return RET_OK;
}

@@ -226,6 +276,10 @@ void LSTMGradDataCPUKernel::FreeRunBuffer() {
ms_context_->allocator->Free(weights_tmp_);
weights_tmp_ = nullptr;
}
if (curr_dy_ != nullptr) {
ms_context_->allocator->Free(curr_dy_);
curr_dy_ = nullptr;
}
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTMGradData, LiteKernelCreator<LSTMGradDataCPUKernel>)


+ 11
- 3
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_data_fp32.h View File

@@ -23,8 +23,6 @@

namespace mindspore {
namespace kernel {
constexpr int LSTMGRADDATA_MAX_WORKSPACE_SIZE = 100000;
constexpr int LSTMGRADDATA_MAX_WEIGHTS_SIZE = 100000;
class LSTMGradDataCPUKernel : public InnerKernel {
public:
explicit LSTMGradDataCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
@@ -39,7 +37,7 @@ class LSTMGradDataCPUKernel : public InnerKernel {
int DoGrad(int thread_id);

private:
int LstmBackpropUnidirectional(float *output, bool is_backward);
int LstmBackpropUnidirectional(bool is_backward, float *w, float *v);

void ReorderLstmWeightGrad(float *dst, float *src);
int InitParam();
@@ -71,6 +69,16 @@ class LSTMGradDataCPUKernel : public InnerKernel {
int input_thread_count_ = 0;
int input_thread_stride_ = 0;

float *curr_dy_ = nullptr;
float *weights_ = nullptr;
float *dC_ = nullptr;
float *dH_ = nullptr;
float *dX_ = nullptr;
float *dY_ = nullptr;
float *cell_input_data_ = nullptr;
float *intermediate_data_ = nullptr;
float *dh_out_ = nullptr;
float *dc_out_ = nullptr;
LstmGradParameter *lstm_param_ = nullptr;
};
} // namespace kernel


+ 106
- 52
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.cc View File

@@ -47,12 +47,7 @@ int LSTMGradCPUKernel::Run() {
FreeRunBuffer();
return RET_ERROR;
}
LstmBackpropUnidirectional(false);
FreeRunBuffer();
return RET_OK;
}

int LSTMGradCPUKernel::LstmBackpropUnidirectional(bool is_backward) {
// get input tensors
auto dC_tensor = in_tensors_.at(dC_index);
MS_ASSERT(dC_tensor != nullptr);
@@ -81,58 +76,113 @@ int LSTMGradCPUKernel::LstmBackpropUnidirectional(bool is_backward) {
auto dC_out_tensor = out_tensors_.at(dC_out_index);
MS_ASSERT(dC_out_tensor != nullptr);

auto cell_input_data = reinterpret_cast<float *>(cell_input_tensor->data());
auto hidden_input_data = reinterpret_cast<float *>(hidden_input_tensor->data());
auto dh_out = reinterpret_cast<float *>(dH_out_tensor->data());
auto dc_out = reinterpret_cast<float *>(dC_out_tensor->data());
auto intermediate_data = reinterpret_cast<float *>(intermediate_tensor->data());
auto dC = reinterpret_cast<float *>(dC_tensor->data());
auto dH = reinterpret_cast<float *>(dH_tensor->data());
auto dY = reinterpret_cast<float *>(dy_tensor->data());
auto dW = reinterpret_cast<float *>(dW_tensor->data());
auto dX = reinterpret_cast<float *>(dX_tensor->data());
auto weights = reinterpret_cast<float *>(weights_tensor->data());
auto input = reinterpret_cast<float *>(input_tensor->data());

auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_;
auto seq_stride = lstm_param_->seq_len_ * state_size;
float *hidden_state = intermediate_data;
float *cell_state = intermediate_data + seq_stride * 1;
float *input_gate = intermediate_data + seq_stride * 2;
float *output_gate = intermediate_data + seq_stride * 3;
float *forget_gate = intermediate_data + seq_stride * 4;
float *cell_gate = intermediate_data + seq_stride * 5;
ReorderLstmWeightGrad(weights_tmp_, weights, getLstmOrderIFGO(), false);

memset(dH, 0, dH_tensor->Size());
memset(dC, 0, dC_tensor->Size());

memset(dW_tmp_, 0, dW_tensor->Size()); // dW_tmp is summed in the loop
float *workspace_gemm = workspace_ + GetRunWorkspaceGemmOffset(lstm_param_);
// Get Tensors Data
int time_stamp_len = lstm_param_->batch_ * lstm_param_->hidden_size_;

weights_ = reinterpret_cast<float *>(weights_tensor->data());
ReorderLstmWeightGrad(weights_tmp_, weights_, getLstmOrderIFGO(), false);

dC_ = reinterpret_cast<float *>(dC_tensor->data());
dH_ = reinterpret_cast<float *>(dH_tensor->data());
dX_ = reinterpret_cast<float *>(dX_tensor->data());
input_ = reinterpret_cast<float *>(input_tensor->data());
memset(dH_, 0, dH_tensor->Size());
memset(dC_, 0, dC_tensor->Size());
memset(dX_, 0, dX_tensor->Size());
memset(dW_tmp_, 0, dW_tensor->Size());

if (lstm_param_->bidirectional_) {
// Adjust pointer to backward cell
cell_input_data_ = reinterpret_cast<float *>(cell_input_tensor->data()) + time_stamp_len;
hidden_input_data_ = reinterpret_cast<float *>(hidden_input_tensor->data()) + time_stamp_len;
intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data()) + time_stamp_len;
dC_ = reinterpret_cast<float *>(dC_tensor->data()) + time_stamp_len;
dH_ = reinterpret_cast<float *>(dH_tensor->data()) + time_stamp_len;
dY_ = reinterpret_cast<float *>(dy_tensor->data()) + lstm_param_->hidden_size_;
int w_offset = num_of_gates * lstm_param_->hidden_size_ * lstm_param_->input_size_;
int v_offset = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_ +
num_of_gates * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
int b_offset = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_ +
weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_ +
num_of_gates * lstm_param_->hidden_size_;
float *w = weights_tmp_ + w_offset;
float *v = weights_tmp_ + v_offset;
float *dw = dW_tmp_ + w_offset;
float *dv = dW_tmp_ + v_offset;
float *db = (lstm_param_->has_bias_) ? dW_tmp_ + b_offset : nullptr;
LstmBackpropUnidirectional(true, w, v, dw, dv, db);
}
// adjust to forward cell
cell_input_data_ = reinterpret_cast<float *>(cell_input_tensor->data());
hidden_input_data_ = reinterpret_cast<float *>(hidden_input_tensor->data());
intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data());
dC_ = reinterpret_cast<float *>(dC_tensor->data());
dH_ = reinterpret_cast<float *>(dH_tensor->data());
dY_ = reinterpret_cast<float *>(dy_tensor->data());
int w_offset = 0;
int v_offset = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
int b_offset = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_ +
weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
float *w = weights_tmp_ + w_offset;
float *v = weights_tmp_ + v_offset;
float *dw = dW_tmp_ + w_offset;
float *dv = dW_tmp_ + v_offset;
float *db = (lstm_param_->has_bias_) ? dW_tmp_ + b_offset : nullptr;
LstmBackpropUnidirectional(false, w, v, dw, dv, db);
// setup output tensors
dW_ = reinterpret_cast<float *>(dW_tensor->data());
dh_out_ = reinterpret_cast<float *>(dH_out_tensor->data());
dc_out_ = reinterpret_cast<float *>(dC_out_tensor->data());
std::copy(&(dH_[0]), &(dH_[dH_tensor->ElementsNum()]), &(dh_out_[0]));
std::copy(&(dC_[0]), &(dC_[dC_tensor->ElementsNum()]), &(dc_out_[0]));
ReorderLstmWeightGrad(dW_, dW_tmp_, getLstmOrderIOFG(), lstm_param_->has_bias_);
FreeRunBuffer();
return RET_OK;
}

int LSTMGradCPUKernel::LstmBackpropUnidirectional(bool is_backward, float *w, float *v, float *dw, float *dv,
float *db) {
auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_;
float *hidden_state = intermediate_data_;
float *cell_state = intermediate_data_ + seq_stride * 1;
float *input_gate = intermediate_data_ + seq_stride * 2;
float *output_gate = intermediate_data_ + seq_stride * 3;
float *forget_gate = intermediate_data_ + seq_stride * 4;
float *cell_gate = intermediate_data_ + seq_stride * 5;

float *workspace_gemm = workspace_ + GetRunWorkspaceGemmOffset(lstm_param_);
int dir_mult = lstm_param_->bidirectional_ ? 2 : 1;
int prev_time_stamp_offset = (is_backward) ? 1 : -1;
int first_time_stamp = (is_backward) ? lstm_param_->seq_len_ - 1 : 0;
for (int t = lstm_param_->seq_len_ - 1; t >= 0; t--) {
int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t;
auto stride = real_t * state_size;
auto stride = real_t * lstm_param_->output_step_;

float *prev_hidden_state = (real_t > 0) ? hidden_state + (real_t - 1) * state_size : hidden_input_data;
float *prev_hidden_state = (real_t == first_time_stamp)
? hidden_input_data_
: hidden_state + (real_t + prev_time_stamp_offset) * lstm_param_->output_step_;
float *curr_cell_state = cell_state + stride;
float *prev_cell_state = (real_t > 0) ? cell_state + (real_t - 1) * state_size : cell_input_data;
float *prev_cell_state = (real_t == first_time_stamp)
? cell_input_data_
: cell_state + (real_t + prev_time_stamp_offset) * lstm_param_->output_step_;
float *curr_input_gate = input_gate + stride;
float *curr_forget_gate = forget_gate + stride;
float *curr_cell_gate = cell_gate + stride;
float *curr_output_gate = output_gate + stride;
float *curr_input = input + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *curr_dx = dX + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *curr_dy = dY + real_t * state_size;
float *curr_input = input_ + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *curr_dx = dX_ + real_t * lstm_param_->batch_ * lstm_param_->input_size_;

int seq_offset = real_t * lstm_param_->output_step_;
for (int b = 0; b < lstm_param_->batch_; b++) {
int batch_offset = b * dir_mult * lstm_param_->hidden_size_;
float *dy = dY_ + seq_offset + batch_offset;
memcpy(curr_dy_ + b * lstm_param_->hidden_size_, dy, lstm_param_->hidden_size_ * sizeof(float));
}
float *dA = nullptr;
LstmGradDoInputStep(curr_output_gate, curr_cell_state, prev_cell_state, curr_cell_gate, curr_input_gate,
curr_forget_gate, curr_dy, dC, dH, &dA, curr_dx, weights_tmp_, workspace_, lstm_param_);
LstmGradDoWeightStep(curr_input, prev_hidden_state, dA, dW_tmp_, workspace_gemm, lstm_param_);
curr_forget_gate, curr_dy_, dC_, dH_, &dA, curr_dx, w, v, workspace_, lstm_param_);
LstmGradDoWeightStep(curr_input, prev_hidden_state, dA, dw, dv, db, workspace_gemm, lstm_param_);
}
std::copy(&(dH[0]), &(dH[state_size]), &(dh_out[0]));
std::copy(&(dC[0]), &(dC[state_size]), &(dc_out[0]));
ReorderLstmWeightGrad(dW, dW_tmp_, getLstmOrderIOFG(), lstm_param_->has_bias_);
return RET_OK;
}

@@ -158,11 +208,6 @@ int LSTMGradCPUKernel::InitParam() {
lstm_param_->batch_ = in_shape.at(SECOND_INPUT);
lstm_param_->input_size_ = in_shape.at(THIRD_INPUT);

auto dy = in_tensors_.at(dy_index);
MS_ASSERT(dy != nullptr);
std::vector<int> dy_shape = dy->shape();
lstm_param_->hidden_size_ = dy_shape.at(THIRD_INPUT);

int dir_multiplier = lstm_param_->bidirectional_ ? 2 : 1;
lstm_param_->output_step_ = dir_multiplier * lstm_param_->batch_ * lstm_param_->hidden_size_;
weight_batch_ = dir_multiplier * num_of_gates;
@@ -198,7 +243,7 @@ int LSTMGradCPUKernel::InitParam() {
}
int LSTMGradCPUKernel::MallocRunBuffer() {
int workspace_size = GetRunWorkspaceSize(lstm_param_);
if ((workspace_size == 0) || (workspace_size > LSTMGRAD_MAX_WORKSPACE_SIZE)) {
if (workspace_size == 0) {
MS_LOG(ERROR) << "LstmCPUKernel malloc run workspace 0 error.";
return RET_ERROR;
}
@@ -210,7 +255,7 @@ int LSTMGradCPUKernel::MallocRunBuffer() {
auto dW_tensor = out_tensors_.at(dW_out_index);
MS_ASSERT(dW_tensor != nullptr);
auto dW_size = dW_tensor->Size();
if ((dW_size == 0) || (dW_size > LSTMGRAD_MAX_WEIGHTS_SIZE)) {
if (dW_size == 0) {
MS_LOG(ERROR) << "LstmCPUKernel malloc run dW_tmp size error.";
return RET_ERROR;
}
@@ -226,7 +271,12 @@ int LSTMGradCPUKernel::MallocRunBuffer() {
MS_LOG(ERROR) << "LstmCPUKernel malloc run weights_tmp_ alloc error.";
return RET_ERROR;
}

int curr_dy_size = lstm_param_->hidden_size_ * lstm_param_->batch_;
curr_dy_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(curr_dy_size * sizeof(float)));
if (curr_dy_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc run curr_dy_ alloc error.";
return RET_ERROR;
}
return RET_OK;
}

@@ -243,6 +293,10 @@ void LSTMGradCPUKernel::FreeRunBuffer() {
ms_context_->allocator->Free(weights_tmp_);
weights_tmp_ = nullptr;
}
if (curr_dy_ != nullptr) {
ms_context_->allocator->Free(curr_dy_);
curr_dy_ = nullptr;
}
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTMGrad, LiteKernelCreator<LSTMGradCPUKernel>)


+ 15
- 3
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.h View File

@@ -23,8 +23,6 @@

namespace mindspore {
namespace kernel {
constexpr int LSTMGRAD_MAX_WORKSPACE_SIZE = 100000;
constexpr int LSTMGRAD_MAX_WEIGHTS_SIZE = 100000;
class LSTMGradCPUKernel : public InnerKernel {
public:
explicit LSTMGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
@@ -39,7 +37,7 @@ class LSTMGradCPUKernel : public InnerKernel {
int DoGrad(int thread_id);

private:
int LstmBackpropUnidirectional(bool is_backward);
int LstmBackpropUnidirectional(bool is_backward, float *w, float *v, float *dw, float *dv, float *db);

int InitParam();
int MallocRunBuffer();
@@ -73,6 +71,20 @@ class LSTMGradCPUKernel : public InnerKernel {
int input_thread_count_ = 0;
int input_thread_stride_ = 0;

float *cell_input_data_ = nullptr;
float *hidden_input_data_ = nullptr;
float *dh_out_ = nullptr;
float *dc_out_ = nullptr;
float *intermediate_data_ = nullptr;
float *dC_ = nullptr;
float *dH_ = nullptr;
float *dY_ = nullptr;
float *dW_ = nullptr;
float *dX_ = nullptr;
float *weights_ = nullptr;
float *input_ = nullptr;
float *curr_dy_ = nullptr;

LstmGradParameter *lstm_param_ = nullptr;
};
} // namespace kernel


+ 98
- 46
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_weight_fp32.cc View File

@@ -47,59 +47,115 @@ int LSTMGradWeightCPUKernel::Run() {
return RET_ERROR;
}

auto output = out_tensors_.at(0);
auto output_ptr = reinterpret_cast<float *>(output->data());
CHECK_NULL_RETURN(output_ptr);

LstmBackpropUnidirectional(output_ptr, false);
FreeRunBuffer();
return RET_OK;
}

int LSTMGradWeightCPUKernel::LstmBackpropUnidirectional(float *output, bool is_backward) {
auto dW_tensor = out_tensors_.at(dW_out_index);
MS_ASSERT(dW_tensor != nullptr);
auto intermediate_tensor = in_tensors_.at(intermediate_data_index);
MS_ASSERT(intermediate_tensor != nullptr);
auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_;
auto input_tensor = in_tensors_.at(input_index);
MS_ASSERT(input_tensor != nullptr);
auto hidden_input_tensor = in_tensors_.at(hidden_input_index);
MS_ASSERT(hidden_input_tensor != nullptr);
auto intermediate_tensor = in_tensors_.at(intermediate_data_index);
MS_ASSERT(intermediate_tensor != nullptr);

auto intermediate_data = reinterpret_cast<float *>(intermediate_tensor->data());
auto input = reinterpret_cast<float *>(input_tensor->data());
auto dW = reinterpret_cast<float *>(dW_tensor->data());
auto hidden_input_data = reinterpret_cast<float *>(hidden_input_tensor->data());
// Get output tensors
auto dW_tensor = out_tensors_.at(dW_out_index);
MS_ASSERT(dW_tensor != nullptr);

// Get Tensors Data
int time_stamp_len = lstm_param_->batch_ * lstm_param_->hidden_size_;
input_ = reinterpret_cast<float *>(input_tensor->data());
memset(dW_tmp_, 0, dW_tensor->Size());

int w_size = lstm_param_->hidden_size_ * lstm_param_->input_size_;
int h_size = lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
int b_size = lstm_param_->hidden_size_;

if (lstm_param_->bidirectional_) {
// Adjust pointer to backward cell
hidden_input_data_ = reinterpret_cast<float *>(hidden_input_tensor->data()) + time_stamp_len;
intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data());
dA_ = intermediate_data_ + seq_stride * 1 + lstm_param_->seq_len_ * num_of_gates * time_stamp_len;
intermediate_data_ += time_stamp_len;

int w_offset = num_of_gates * (w_size + h_size);
int v_offset = weight_batch_ * w_size + num_of_gates * h_size;
int b_offset = weight_batch_ * (w_size + h_size) + num_of_gates * b_size;
float *dw = dW_tmp_ + w_offset;
float *dv = dW_tmp_ + v_offset;
float *db = nullptr;
if (lstm_param_->has_bias_) {
db = dW_tmp_ + b_offset;
}
LstmBackpropUnidirectional(true, dw, dv, db);
}
// adjust to forward cell
hidden_input_data_ = reinterpret_cast<float *>(hidden_input_tensor->data());
intermediate_data_ = reinterpret_cast<float *>(intermediate_tensor->data());
dA_ = intermediate_data_ + seq_stride * 1;
int w_offset = 0;
int v_offset = num_of_gates * w_size;
int b_offset = weight_batch_ * (w_size + h_size);
float *dw = dW_tmp_ + w_offset;
float *dv = dW_tmp_ + v_offset;
float *db = nullptr;
if (lstm_param_->has_bias_) {
db = dW_tmp_ + b_offset;
}
LstmBackpropUnidirectional(false, dw, dv, db);

auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_;
auto seq_stride = lstm_param_->seq_len_ * state_size;
float *hidden_state = intermediate_data;
float *dA = intermediate_data + seq_stride * 1; // intremidate tensor used to transfer dA data from GradData kernel
// setup output tensors
dW_ = reinterpret_cast<float *>(dW_tensor->data());
ReorderLstmWeightGrad(dW_, dW_tmp_, lstm_param_);
FreeRunBuffer();
return RET_OK;
}

int LSTMGradWeightCPUKernel::LstmBackpropUnidirectional(bool is_backward, float *dw, float *dv, float *db) {
float *hidden_state = intermediate_data_;
int state_len = lstm_param_->batch_ * lstm_param_->hidden_size_;
int prev_time_stamp_offset = (is_backward) ? 1 : -1;
int first_time_stamp = (is_backward) ? lstm_param_->seq_len_ - 1 : 0;

memset(dW_tmp_, 0, dW_tensor->Size()); // dW_tmp is summed in the loop
for (int t = lstm_param_->seq_len_ - 1; t >= 0; t--) {
int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t;
float *curr_input = input + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *prev_hidden_state = (real_t > 0) ? hidden_state + (real_t - 1) * state_size : hidden_input_data;
float *curr_da = dA + real_t * num_of_gates * lstm_param_->output_step_;
LstmGradDoWeightStep(curr_input, prev_hidden_state, curr_da, dW_tmp_, workspace_, lstm_param_);
float *prev_hidden_state = (real_t == first_time_stamp)
? hidden_input_data_
: hidden_state + (real_t + prev_time_stamp_offset) * lstm_param_->output_step_;
float *curr_input = input_ + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *curr_da = dA_ + real_t * num_of_gates * state_len;
LstmGradDoWeightStep(curr_input, prev_hidden_state, curr_da, dw, dv, db, workspace_, lstm_param_);
}
ReorderLstmWeightGrad(dW, dW_tmp_, lstm_param_->has_bias_);
return RET_OK;
}

void LSTMGradWeightCPUKernel::ReorderLstmWeightGrad(float *dst, float *src, bool has_bias) {
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIOFG());
src += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
dst += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIOFG());
src += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
dst += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
if (has_bias) {
ReorderLstmWeights(dst, src, weight_batch_, 1, lstm_param_->hidden_size_, getLstmOrderIOFG());
// update secend bias term (only if separate GradData and GradWeight)
dst += weight_batch_ * lstm_param_->hidden_size_;
ReorderLstmWeights(dst, src, weight_batch_, 1, lstm_param_->hidden_size_, getLstmOrderIOFG());
void LSTMGradWeightCPUKernel::ReorderLstmWeightGrad(float *dst, float *src, LstmGradParameter *param) {
int uni_batch = param->bidirectional_ ? weight_batch_ / 2 : weight_batch_;
// 4xWixWh,4xWirxWhr,4xBiBh,4xBirBhr
ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIOFG());
src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_;
dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_;
ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIOFG());
src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
if (param->bidirectional_) {
ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIOFG());
src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_;
dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->input_size_;
ReorderLstmWeights(dst, src, uni_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIOFG());
src += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
dst += uni_batch * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
}
if (param->has_bias_) {
ReorderLstmWeights(dst, src, uni_batch, 1, lstm_param_->hidden_size_, getLstmOrderIOFG());
dst += uni_batch * lstm_param_->hidden_size_;
ReorderLstmWeights(dst, src, uni_batch, 1, lstm_param_->hidden_size_, getLstmOrderIOFG());
dst += uni_batch * lstm_param_->hidden_size_;
src += uni_batch * lstm_param_->hidden_size_;
if (param->bidirectional_) {
ReorderLstmWeights(dst, src, uni_batch, 1, lstm_param_->hidden_size_, getLstmOrderIOFG());
dst += uni_batch * lstm_param_->hidden_size_;
ReorderLstmWeights(dst, src, uni_batch, 1, lstm_param_->hidden_size_, getLstmOrderIOFG());
dst += uni_batch * lstm_param_->hidden_size_;
src += uni_batch * lstm_param_->hidden_size_;
}
}
}

@@ -112,10 +168,6 @@ int LSTMGradWeightCPUKernel::InitParam() {
lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT);
lstm_param_->batch_ = in_shape.at(SECOND_INPUT);
lstm_param_->input_size_ = in_shape.at(THIRD_INPUT);
auto y = in_tensors_.at(y_index);
MS_ASSERT(y != nullptr);
std::vector<int> y_shape = y->shape();
lstm_param_->hidden_size_ = y_shape.at(THIRD_INPUT);

int dir_multiplier = lstm_param_->bidirectional_ ? 2 : 1;
lstm_param_->output_step_ = dir_multiplier * lstm_param_->batch_ * lstm_param_->hidden_size_;
@@ -153,7 +205,7 @@ int LSTMGradWeightCPUKernel::InitParam() {

int LSTMGradWeightCPUKernel::MallocRunBuffer() {
int workspace_size = GetRunWorkspaceSize(lstm_param_);
if ((workspace_size == 0) || (workspace_size > LSTMGRADWEIGHT_MAX_WORKSPACE_SIZE)) {
if (workspace_size == 0) {
MS_LOG(ERROR) << "LstmGradWeightCPUKernel malloc run workspace 0 error.";
return RET_ERROR;
}
@@ -166,7 +218,7 @@ int LSTMGradWeightCPUKernel::MallocRunBuffer() {
auto dW_tensor = out_tensors_.at(dW_out_index);
MS_ASSERT(dW_tensor != nullptr);
auto dW_size = dW_tensor->Size();
if ((dW_size == 0) || (dW_size > LSTMGRADWEIGHT_MAX_WEIGHTS_SIZE)) {
if (dW_size == 0) {
MS_LOG(ERROR) << "LstmGradWeightCPUKernel malloc run dW_tmp size error.";
return RET_ERROR;
}


+ 7
- 5
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_weight_fp32.h View File

@@ -23,8 +23,6 @@

namespace mindspore {
namespace kernel {
constexpr int LSTMGRADWEIGHT_MAX_WORKSPACE_SIZE = 100000;
constexpr int LSTMGRADWEIGHT_MAX_WEIGHTS_SIZE = 100000;
class LSTMGradWeightCPUKernel : public InnerKernel {
public:
explicit LSTMGradWeightCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
@@ -39,12 +37,12 @@ class LSTMGradWeightCPUKernel : public InnerKernel {
int DoGrad(int thread_id);

private:
int LstmBackpropUnidirectional(float *output, bool is_backward);
int LstmBackpropUnidirectional(bool is_backward, float *dw, float *dv, float *db);

int InitParam();
int MallocRunBuffer();
void FreeRunBuffer();
void ReorderLstmWeightGrad(float *dst, float *src, bool has_bias);
void ReorderLstmWeightGrad(float *dst, float *src, LstmGradParameter *param);

static const int input_index = 0;
static const int hidden_input_index = 1;
@@ -65,7 +63,11 @@ class LSTMGradWeightCPUKernel : public InnerKernel {
bool state_is_vec_ = false;
int input_thread_count_ = 0;
int input_thread_stride_ = 0;

float *input_ = nullptr;
float *hidden_input_data_ = nullptr;
float *intermediate_data_ = nullptr;
float *dW_ = nullptr;
float *dA_ = nullptr;
LstmGradParameter *lstm_param_ = nullptr;
};
} // namespace kernel


+ 3
- 1
mindspore/lite/src/train/train_populate_parameter.cc View File

@@ -526,6 +526,7 @@ OpParameter *PopulateLstmGradParameter(const void *prim) {
param->zoneout_hidden_ = value->zoneout_hidden();
param->input_size_ = value->input_size();
param->has_bias_ = value->has_bias();
param->hidden_size_ = value->hidden_size();

return reinterpret_cast<OpParameter *>(param);
}
@@ -552,6 +553,7 @@ OpParameter *PopulateLstmGradDataParameter(const void *prim) {
param->zoneout_hidden_ = value->zoneout_hidden();
param->input_size_ = value->input_size();
param->has_bias_ = value->has_bias();
param->hidden_size_ = value->hidden_size();
return reinterpret_cast<OpParameter *>(param);
}

@@ -573,7 +575,7 @@ OpParameter *PopulateLstmGradWeightParameter(const void *prim) {

param->op_parameter_.type_ = primitive->value_type();
param->input_size_ = value->input_size();
param->hidden_size_ = value->hidden_size(); // output_size
param->hidden_size_ = value->hidden_size();
param->bidirectional_ = value->bidirectional();
param->zoneout_cell_ = value->zoneout_cell();
param->zoneout_hidden_ = value->zoneout_hidden();


Loading…
Cancel
Save