Browse Source

LSTM Grad Initial Support

feature/build-system-rewrite
Haim Moushkatel 4 years ago
parent
commit
5137ae1c01
18 changed files with 463 additions and 273 deletions
  1. +190
    -63
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/lstm_grad_fp32.c
  2. +34
    -12
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/lstm_grad_fp32.h
  3. +13
    -6
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_grad_data_infer.c
  4. +1
    -1
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_grad_infer.c
  5. +3
    -3
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_grad_weight_infer.c
  6. +0
    -1
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/strided_slice_grad_infer.c
  7. +10
    -17
      mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc
  8. +1
    -3
      mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h
  9. +62
    -34
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_data_fp32.cc
  10. +6
    -18
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_data_fp32.h
  11. +75
    -50
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.cc
  12. +7
    -7
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.h
  13. +21
    -17
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_weight_fp32.cc
  14. +7
    -20
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_weight_fp32.h
  15. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc
  16. +15
    -11
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc
  17. +16
    -9
      mindspore/lite/src/train/train_populate_parameter.cc
  18. +1
    -0
      mindspore/lite/tools/benchmark_train/net_train.h

+ 190
- 63
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/lstm_grad_fp32.c View File

@@ -26,8 +26,8 @@
#include "nnacl/fp32/pack_fp32.h"
#include "nnacl/nnacl_utils.h"

static const int no_of_temp_matrices_sized_output_step = 10;
static const int num_of_gates = 4;
static const int no_of_temp_matrices_sized_output_step = 5;

static inline float *AllocteFromScrachPad(float **scrach_pad, int size) {
float *buffer = *scrach_pad;
@@ -35,6 +35,13 @@ static inline float *AllocteFromScrachPad(float **scrach_pad, int size) {
return buffer;
}

static const int weights_order_IOFG[2 * 4] = {0, 3, 1, 2, 4, 7, 5, 6}; // IOFG order to IFGO order
static const int weights_order_IFGO[2 * 4] = {0, 2, 3, 1, 4, 6, 7, 5}; // IFGO order to IOFG order

const int *getLstmOrderIOFG(void) { return weights_order_IOFG; }

const int *getLstmOrderIFGO(void) { return weights_order_IFGO; }

void PackLstmWeightTranspose(float *dst, const float *src, int batch, int col, int row, int row_align,
const int *order) {
for (int i = 0; i < batch; i++) {
@@ -86,67 +93,199 @@ int GetGemmMatMullWorkspace(int batch, int input_size, int hidden_size) {
return workspace_size;
}

int GetRunWorkspaceSize(const LstmParameter *lstm_param) {
int GetRunWorkspaceSize(const LstmGradParameter *lstm_param) {
int workspace_size = no_of_temp_matrices_sized_output_step * lstm_param->output_step_;
workspace_size += GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_);
return workspace_size;
}

void LstmGradStepUnitInit(const float *output_gate, float *cell_state, float *dY, float *dC, float *dH,
float *workspace, const LstmParameter *lstm_param) {
int state_size = lstm_param->batch_ * lstm_param->hidden_size_;
memcpy(dH, dY, state_size * sizeof(float));
float *workspace_i = workspace;
float *tanh_c = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_);
float *temp = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_);
ElementMul(dH, output_gate, dC, lstm_param->output_step_);
Tanh(cell_state, lstm_param->output_step_, tanh_c);
ElementMul(tanh_c, tanh_c, tanh_c, lstm_param->output_step_);
ElementMul(dC, tanh_c, temp, lstm_param->output_step_);
ElementSub(dC, temp, dC, lstm_param->output_step_);
size_t GetRunWorkspaceGemmOffset(const LstmGradParameter *lstm_param) {
return no_of_temp_matrices_sized_output_step * lstm_param->output_step_;
}

void LstmGradStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate,
float *hidden_state, float *cell_state, float *dC, float *dH, float *dY, float *dX,
float *cell_state_minus1, float *weights, float *workspace, float *dA,
const LstmParameter *lstm_param) {
float *workspace_i = workspace;
void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate,
float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX,
float *weights, float *workspace, const LstmGradParameter *lstm_param) {
float *scratchPad = workspace;

float *dI = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dI = dC_{t+1} * G
float *temp0 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp1 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp2 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp3 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp4 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);

// Accumulate gradients into dH
ElementAdd(dH, dY, dH, lstm_param->output_step_);

ElementMul(dH, output_gate, temp1, lstm_param->output_step_);
Tanh(cell_state, lstm_param->output_step_, temp0);
ElementMul(temp0, temp0, temp2, lstm_param->output_step_);
ElementMul(temp1, temp2, temp4, lstm_param->output_step_);
ElementSub(temp1, temp4, temp1, lstm_param->output_step_);
ElementAdd(dC, temp1, dC, lstm_param->output_step_);

// calculate dI, dO, dF and dG
float *dI = temp1; // dI = dC_{t} * G
ElementMul(dC, cell_gate, dI, lstm_param->output_step_);
float *tanh_c = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_);
float *dO = temp2; // dO = dH * Tanh(C_{t})
ElementMul(dH, temp0, dO, lstm_param->output_step_);
float *dF = temp3; // dF = dC_{t} * C_{t-1}
ElementMul(dC, prev_cell_state, dF, lstm_param->output_step_);
float *dG = temp4; // dG = dC_{t} * I
ElementMul(dC, input_gate, dG, lstm_param->output_step_);

// dAi = dI * I * (1 - I)
float *dAi = temp1;
*dA = dAi;
ElementMul(dI, input_gate, dAi, lstm_param->output_step_);
ElementMul(dAi, input_gate, temp0, lstm_param->output_step_);
ElementSub(dAi, temp0, dAi, lstm_param->output_step_);

// dAo = dO * O * (1 - O)
float *dAo = temp2;
ElementMul(dO, output_gate, dAo, lstm_param->output_step_);
ElementMul(dAo, output_gate, temp0, lstm_param->output_step_);
ElementSub(dAo, temp0, dAo, lstm_param->output_step_);

// dAf = dF * F * (1 - F)
float *dAf = temp3;
ElementMul(dF, forget_gate, dAf, lstm_param->output_step_);
ElementMul(dAf, forget_gate, temp0, lstm_param->output_step_);
ElementSub(dAf, temp0, dAf, lstm_param->output_step_);

float *dAg = temp4;
ElementMul(cell_gate, cell_gate, temp0, lstm_param->output_step_);
ElementMul(dG, temp0, temp0, lstm_param->output_step_);
ElementSub(dG, temp0, dAg, lstm_param->output_step_);

// calculate dX
size_t dX_size = lstm_param->batch_ * lstm_param->input_size_ * sizeof(float);
memset(dX, 0, dX_size);
float *mat_workspace = AllocteFromScrachPad(
&scratchPad, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_));
float *weights_loop = weights;
float *dA_loop = dAi; // dAi, dAo, dAf, dAg
for (int idx = 0; idx < num_of_gates; idx++) {
GemmMatmul(0, 0, lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_, 1.0, dA_loop,
lstm_param->hidden_size_, weights_loop, lstm_param->input_size_, 1.0, dX, lstm_param->input_size_,
mat_workspace);
weights_loop += lstm_param->hidden_size_ * lstm_param->input_size_;
dA_loop += lstm_param->output_step_;
}

// calculate dH next
size_t dH_size = lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float);
memset(dH, 0, dH_size);
dA_loop = dAi;
for (int idx = 0; idx < num_of_gates; idx++) {
GemmMatmul(0, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, 1.0, dA_loop,
lstm_param->hidden_size_, weights_loop, lstm_param->hidden_size_, 1.0, dH, lstm_param->hidden_size_,
mat_workspace);
weights_loop += lstm_param->hidden_size_ * lstm_param->hidden_size_;
dA_loop += lstm_param->output_step_;
}
// calculate dC next
ElementMul(dC, forget_gate, dC, lstm_param->output_step_);
}

void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *workspace,
const LstmGradParameter *lstm_param) {
// Calc dWi, dWo, dWf, dWg, dVi, dVo, dVf, dVg, dBi, dBo, dBf, dBg
float *mat_workspace = AllocteFromScrachPad(
&workspace, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_));
float *dA_loop = dA; // dAi, dAo, dAf, dAg
int dW_size = lstm_param->input_size_ * lstm_param->hidden_size_;
int dV_size = lstm_param->hidden_size_ * lstm_param->hidden_size_;
int dB_size = 0;
float *dW_loop = dW;
float *dV_loop = dW + (num_of_gates * dW_size);
float *dB_loop = 0;
if (lstm_param->has_bias_) {
dB_loop = dW + (num_of_gates * (dW_size + dV_size));
dB_size = lstm_param->hidden_size_;
}

for (int idx = 0; idx < num_of_gates; idx++) {
// Calc dW
GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->input_size_, lstm_param->batch_, 1.0, dA_loop,
lstm_param->hidden_size_, input_t, lstm_param->input_size_, 1.0, dW_loop, lstm_param->input_size_,
mat_workspace);
// Calc dV
GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->batch_, 1.0, dA_loop,
lstm_param->hidden_size_, prev_hidden_state, lstm_param->hidden_size_, 1.0, dV_loop,
lstm_param->hidden_size_, mat_workspace);
// Clac dB
if (dB_loop != 0) {
sumCols(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dA_loop, dB_loop, true);
}
dA_loop += lstm_param->output_step_;
dW_loop += dW_size;
dV_loop += dV_size;
dB_loop += dB_size;
}
}

void LstmGradDoStep(const float *output_gate, float *cell_state, float *cell_state_minus1, float *cell_gate,
float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float *dX, float *weights,
float *dW, float *hidden_state, float *input_t, float *workspace,
const LstmGradParameter *lstm_param) {
float *workspace_i = workspace;

float buffer[1024];
float *scratchPad = buffer;

float *tanh_c = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *temp2 = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
float *tanh_c_sqr = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);

// Accumulate gradients into dH
ElementAdd(dH, dY, dH, lstm_param->output_step_);

ElementMul(dH, output_gate, temp2, lstm_param->output_step_);
Tanh(cell_state, lstm_param->output_step_, tanh_c);
float *dO = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dO = dH * Tanh(C_{t+1})
ElementMul(tanh_c, tanh_c, tanh_c_sqr, lstm_param->output_step_);
ElementMul(temp2, tanh_c_sqr, temp, lstm_param->output_step_);
ElementSub(temp2, temp, temp2, lstm_param->output_step_);
ElementAdd(dC, temp2, dC, lstm_param->output_step_);

// calculate dI, dO, dF and dG
float *dI = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dI = dC_{t} * G
ElementMul(dC, cell_gate, dI, lstm_param->output_step_);
float *dO = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dO = dH * Tanh(C_{t})
ElementMul(dH, tanh_c, dO, lstm_param->output_step_);
float *dF = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dF = dC_{t+1} * C_t
float *dF = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dF = dC_{t} * C_{t-1}
ElementMul(dC, cell_state_minus1, dF, lstm_param->output_step_);
float *dG = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_); // dG = dC_{t+1} * I
float *dG = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_); // dG = dC_{t} * I
ElementMul(dC, input_gate, dG, lstm_param->output_step_);

float *temp = AllocteFromScrachPad(&workspace_i, lstm_param->output_step_);
float *dAi = AllocteFromScrachPad(&dA, lstm_param->output_step_); // dAi = dI * I * (1 - I)
// dAi = dI * I * (1 - I)
float *dAi = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
ElementMul(dI, input_gate, dAi, lstm_param->output_step_);
ElementMul(dAi, input_gate, temp, lstm_param->output_step_);
ElementSub(dAi, temp, dAi, lstm_param->output_step_);
float *dAo = AllocteFromScrachPad(&dA, lstm_param->output_step_); // dAo = dO * O * (1 - O)

// dAo = dO * O * (1 - O)
float *dAo = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
ElementMul(dO, output_gate, dAo, lstm_param->output_step_);
ElementMul(dAo, output_gate, temp, lstm_param->output_step_);
ElementSub(dAo, temp, dAo, lstm_param->output_step_);
float *dAf = AllocteFromScrachPad(&dA, lstm_param->output_step_); // dAf = dF * F * (1 - F)

// dAf = dF * F * (1 - F)
float *dAf = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
ElementMul(dF, forget_gate, dAf, lstm_param->output_step_);
ElementMul(dAf, forget_gate, temp, lstm_param->output_step_);
ElementSub(dAf, temp, dAf, lstm_param->output_step_);
float *dAg = AllocteFromScrachPad(&dA, lstm_param->output_step_); // dAg = dG * (1 - G^2)

// dAg = dG * (1 - G^2)
float *dAg = AllocteFromScrachPad(&scratchPad, lstm_param->output_step_);
ElementMul(cell_gate, cell_gate, dAg, lstm_param->output_step_);
ElementMul(dG, dAg, dAg, lstm_param->output_step_);
ElementSub(dG, dAg, dAg, lstm_param->output_step_);

// calculate dX
float *mat_workspace = AllocteFromScrachPad(
&workspace_i, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_));

size_t dX_size = lstm_param->batch_ * lstm_param->input_size_ * sizeof(float);
memset(dX, 0, dX_size);

float *weights_loop = weights;
float *dA_loop = dAi; // dAi, dAo, dAf, dAg
for (int idx = 0; idx < num_of_gates; idx++) {
@@ -157,39 +296,22 @@ void LstmGradStepUnit(float *output, float *input_gate, float *forget_gate, floa
dA_loop += lstm_param->output_step_;
}

// calculate dH next
size_t dH_size = lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float);
if (dY != NULL) {
memcpy(dH, dY, dH_size);
output_gate -= lstm_param->output_step_;
} else {
memset(dH, 0, dH_size);
}
memset(dH, 0, dH_size);
dA_loop = dAi;
for (int idx = 0; idx < num_of_gates; idx++) {
GemmMatmul(0, 1, lstm_param->hidden_size_, lstm_param->batch_, lstm_param->hidden_size_, 1.0, weights_loop,
lstm_param->hidden_size_, dA_loop, lstm_param->hidden_size_, 1.0, dH, lstm_param->batch_, mat_workspace);
GemmMatmul(0, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, 1.0, dA_loop,
lstm_param->hidden_size_, weights_loop, lstm_param->hidden_size_, 1.0, dH, lstm_param->hidden_size_,
mat_workspace);
weights_loop += lstm_param->hidden_size_ * lstm_param->hidden_size_;
dA_loop += lstm_param->output_step_;
}

NNACL_ASSERT(workspace_i <= workspace + GetRunWorkspaceSize(lstm_param));

// calculate dC next
ElementMul(dC, forget_gate, dC, lstm_param->output_step_);
ElementMul(dH, output_gate, temp, lstm_param->output_step_);

Tanh(cell_state_minus1, lstm_param->output_step_, tanh_c);
ElementMul(tanh_c, tanh_c, tanh_c, lstm_param->output_step_);
ElementMul(temp, tanh_c, tanh_c, lstm_param->output_step_);
ElementSub(temp, tanh_c, temp, lstm_param->output_step_);
ElementAdd(dC, temp, dC, lstm_param->output_step_);
}

void LstmGradWeightStepUnit(float *input_t, float *hidden_state, float *dA, float *dW, float *workspace,
const LstmParameter *lstm_param) {
// Calc dWi, dWo, dWf, dWg, dVi, dVo, dVf, dVg, dBi, dBo, dBf, dBg
float *dA_loop = dA;
float *mat_workspace = AllocteFromScrachPad(
&workspace, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_));
// Calc dWi, dWo, dWf, dWg, dVi, dVo, dVf, dVg, dBi, dBo, dBf, dBg
dA_loop = dAi;
int dW_size = lstm_param->input_size_ * lstm_param->hidden_size_;
int dV_size = lstm_param->hidden_size_ * lstm_param->hidden_size_;
int dB_size = lstm_param->hidden_size_;
@@ -197,13 +319,18 @@ void LstmGradWeightStepUnit(float *input_t, float *hidden_state, float *dA, floa
float *dV_loop = dW + (num_of_gates * dW_size);
float *dB_loop = dW + (num_of_gates * (dW_size + dV_size));
for (int idx = 0; idx < num_of_gates; idx++) {
// Calc dW
GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->input_size_, lstm_param->batch_, 1.0, dA_loop,
lstm_param->hidden_size_, input_t, lstm_param->input_size_, 1.0, dW_loop, lstm_param->input_size_,
mat_workspace); // Calc dW
GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->batch_, 1.0, dA_loop,
lstm_param->hidden_size_, hidden_state, lstm_param->hidden_size_, 1.0, dV_loop, lstm_param->hidden_size_,
mat_workspace); // Calc dV
sumCols(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dA_loop, dB_loop, true); // Clac dB
mat_workspace);
// Calc dV
if (hidden_state != 0) {
GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->batch_, 1.0, dA_loop,
lstm_param->hidden_size_, hidden_state, lstm_param->hidden_size_, 1.0, dV_loop,
lstm_param->hidden_size_, mat_workspace);
}
// Clac dB
sumCols(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dA_loop, dB_loop, true);
dA_loop += lstm_param->output_step_;
dW_loop += dW_size;
dV_loop += dV_size;


+ 34
- 12
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/lstm_grad_fp32.h View File

@@ -17,29 +17,51 @@
#ifndef MINDSPORE_NNACL_FP32_GRAD_LSTM_GRAD_H_
#define MINDSPORE_NNACL_FP32_GRAD_LSTM_GRAD_H_

#include "nnacl/lstm_parameter.h"
#include "nnacl/op_base.h"

typedef struct LstmGradParameter {
// Primitive parameter
OpParameter op_parameter_;
// shape correlative
int input_size_;
int hidden_size_; // output_size
int seq_len_;
int batch_;
// other parameter
int output_step_;
bool bidirectional_;
float zoneout_cell_;
float zoneout_hidden_;
int input_row_align_;
int input_col_align_;
int state_row_align_;
int state_col_align_;
int has_bias_;
} LstmGradParameter;

#ifdef __cplusplus
extern "C" {
#endif

int GetRunWorkspaceSize(const LstmParameter *lstm_param);
const int *getLstmOrderIOFG(void);

const int *getLstmOrderIFGO(void);

int GetRunWorkspaceSize(const LstmGradParameter *lstm_param);

size_t GetRunWorkspaceGemmOffset(const LstmGradParameter *lstm_param);

void PackLstmWeightTranspose(float *dst, const float *src, int batch, int col, int row, int row_align,
const int *order);

void ReorderLstmWeights(float *dst, const float *src, int nof_martices, int col, int row, const int *order);

void LstmGradStepUnitInit(const float *output_gate, float *cell_state, float *dY, float *dC, float *dH,
float *workspace, const LstmParameter *lstm_param);

void LstmGradStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate,
float *hidden_state, float *cell_state, float *dC, float *dH, float *dY, float *dX,
float *cell_state_minus1, float *weights, float *workspace, float *dA,
const LstmParameter *lstm_param);

void LstmGradWeightStepUnit(float *input_t, float *hidden_state, float *dA, float *dW, float *workspace,
const LstmParameter *lstm_param);
void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate,
float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX,
float *weights, float *workspace, const LstmGradParameter *lstm_param);

void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *workspace,
const LstmGradParameter *lstm_param);
#ifdef __cplusplus
}
#endif


+ 13
- 6
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_grad_data_infer.c View File

@@ -18,6 +18,7 @@
#include "nnacl/infer/infer_register.h"
#include "nnacl/infer/common_infer.h"
#include "nnacl/fp32/lstm_fp32.h"
#include "nnacl/fp32_grad/lstm_grad_fp32.h"

int LstmGradDataInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
@@ -25,25 +26,31 @@ int LstmGradDataInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
if (check_ret != NNACL_OK) {
return check_ret;
}
const TensorC *dY = inputs[SECOND_INPUT];
LstmGradParameter *p = (LstmGradParameter *)parameter;
const TensorC *Y = inputs[SECOND_INPUT];
const TensorC *H = inputs[THIRD_INPUT];
const TensorC *C = inputs[FOURTH_INPUT];
const TensorC *weight = inputs[FIFTH_INPUT];
TensorC *dX = outputs[FIRST_INPUT];

int out_shape[MAX_SHAPE_SIZE];
size_t out_shape_size = 0;

for (int i = 0; i < outputs_size; i++) {
SetDataTypeFormat(outputs[i], dY);
SetDataTypeFormat(outputs[i], Y);
}

if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}

if (dY->shape_size_ != C3NUM || weight->shape_size_ != C3NUM) {
if (Y->shape_size_ != C3NUM || weight->shape_size_ != C3NUM) {
return NNACL_ERR;
}
ShapePush(out_shape, &out_shape_size, Y->shape_[out_shape_size]);
ShapePush(out_shape, &out_shape_size, Y->shape_[out_shape_size]);
ShapePush(out_shape, &out_shape_size, p->input_size_);

SetShapeArray(dX, dY->shape_, dY->shape_size_);
SetShapeArray(outputs[FIRST_INPUT], out_shape, C3NUM);
SetShapeArray(outputs[SECOND_INPUT], H->shape_, H->shape_size_);
SetShapeArray(outputs[THIRD_INPUT], C->shape_, C->shape_size_);



+ 1
- 1
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_grad_infer.c View File

@@ -17,7 +17,7 @@
#include "nnacl/infer/lstm_grad_infer.h"
#include "nnacl/infer/infer_register.h"
#include "nnacl/infer/common_infer.h"
#include "nnacl/fp32/lstm_fp32.h"
#include "nnacl/fp32_grad/lstm_grad_fp32.h"

int LstmGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {


+ 3
- 3
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_grad_weight_infer.c View File

@@ -17,7 +17,7 @@
#include "nnacl/infer/lstm_grad_weight_infer.h"
#include "nnacl/infer/infer_register.h"
#include "nnacl/infer/common_infer.h"
#include "nnacl/fp32/lstm_fp32.h"
#include "nnacl/fp32_grad/lstm_grad_fp32.h"

int LstmGradWeightInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
@@ -42,8 +42,8 @@ int LstmGradWeightInferShape(const TensorC *const *inputs, size_t inputs_size, T
if (input->shape_size_ != C3NUM || H->shape_size_ != C3NUM || Y->shape_size_ != C3NUM) {
return NNACL_ERR;
}
LstmParameter *param = (LstmParameter *)parameter;
bool has_bias = true;
LstmGradParameter *param = (LstmGradParameter *)parameter;
int has_bias = param->has_bias_;
int output_shape[3] = {0, 1, 1};
int gate_size = 4 * param->hidden_size_;
output_shape[0] += gate_size * param->input_size_;


+ 0
- 1
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/strided_slice_grad_infer.c View File

@@ -138,7 +138,6 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size,
if (!inferflag) {
return NNACL_OK;
}

int output_size = inputs[1]->shape_[0];
int output_shape[MAX_SHAPE_SIZE] = {0};
size_t output_shape_size = 0;


+ 10
- 17
mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc View File

@@ -102,15 +102,20 @@ int LstmCPUKernel::InitInputWeightBias() {
return RET_ERROR;
}
memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float));
float *bias_data =
weight_i_data + gate_num * lstm_param_->hidden_size_ * (lstm_param_->input_size_ + lstm_param_->hidden_size_);

float *bias_data = nullptr;
int offset = gate_num * lstm_param_->hidden_size_ * (lstm_param_->input_size_ + lstm_param_->hidden_size_);
if (weight_i->ElementsNum() > offset) {
bias_data = weight_i_data + offset;
}
if (in_tensors_.size() > mindir_input_tensors) {
bias_data = reinterpret_cast<float *>(in_tensors_.at(onnx_bias_index)->data());
}

CHECK_NULL_RETURN(bias_data);
PackLstmBias(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_col_align_,
lstm_param_->bidirectional_, weights_order);
if (bias_data != nullptr) {
PackLstmBias(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_col_align_,
lstm_param_->bidirectional_, weights_order);
}
return RET_OK;
}

@@ -386,9 +391,6 @@ int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, cons
float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
float *output_ptr = output + real_t * lstm_param_->output_step_;

if (IsTrain() && IsTrainable()) {
RecordPreState(cell_state, is_backward ? real_t : t);
}
LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
hidden_state, cell_state, buffer_, lstm_param_);
if (IsTrain() && IsTrainable()) {
@@ -399,15 +401,6 @@ int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, cons
return RET_OK;
}

void LstmCPUKernel::RecordPreState(float *cell_state_minus1, int step) {
float *states = reinterpret_cast<float *>(out_tensors_[out_intermediate_states_index]->data());
auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_;
auto stride = step * state_size;
auto seq_stride = lstm_param_->seq_len_ * state_size;
stride += (no_of_recorde_values - 1) * seq_stride;
memcpy(states + stride, cell_state_minus1, state_size * sizeof(float));
}

void LstmCPUKernel::RecordStates(float *hidden_state, float *cell_state, float *input_gate, float *output_gate,
float *forget_gate, float *cell_gate, int step) {
float *states = reinterpret_cast<float *>(out_tensors_[out_intermediate_states_index]->data());


+ 1
- 3
mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h View File

@@ -51,7 +51,6 @@ class LstmCPUKernel : public InnerKernel {
int InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state);
void RecordStates(float *hidden_state, float *cell_state, float *input_gate, float *output_gate, float *forget_gate,
float *cell_gate, int step);
void RecordPreState(float *cell_state_minus1, int step);
const float *weight_loop_;
const float *bias_loop_;
float *gate_loop_ = nullptr;
@@ -73,7 +72,6 @@ class LstmCPUKernel : public InnerKernel {
const int combined_weights_index = 3;
const int mindir_hidden_state_input_index = 1;
const int mindir_cell_state_input_index = 2;
const int no_of_recorde_values = 7;
int hidden_state_input_index_ = onnx_hidden_state_index;
int cell_state_input_index_ = onnx_cell_state_index;

@@ -87,7 +85,7 @@ class LstmCPUKernel : public InnerKernel {
const int hidden_state_index = 5;
const int avx_state_output_index = 6;
static const int out_intermediate_states_index = 3;
const int weights_order_IFOG[2 * 4] = {0, 2, 3, 1, 4, 6, 7, 4}; // IFGO order to IOFG order
const int weights_order_IFOG[2 * 4] = {0, 2, 3, 1, 4, 6, 7, 5}; // IFGO order to IOFG order

int row_tile_ = 0;
int col_tile_ = 0;


+ 62
- 34
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_data_fp32.cc View File

@@ -16,6 +16,7 @@
#include "src/runtime/kernel/arm/fp32_grad/lstm_grad_data_fp32.h"
#include <string>
#include <memory>
#include <algorithm>
#include "utils/ms_utils.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@@ -57,18 +58,31 @@ int LSTMGradDataCPUKernel::Run() {
}

int LSTMGradDataCPUKernel::LstmBackpropUnidirectional(float *output, bool is_backward) {
auto dC_tensor = in_tensors_.at(dC_index); /* [1, Batch, hidden_size ] */
auto dH_tensor = in_tensors_.at(dH_index); /* [1, Batch, hidden_size ] */
auto dy_tensor = in_tensors_.at(dy_index); /* [SeqLen, Batch, insize ] */
auto dX_tensor = out_tensors_.at(dX_out_index);
auto weights_tensor = in_tensors_.at(weights_index); /* [all weights + biases, 1, 1] */
auto intermediate_tensor = in_tensors_.at(intermediate_data_index);
MS_ASSERT(dy_tensor != nullptr);
// get input tensors
auto dC_tensor = in_tensors_.at(dC_index);
MS_ASSERT(dC_tensor != nullptr);
auto dH_tensor = in_tensors_.at(dH_index);
MS_ASSERT(dH_tensor != nullptr);
MS_ASSERT(dX_tensor != nullptr);
auto dy_tensor = in_tensors_.at(dy_index);
MS_ASSERT(dy_tensor != nullptr);
auto weights_tensor = in_tensors_.at(weights_index);
MS_ASSERT(weights_tensor != nullptr);
auto intermediate_tensor = in_tensors_.at(intermediate_data_index);
MS_ASSERT(intermediate_tensor != nullptr);
auto cell_input_tensor = in_tensors_.at(cell_input_index);
MS_ASSERT(cell_input_tensor != nullptr);

// Get output tensors
auto dX_tensor = out_tensors_.at(dX_out_index);
MS_ASSERT(dX_tensor != nullptr);
auto dH_out_tensor = out_tensors_.at(dH_out_index);
MS_ASSERT(dH_out_tensor != nullptr);
auto dC_out_tensor = out_tensors_.at(dC_out_index);
MS_ASSERT(dC_out_tensor != nullptr);

auto cell_input_data = reinterpret_cast<float *>(cell_input_tensor->data());
auto dh_out = reinterpret_cast<float *>(dH_out_tensor->data());
auto dc_out = reinterpret_cast<float *>(dC_out_tensor->data());
auto intermediate_data = reinterpret_cast<float *>(intermediate_tensor->data());
auto dC = reinterpret_cast<float *>(dC_tensor->data());
auto dH = reinterpret_cast<float *>(dH_tensor->data());
@@ -78,43 +92,48 @@ int LSTMGradDataCPUKernel::LstmBackpropUnidirectional(float *output, bool is_bac

auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_;
auto seq_stride = lstm_param_->seq_len_ * state_size;
float *hidden_state = intermediate_data;
float *cell_state = intermediate_data + seq_stride * 1;
float *input_gate = intermediate_data + seq_stride * 2;
float *output_gate = intermediate_data + seq_stride * 3;
float *forget_gate = intermediate_data + seq_stride * 4;
float *cell_gate = intermediate_data + seq_stride * 5;
float *cell_state_minus1 = intermediate_data + seq_stride * 6;

bool first_time = true;
// reorder weights only from IFGO to IOFG
ReorderLstmWeightGrad(weights_tmp_, weights);
memset(dH, 0, dH_tensor->Size());
memset(dC, 0, dC_tensor->Size());
for (int t = lstm_param_->seq_len_ - 1; t >= 0; t--) {
int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t;
auto stride = real_t * state_size;
float *hidden_state_t = hidden_state + stride;
float *cell_state_t = cell_state + stride;
float *input_gate_t = input_gate + stride;
float *forget_gate_t = forget_gate + stride;
float *cell_gate_t = cell_gate + stride;
float *output_gate_t = output_gate + stride;
float *cell_state_minus1_t = cell_state_minus1 + stride;
float *output_ptr = output + real_t * lstm_param_->output_step_;
float *dX_t = dX + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *dY_t = nullptr;

float *curr_cell_state = cell_state + stride;
float *prev_cell_state = (real_t > 0) ? cell_state + (real_t - 1) * state_size : cell_input_data;
float *curr_input_gate = input_gate + stride;
float *curr_forget_gate = forget_gate + stride;
float *curr_cell_gate = cell_gate + stride;
float *curr_output_gate = output_gate + stride;
float *curr_dx = dX + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *curr_dy = dY + real_t * state_size;

float *dA = nullptr;
LstmGradDoInputStep(curr_output_gate, curr_cell_state, prev_cell_state, curr_cell_gate, curr_input_gate,
curr_forget_gate, curr_dy, dC, dH, &dA, curr_dx, weights_tmp_, workspace_, lstm_param_);
float *dA_t = dA_tmp_ + t * num_of_gates * lstm_param_->output_step_;
if (first_time) {
dY_t = dY + real_t * lstm_param_->batch_ * lstm_param_->hidden_size_;
LstmGradStepUnitInit(output_gate_t, cell_state_t, dY_t, dC, dH, workspace_, lstm_param_);
first_time = false;
}
dY_t = (real_t > 0) ? dY + (real_t - 1) * lstm_param_->batch_ * lstm_param_->hidden_size_ : nullptr;
LstmGradStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, hidden_state_t, cell_state_t,
dC, dH, dY_t, dX_t, cell_state_minus1_t, weights, workspace_, dA_t, lstm_param_);
std::copy(&(dA[0]), &(dA[num_of_gates * lstm_param_->output_step_]), &dA_t[0]); // for w grad step
}
// Copy dA matrices to intermidate data after hidden state. both hidden stae and dA are needed by LstmGradWeight.
memcpy(cell_state, dA_tmp_, sizeof(float) * num_of_gates * lstm_param_->output_step_ * lstm_param_->seq_len_);
std::copy(&(dH[0]), &(dH[state_size]), &(dh_out[0]));
std::copy(&(dC[0]), &(dC[state_size]), &(dc_out[0]));
std::copy(&(dA_tmp_[0]), &(dA_tmp_[num_of_gates * lstm_param_->output_step_ * lstm_param_->seq_len_]),
&(cell_state[0]));
return RET_OK;
}

void LSTMGradDataCPUKernel::ReorderLstmWeightGrad(float *dst, float *src) {
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIFGO());
src += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
dst += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIFGO());
}

int LSTMGradDataCPUKernel::DoGrad(int thread_id) { return RET_OK; }

int LSTMGradDataCPUKernel::InitParam() {
@@ -123,7 +142,6 @@ int LSTMGradDataCPUKernel::InitParam() {
std::vector<int> in_shape = input->shape();
lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT);
lstm_param_->batch_ = in_shape.at(SECOND_INPUT);
lstm_param_->input_size_ = in_shape.at(THIRD_INPUT);

auto dy = in_tensors_.at(dy_index);
MS_ASSERT(dy != nullptr);
@@ -185,7 +203,13 @@ int LSTMGradDataCPUKernel::MallocRunBuffer() {
MS_LOG(ERROR) << "LstmGradDataCPUKernel malloc run dA_tmp alloc error.";
return RET_ERROR;
}

int weights_size = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_ + // IW matrics
weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; // V matrics
weights_tmp_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(weights_size * sizeof(float)));
if (weights_tmp_ == nullptr) {
MS_LOG(ERROR) << "LstmGradWeightCPUKernel malloc run weights_tmp_ alloc error.";
return RET_ERROR;
}
return RET_OK;
}

@@ -198,6 +222,10 @@ void LSTMGradDataCPUKernel::FreeRunBuffer() {
ms_context_->allocator->Free(dA_tmp_);
dA_tmp_ = nullptr;
}
if (weights_tmp_ != nullptr) {
ms_context_->allocator->Free(weights_tmp_);
weights_tmp_ = nullptr;
}
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTMGradData, LiteKernelCreator<LSTMGradDataCPUKernel>)


+ 6
- 18
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_data_fp32.h View File

@@ -29,8 +29,8 @@ class LSTMGradDataCPUKernel : public InnerKernel {
public:
explicit LSTMGradDataCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx), thread_count_(ctx->thread_num_) {
lstm_param_ = reinterpret_cast<LstmParameter *>(op_parameter_);
: InnerKernel(parameter, inputs, outputs, ctx) {
lstm_param_ = reinterpret_cast<LstmGradParameter *>(op_parameter_);
}
~LSTMGradDataCPUKernel() {}
int Prepare() override;
@@ -41,18 +41,16 @@ class LSTMGradDataCPUKernel : public InnerKernel {
private:
int LstmBackpropUnidirectional(float *output, bool is_backward);

void ReorderLstmWeightGrad(float *dst, float *src);
int InitParam();
int MallocRunBuffer();
void FreeRunBuffer();
int InitInputWeightBias();
int InitStateWeightBias();

int thread_count_;
static const int dy_index = 1;
static const int dH_index = 2;
static const int dC_index = 3;
static const int weights_index = 4;
static const int cell_state_index = 6;
static const int cell_input_index = 6;
static const int intermediate_data_index = 7;
static const int dX_out_index = 0;
static const int dH_out_index = 1;
@@ -61,19 +59,9 @@ class LSTMGradDataCPUKernel : public InnerKernel {

int input_size_align_ = 1;
float *dA_tmp_ = nullptr;
float *weights_tmp_ = nullptr;
float *workspace_ = nullptr;

int64_t weight_size_ = 0;
int64_t weight_h_size_ = 0;
int64_t input_size_;
int64_t hidden_size_;
int64_t num_layers_;
int64_t batch_size_;
int64_t seq_len_;
int num_directions_;
bool bidirectional_;
bool has_bias_;
size_t reserve_size_;
int row_tile_ = 0;
int col_tile_ = 0;
int state_row_tile_ = 0;
@@ -83,7 +71,7 @@ class LSTMGradDataCPUKernel : public InnerKernel {
int input_thread_count_ = 0;
int input_thread_stride_ = 0;

LstmParameter *lstm_param_ = nullptr;
LstmGradParameter *lstm_param_ = nullptr;
};
} // namespace kernel
} // namespace mindspore


+ 75
- 50
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.cc View File

@@ -16,6 +16,7 @@
#include "src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.h"
#include <string>
#include <memory>
#include <algorithm>
#include "utils/ms_utils.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@@ -46,33 +47,44 @@ int LSTMGradCPUKernel::Run() {
FreeRunBuffer();
return RET_ERROR;
}

auto output = out_tensors_.at(0);
auto output_ptr = reinterpret_cast<float *>(output->data());
CHECK_NULL_RETURN(output_ptr);

LstmBackpropUnidirectional(output_ptr, false);
LstmBackpropUnidirectional(false);
FreeRunBuffer();
return RET_OK;
}

int LSTMGradCPUKernel::LstmBackpropUnidirectional(float *output, bool is_backward) {
auto dC_tensor = in_tensors_.at(dC_index); /* [1, Batch, hidden_size ] */
auto dH_tensor = in_tensors_.at(dH_index); /* [1, Batch, hidden_size ] */
auto dy_tensor = in_tensors_.at(dy_index); /* [SeqLen, Batch, insize ] */
auto dW_tensor = out_tensors_.at(dW_out_index);
auto dX_tensor = out_tensors_.at(dX_out_index);
auto weights_tensor = in_tensors_.at(weights_index); /* [all weights + biases, 1, 1] */
auto intermediate_tensor = in_tensors_.at(intermediate_data_index);
auto input_tensor = in_tensors_.at(input_index);
MS_ASSERT(dy_tensor != nullptr);
int LSTMGradCPUKernel::LstmBackpropUnidirectional(bool is_backward) {
// get input tensors
auto dC_tensor = in_tensors_.at(dC_index);
MS_ASSERT(dC_tensor != nullptr);
auto dH_tensor = in_tensors_.at(dH_index);
MS_ASSERT(dH_tensor != nullptr);
MS_ASSERT(dW_tensor != nullptr);
MS_ASSERT(dX_tensor != nullptr);
auto dy_tensor = in_tensors_.at(dy_index);
MS_ASSERT(dy_tensor != nullptr);
auto input_tensor = in_tensors_.at(input_index);
MS_ASSERT(input_tensor != nullptr);
auto weights_tensor = in_tensors_.at(weights_index);
MS_ASSERT(weights_tensor != nullptr);
auto intermediate_tensor = in_tensors_.at(intermediate_data_index);
MS_ASSERT(intermediate_tensor != nullptr);
MS_ASSERT(input_tensor != nullptr);
auto cell_input_tensor = in_tensors_.at(cell_input_index);
MS_ASSERT(cell_input_tensor != nullptr);
auto hidden_input_tensor = in_tensors_.at(hidden_input_index);
MS_ASSERT(hidden_input_tensor != nullptr);

// Get output tensors
auto dW_tensor = out_tensors_.at(dW_out_index);
MS_ASSERT(dW_tensor != nullptr);
auto dX_tensor = out_tensors_.at(dX_out_index);
MS_ASSERT(dX_tensor != nullptr);
auto dH_out_tensor = out_tensors_.at(dH_out_index);
MS_ASSERT(dH_out_tensor != nullptr);
auto dC_out_tensor = out_tensors_.at(dC_out_index);
MS_ASSERT(dC_out_tensor != nullptr);

auto cell_input_data = reinterpret_cast<float *>(cell_input_tensor->data());
auto hidden_input_data = reinterpret_cast<float *>(hidden_input_tensor->data());
auto dh_out = reinterpret_cast<float *>(dH_out_tensor->data());
auto dc_out = reinterpret_cast<float *>(dC_out_tensor->data());
auto intermediate_data = reinterpret_cast<float *>(intermediate_tensor->data());
auto dC = reinterpret_cast<float *>(dC_tensor->data());
auto dH = reinterpret_cast<float *>(dH_tensor->data());
@@ -80,7 +92,6 @@ int LSTMGradCPUKernel::LstmBackpropUnidirectional(float *output, bool is_backwar
auto dW = reinterpret_cast<float *>(dW_tensor->data());
auto dX = reinterpret_cast<float *>(dX_tensor->data());
auto weights = reinterpret_cast<float *>(weights_tensor->data());

auto input = reinterpret_cast<float *>(input_tensor->data());

auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_;
@@ -91,48 +102,50 @@ int LSTMGradCPUKernel::LstmBackpropUnidirectional(float *output, bool is_backwar
float *output_gate = intermediate_data + seq_stride * 3;
float *forget_gate = intermediate_data + seq_stride * 4;
float *cell_gate = intermediate_data + seq_stride * 5;
float *cell_state_minus1 = intermediate_data + seq_stride * 6;
float *workspace_i = workspace_;
float *dA = workspace_i;
workspace_i += num_of_gates * lstm_param_->output_step_;
ReorderLstmWeightGrad(weights_tmp_, weights, getLstmOrderIFGO(), false);

memset(dH, 0, dH_tensor->Size());
memset(dC, 0, dC_tensor->Size());

memset(dW_tmp_, 0, dW_tensor->Size()); // dW_tmp is summed in the loop
bool first_time = true;
float *workspace_gemm = workspace_ + GetRunWorkspaceGemmOffset(lstm_param_);

for (int t = lstm_param_->seq_len_ - 1; t >= 0; t--) {
int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t;
auto stride = real_t * state_size;
float *hidden_state_t = hidden_state + stride;
float *cell_state_t = cell_state + stride;
float *input_gate_t = input_gate + stride;
float *forget_gate_t = forget_gate + stride;
float *cell_gate_t = cell_gate + stride;
float *output_gate_t = output_gate + stride;
float *cell_state_minus1_t = cell_state_minus1 + stride;
float *output_ptr = output + real_t * lstm_param_->output_step_;
float *input_ptr = input + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *dX_t = dX + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *dY_t = nullptr;
if (first_time) {
dY_t = dY + real_t * lstm_param_->batch_ * lstm_param_->hidden_size_;
LstmGradStepUnitInit(output_gate_t, cell_state_t, dY_t, dC, dH, workspace_i, lstm_param_);
first_time = false;
}
dY_t = (real_t > 0) ? dY + (real_t - 1) * lstm_param_->batch_ * lstm_param_->hidden_size_ : nullptr;
LstmGradStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, hidden_state_t, cell_state_t,
dC, dH, dY_t, dX_t, cell_state_minus1_t, weights, workspace_i, dA, lstm_param_);
LstmGradWeightStepUnit(input_ptr, hidden_state_t, dA, dW_tmp_, workspace_i, lstm_param_);

float *prev_hidden_state = (real_t > 0) ? hidden_state + (real_t - 1) * state_size : hidden_input_data;
float *curr_cell_state = cell_state + stride;
float *prev_cell_state = (real_t > 0) ? cell_state + (real_t - 1) * state_size : cell_input_data;
float *curr_input_gate = input_gate + stride;
float *curr_forget_gate = forget_gate + stride;
float *curr_cell_gate = cell_gate + stride;
float *curr_output_gate = output_gate + stride;
float *curr_input = input + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *curr_dx = dX + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *curr_dy = dY + real_t * state_size;

float *dA = nullptr;
LstmGradDoInputStep(curr_output_gate, curr_cell_state, prev_cell_state, curr_cell_gate, curr_input_gate,
curr_forget_gate, curr_dy, dC, dH, &dA, curr_dx, weights_tmp_, workspace_, lstm_param_);
LstmGradDoWeightStep(curr_input, prev_hidden_state, dA, dW_tmp_, workspace_gemm, lstm_param_);
}
ReorderLstmWeightGrad(dW, dW_tmp_);
std::copy(&(dH[0]), &(dH[state_size]), &(dh_out[0]));
std::copy(&(dC[0]), &(dC[state_size]), &(dc_out[0]));
ReorderLstmWeightGrad(dW, dW_tmp_, getLstmOrderIOFG(), lstm_param_->has_bias_);
return RET_OK;
}

void LSTMGradCPUKernel::ReorderLstmWeightGrad(float *dst, float *src) {
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_size_, weights_order_IOFG);
void LSTMGradCPUKernel::ReorderLstmWeightGrad(float *dst, float *src, const int *order, bool include_bias) {
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_size_, order);
src += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
dst += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_, weights_order_IOFG);
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_, order);
src += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
dst += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
ReorderLstmWeights(dst, src, weight_batch_, 1, lstm_param_->hidden_size_, weights_order_IOFG);
if (include_bias) {
ReorderLstmWeights(dst, src, weight_batch_, 1, lstm_param_->hidden_size_, order);
}
}

int LSTMGradCPUKernel::DoGrad(int thread_id) { return RET_OK; }
@@ -206,6 +219,14 @@ int LSTMGradCPUKernel::MallocRunBuffer() {
MS_LOG(ERROR) << "LstmCPUKernel malloc run dW_tmp alloc error.";
return RET_ERROR;
}
int weights_size = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_ + // IW matrics
weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; // V matrics
weights_tmp_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(weights_size * sizeof(float)));
if (weights_tmp_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc run weights_tmp_ alloc error.";
return RET_ERROR;
}

return RET_OK;
}

@@ -218,6 +239,10 @@ void LSTMGradCPUKernel::FreeRunBuffer() {
ms_context_->allocator->Free(dW_tmp_);
dW_tmp_ = nullptr;
}
if (weights_tmp_ != nullptr) {
ms_context_->allocator->Free(weights_tmp_);
weights_tmp_ = nullptr;
}
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTMGrad, LiteKernelCreator<LSTMGradCPUKernel>)


+ 7
- 7
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_fp32.h View File

@@ -30,7 +30,7 @@ class LSTMGradCPUKernel : public InnerKernel {
explicit LSTMGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) { // }, thread_count_(ctx->thread_num_) {
lstm_param_ = reinterpret_cast<LstmParameter *>(op_parameter_);
lstm_param_ = reinterpret_cast<LstmGradParameter *>(op_parameter_);
}
~LSTMGradCPUKernel() {}
int Prepare() override;
@@ -39,15 +39,16 @@ class LSTMGradCPUKernel : public InnerKernel {
int DoGrad(int thread_id);

private:
int LstmBackpropUnidirectional(float *output, bool is_backward);
int LstmBackpropUnidirectional(bool is_backward);

int InitParam();
int MallocRunBuffer();
void FreeRunBuffer();
void ReorderLstmWeightGrad(float *dst, float *src);
void ReorderLstmWeightGrad(float *dst, float *src, const int *order, bool include_bias);

static const int input_index = 0;
static const int cell_state_index = 2;
static const int hidden_input_index = 1;
static const int cell_input_index = 2;
static const int weights_index = 3;
static const int dy_index = 7;
static const int dH_index = 8;
@@ -58,12 +59,11 @@ class LSTMGradCPUKernel : public InnerKernel {
static const int dC_out_index = 2;
static const int dW_out_index = 3;
static const int num_of_gates = 4;
const int weights_order_IOFG[2 * 4] = {0, 3, 1, 2, 4, 7, 5, 6}; // IOFG order to IFGO order

int input_size_align_ = 1;
float *dW_tmp_ = nullptr;
float *weights_tmp_ = nullptr;
float *workspace_ = nullptr;

int row_tile_ = 0;
int col_tile_ = 0;
int state_row_tile_ = 0;
@@ -73,7 +73,7 @@ class LSTMGradCPUKernel : public InnerKernel {
int input_thread_count_ = 0;
int input_thread_stride_ = 0;

LstmParameter *lstm_param_ = nullptr;
LstmGradParameter *lstm_param_ = nullptr;
};
} // namespace kernel
} // namespace mindspore


+ 21
- 17
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_weight_fp32.cc View File

@@ -58,15 +58,19 @@ int LSTMGradWeightCPUKernel::Run() {

int LSTMGradWeightCPUKernel::LstmBackpropUnidirectional(float *output, bool is_backward) {
auto dW_tensor = out_tensors_.at(dW_out_index);
auto intermediate_tensor = in_tensors_.at(intermediate_data_index);
auto input_tensor = in_tensors_.at(input_index);
MS_ASSERT(dW_tensor != nullptr);
auto intermediate_tensor = in_tensors_.at(intermediate_data_index);
MS_ASSERT(intermediate_tensor != nullptr);
auto input_tensor = in_tensors_.at(input_index);
MS_ASSERT(input_tensor != nullptr);
auto hidden_input_tensor = in_tensors_.at(hidden_input_index);
MS_ASSERT(hidden_input_tensor != nullptr);

auto intermediate_data = reinterpret_cast<float *>(intermediate_tensor->data());
auto input = reinterpret_cast<float *>(input_tensor->data());

auto dW = reinterpret_cast<float *>(dW_tensor->data());
auto hidden_input_data = reinterpret_cast<float *>(hidden_input_tensor->data());

auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_;
auto seq_stride = lstm_param_->seq_len_ * state_size;
float *hidden_state = intermediate_data;
@@ -75,27 +79,28 @@ int LSTMGradWeightCPUKernel::LstmBackpropUnidirectional(float *output, bool is_b
memset(dW_tmp_, 0, dW_tensor->Size()); // dW_tmp is summed in the loop
for (int t = lstm_param_->seq_len_ - 1; t >= 0; t--) {
int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t;
auto stride = real_t * state_size;
float *input_ptr = input + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *hidden_state_t = hidden_state + stride;
float *dA_t = dA + t * num_of_gates * lstm_param_->output_step_;
LstmGradWeightStepUnit(input_ptr, hidden_state_t, dA_t, dW_tmp_, workspace_, lstm_param_);
float *curr_input = input + real_t * lstm_param_->batch_ * lstm_param_->input_size_;
float *prev_hidden_state = (real_t > 0) ? hidden_state + (real_t - 1) * state_size : hidden_input_data;
float *curr_da = dA + real_t * num_of_gates * lstm_param_->output_step_;
LstmGradDoWeightStep(curr_input, prev_hidden_state, curr_da, dW_tmp_, workspace_, lstm_param_);
}
ReorderLstmWeightGrad(dW, dW_tmp_);
ReorderLstmWeightGrad(dW, dW_tmp_, lstm_param_->has_bias_);
return RET_OK;
}

void LSTMGradWeightCPUKernel::ReorderLstmWeightGrad(float *dst, float *src) {
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_size_, weights_order_IOFG);
void LSTMGradWeightCPUKernel::ReorderLstmWeightGrad(float *dst, float *src, bool has_bias) {
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_size_, getLstmOrderIOFG());
src += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
dst += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_;
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_, weights_order_IOFG);
ReorderLstmWeights(dst, src, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_, getLstmOrderIOFG());
src += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
dst += weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
ReorderLstmWeights(dst, src, weight_batch_, 1, lstm_param_->hidden_size_, weights_order_IOFG);
// update senced bias term (only if separate GradData and GradWeight)
dst += weight_batch_ * lstm_param_->hidden_size_;
ReorderLstmWeights(dst, src, weight_batch_, 1, lstm_param_->hidden_size_, weights_order_IOFG);
if (has_bias) {
ReorderLstmWeights(dst, src, weight_batch_, 1, lstm_param_->hidden_size_, getLstmOrderIOFG());
// update secend bias term (only if separate GradData and GradWeight)
dst += weight_batch_ * lstm_param_->hidden_size_;
ReorderLstmWeights(dst, src, weight_batch_, 1, lstm_param_->hidden_size_, getLstmOrderIOFG());
}
}

int LSTMGradWeightCPUKernel::DoGrad(int thread_id) { return RET_OK; }
@@ -107,7 +112,6 @@ int LSTMGradWeightCPUKernel::InitParam() {
lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT);
lstm_param_->batch_ = in_shape.at(SECOND_INPUT);
lstm_param_->input_size_ = in_shape.at(THIRD_INPUT);

auto y = in_tensors_.at(y_index);
MS_ASSERT(y != nullptr);
std::vector<int> y_shape = y->shape();


+ 7
- 20
mindspore/lite/src/runtime/kernel/arm/fp32_grad/lstm_grad_weight_fp32.h View File

@@ -29,8 +29,8 @@ class LSTMGradWeightCPUKernel : public InnerKernel {
public:
explicit LSTMGradWeightCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx), thread_count_(ctx->thread_num_) {
lstm_param_ = reinterpret_cast<LstmParameter *>(op_parameter_);
: InnerKernel(parameter, inputs, outputs, ctx) {
lstm_param_ = reinterpret_cast<LstmGradParameter *>(op_parameter_);
}
~LSTMGradWeightCPUKernel() {}
int Prepare() override;
@@ -44,34 +44,21 @@ class LSTMGradWeightCPUKernel : public InnerKernel {
int InitParam();
int MallocRunBuffer();
void FreeRunBuffer();
int InitInputWeightBias();
int InitStateWeightBias();
void ReorderLstmWeightGrad(float *dst, float *src);
// int InitInputWeightBias();
// int InitStateWeightBias(); AA to be removed
void ReorderLstmWeightGrad(float *dst, float *src, bool has_bias);

int thread_count_;
static const int input_index = 0;
static const int hidden_state_index = 1;
static const int hidden_input_index = 1;
static const int y_index = 2;
static const int intermediate_data_index = 3;
static const int dW_out_index = 0;
static const int num_of_gates = 4;
const int weights_order_IOFG[2 * 4] = {0, 3, 1, 2, 4, 7, 5, 6}; // IOFG order to IFGO order

int input_size_align_ = 1;
float *dW_tmp_ = nullptr;
float *workspace_ = nullptr;

int64_t weight_size_ = 0;
int64_t weight_h_size_ = 0;
int64_t input_size_;
int64_t hidden_size_;
int64_t num_layers_;
int64_t batch_size_;
int64_t seq_len_;
int num_directions_;
bool bidirectional_;
bool has_bias_;
size_t reserve_size_;
int row_tile_ = 0;
int col_tile_ = 0;
int state_row_tile_ = 0;
@@ -81,7 +68,7 @@ class LSTMGradWeightCPUKernel : public InnerKernel {
int input_thread_count_ = 0;
int input_thread_stride_ = 0;

LstmParameter *lstm_param_ = nullptr;
LstmGradParameter *lstm_param_ = nullptr;
};
} // namespace kernel
} // namespace mindspore


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc View File

@@ -64,7 +64,7 @@ int DoSgdInit(float *weight, float *accumulate, float *gradient, float *stat, fl
weight[i] -= accumulate[i] * learning_rate;
}
}
*stat = 1.0f;
*stat = 0.0f;
return RET_OK;
}



+ 15
- 11
mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc View File

@@ -1,3 +1,4 @@

/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
@@ -56,31 +57,34 @@ void StridedSliceGradCPUKernel::FillEmptyDims() {
int32_t strides[DIMENSION_8D];
int32_t input_shape[DIMENSION_8D];
int32_t i;
for (i = 0; i < param_->num_axes_; ++i) {

// invert the order of the dimension and fill defout outsize actual ranae
for (i = 0; i < DIMENSION_8D; ++i) {
begins[i] = param_->begins_[i];
ends[i] = MSMIN(param_->ends_[i], param_->in_shape_[i]);
ends[i] = param_->ends_[i];
strides[i] = param_->strides_[i];
input_shape[i] = param_->in_shape_[i];
}
for (i = param_->num_axes_; i < param_->in_shape_length_; ++i) {
input_shape[i] = param_->in_shape_[i];
begins[i] = 0;
ends[i] = param_->in_shape_[i];
strides[i] = 1;
}

int32_t real_index = param_->in_shape_length_ - 1;
for (i = DIMENSION_8D - 1; i >= 0; --i) {
if (real_index >= 0) {
param_->in_shape_[i] = input_shape[real_index--];
} else {
param_->in_shape_[i] = 1;
}
}
int out_shape_length = in_tensors_.at(1)->shape().at(0);
real_index = out_shape_length - 1;
for (i = DIMENSION_8D - 1; i >= 0; --i) {
if (real_index >= 0) {
param_->begins_[i] = begins[real_index];
param_->ends_[i] = ends[real_index];
param_->strides_[i] = strides[real_index];
param_->in_shape_[i] = input_shape[real_index--];
param_->strides_[i] = strides[real_index--];
} else {
param_->begins_[i] = 0;
param_->ends_[i] = 1;
param_->strides_[i] = 1;
param_->in_shape_[i] = 1;
}
}
param_->num_axes_ = DIMENSION_8D;


+ 16
- 9
mindspore/lite/src/train/train_populate_parameter.cc View File

@@ -30,6 +30,7 @@
#include "nnacl/fp32_grad/dropout_parameter.h"
#include "nnacl/fp32_grad/smooth_l1_loss.h"
#include "nnacl/fp32_grad/resize_grad_parameter.h"
#include "nnacl/fp32_grad/lstm_grad_fp32.h"

using mindspore::lite::Registry;

@@ -511,17 +512,20 @@ OpParameter *PopulateLstmGradParameter(const void *prim) {
return nullptr;
}

auto *param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter)));
auto *param = reinterpret_cast<LstmGradParameter *>(malloc(sizeof(LstmGradParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc LstmParameter failed.";
MS_LOG(ERROR) << "malloc LstmGradParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(LstmParameter));
memset(param, 0, sizeof(LstmGradParameter));

param->op_parameter_.type_ = primitive->value_type();
param->bidirectional_ = value->bidirectional();
param->zoneout_cell_ = value->zoneout_cell();
param->zoneout_hidden_ = value->zoneout_hidden();
param->input_size_ = value->input_size();
param->has_bias_ = value->has_bias();

return reinterpret_cast<OpParameter *>(param);
}

@@ -534,17 +538,19 @@ OpParameter *PopulateLstmGradDataParameter(const void *prim) {
return nullptr;
}

auto *param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter)));
auto *param = reinterpret_cast<LstmGradParameter *>(malloc(sizeof(LstmGradParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc LstmParameter failed.";
MS_LOG(ERROR) << "malloc LstmGradParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(LstmParameter));
memset(param, 0, sizeof(LstmGradParameter));

param->op_parameter_.type_ = primitive->value_type();
param->bidirectional_ = value->bidirectional();
param->zoneout_cell_ = value->zoneout_cell();
param->zoneout_hidden_ = value->zoneout_hidden();
param->input_size_ = value->input_size();
param->has_bias_ = value->has_bias();
return reinterpret_cast<OpParameter *>(param);
}

@@ -557,12 +563,12 @@ OpParameter *PopulateLstmGradWeightParameter(const void *prim) {
return nullptr;
}

auto *param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter)));
auto *param = reinterpret_cast<LstmGradParameter *>(malloc(sizeof(LstmGradParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc LstmParameter failed.";
MS_LOG(ERROR) << "malloc LstmGradParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(LstmParameter));
memset(param, 0, sizeof(LstmGradParameter));

param->op_parameter_.type_ = primitive->value_type();
param->input_size_ = value->input_size();
@@ -570,6 +576,7 @@ OpParameter *PopulateLstmGradWeightParameter(const void *prim) {
param->bidirectional_ = value->bidirectional();
param->zoneout_cell_ = value->zoneout_cell();
param->zoneout_hidden_ = value->zoneout_hidden();
param->has_bias_ = value->has_bias();
return reinterpret_cast<OpParameter *>(param);
}



+ 1
- 0
mindspore/lite/tools/benchmark_train/net_train.h View File

@@ -171,6 +171,7 @@ class MS_API NetTrain {
float CompareData(const float *refOutput, int size, T *msTensorData) {
size_t errorCount = 0;
float meanError = 0;
std::cout << "Out tensor size is: " << size << std::endl;
std::cout << "Data of model output: ";
for (int j = 0; j < std::min(50, size); j++) {
std::cout << static_cast<float>(msTensorData[j]) << " ";


Loading…
Cancel
Save