Browse Source

!10695 lstm add smooth param & add gru op & add 3 rnn fusion pass

From: @wangzhe128
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
fdf534d145
33 changed files with 2857 additions and 14 deletions
  1. +134
    -0
      mindspore/lite/nnacl/fp32/gru_fp32.c
  2. +43
    -0
      mindspore/lite/nnacl/fp32/gru_fp32.h
  3. +60
    -8
      mindspore/lite/nnacl/fp32/lstm_fp32.c
  4. +12
    -1
      mindspore/lite/nnacl/fp32/lstm_fp32.h
  5. +1
    -0
      mindspore/lite/schema/model.fbs
  6. +5
    -0
      mindspore/lite/schema/ops.fbs
  7. +121
    -0
      mindspore/lite/src/ops/gru.cc
  8. +47
    -0
      mindspore/lite/src/ops/gru.h
  9. +6
    -1
      mindspore/lite/src/ops/lstm.cc
  10. +2
    -0
      mindspore/lite/src/ops/lstm.h
  11. +42
    -0
      mindspore/lite/src/ops/populate/gru_populate.cc
  12. +1
    -0
      mindspore/lite/src/ops/populate/lstm_populate.cc
  13. +3
    -0
      mindspore/lite/src/ops/primitive_c.cc
  14. +165
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc
  15. +52
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h
  16. +14
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc
  17. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h
  18. +3
    -0
      mindspore/lite/test/CMakeLists.txt
  19. +1
    -0
      mindspore/lite/test/models_tf.cfg
  20. +35
    -0
      mindspore/lite/test/run_benchmark_nets.sh
  21. +3
    -0
      mindspore/lite/tools/converter/CMakeLists.txt
  22. +6
    -0
      mindspore/lite/tools/converter/anf_transform.cc
  23. +12
    -1
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  24. +5
    -1
      mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc
  25. +61
    -0
      mindspore/lite/tools/converter/parser/tf/tf_select_parser.cc
  26. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_select_parser.h
  27. +1
    -1
      mindspore/lite/tools/converter/parser/tf/tf_util.cc
  28. +679
    -0
      mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc
  29. +73
    -0
      mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h
  30. +370
    -0
      mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc
  31. +53
    -0
      mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.h
  32. +727
    -0
      mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc
  33. +82
    -0
      mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.h

+ 134
- 0
mindspore/lite/nnacl/fp32/gru_fp32.c View File

@@ -0,0 +1,134 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/gru_fp32.h"
#include <string.h>
#include "nnacl/fp32/lstm_fp32.h"
#include "nnacl/fp32/activation_fp32.h"
#include "nnacl/fp32/arithmetic_fp32.h"

void InitGruGate(float *gate_buffer, const float *bias, const GruParameter *gru_parm) {
int gate_offest = 0;
for (int l = 0; l < 3; l++) {
int batch_offest = gate_offest;
int bias_offest = l * gru_parm->hidden_size_;
for (int b = 0; b < gru_parm->batch_; b++) {
memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float));
batch_offest += gru_parm->hidden_size_;
}
gate_offest += gru_parm->batch_ * gru_parm->hidden_size_;
}
}

void GruStepUnit(float *output, const float *input, const float *input_reset_weight, const float *input_update_weight,
const float *input_hidden_weight, const float *state_reset_weight, const float *state_update_weight,
const float *state_hidden_weight, const float *bias, float *hidden_state, float *gate_buffer,
const GruParameter *gru_parm) {
InitGruGate(gate_buffer, bias, gru_parm);

float *update_gate = gate_buffer;
float *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_;
float *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2;

// input * weight
MatMulAcc(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
MatMulAcc(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
MatMulAcc(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);

// state * weight
MatMulAcc(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_,
gru_parm->hidden_size_);
MatMulAcc(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_,
gru_parm->hidden_size_);

// update reset_gate
Sigmoid(reset_gate, gru_parm->batch_ * gru_parm->hidden_size_, reset_gate);

// update update_gate
Sigmoid(update_gate, gru_parm->batch_ * gru_parm->hidden_size_, update_gate);

ElementMul(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_);
MatMulAcc(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_,
gru_parm->hidden_size_);

Tanh(hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_, hidden_buffer);

ElementMul(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_);

ArithmeticParameter parameter;
parameter.in_elements_num0_ = 1;
parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_;
const float one = 1.0f;
ElementOptSub(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, &parameter);

ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_);

memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float));
}

void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias,
float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm) {
// forward
const float *input_update_weight = weight_g;
const float *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_;
const float *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2;

const float *state_update_weight = weight_r;
const float *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_;
const float *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2;

for (int t = 0; t < check_seq_len; t++) {
const float *input_ptr = input + t * gru_parm->input_step_;
float *output_ptr = output + t * gru_parm->output_step_;
GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, state_reset_weight,
state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer, gru_parm);
}
// zero out extra fw outputs
for (int t = check_seq_len; t < gru_parm->seq_len_; t++) {
float *output_ptr = output + t * gru_parm->output_step_;
for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) {
output_ptr[i] = 0.0f;
}
}

// backward
if (gru_parm->bidirectional_) {
input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3;
input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4;
input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5;

state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3;
state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4;
state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5;

float *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_;
const float *backward_bias = bias + 3 * gru_parm->hidden_size_;
float *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_;
for (int t = check_seq_len - 1; t >= 0; t--) {
const float *input_ptr = input + t * gru_parm->input_step_;
float *output_ptr = backward_output + t * gru_parm->output_step_;
GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight,
state_reset_weight, state_update_weight, state_hidden_weight, backward_bias, backward_hidden_state,
gate_buffer, gru_parm);
}
// zero out extra bw outputs
for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) {
float *output_ptr = backward_output + t * gru_parm->output_step_;
for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) {
output_ptr[i] = 0.0f;
}
}
}
}

+ 43
- 0
mindspore/lite/nnacl/fp32/gru_fp32.h View File

@@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_
#define MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_
#include "nnacl/op_base.h"

typedef struct GruParameter {
// Primitive parameter
OpParameter op_parameter_;
// shape correlative
int input_size_;
int hidden_size_; // output_size
int seq_len_;
int batch_;
// other parameter
int input_step_;
int output_step_;
bool bidirectional_;
} GruParameter;

#ifdef __cplusplus
extern "C" {
#endif
void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias,
float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm);
#ifdef __cplusplus
}
#endif

#endif // MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_

+ 60
- 8
mindspore/lite/nnacl/fp32/lstm_fp32.c View File

@@ -16,6 +16,7 @@


#include "nnacl/fp32/lstm_fp32.h" #include "nnacl/fp32/lstm_fp32.h"
#include <string.h> #include <string.h>
#include <float.h>
#include "nnacl/fp32/activation_fp32.h" #include "nnacl/fp32/activation_fp32.h"
#include "nnacl/fp32/arithmetic_fp32.h" #include "nnacl/fp32/arithmetic_fp32.h"


@@ -79,21 +80,63 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int
} }
} }


int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(input0 + index);
float32x4_t vout = vld1q_f32(output + index);
vout = vmlaq_n_f32(vout, vin0, input1);
vst1q_f32(output + index, vout);
}
#endif
for (; index < element_size; index++) {
output[index] += input0[index] * input1;
}
return NNACL_OK;
}

void UpdataState(float *cell_state, const float *forget_gate, const float *input_gate, const float *cell_gate, void UpdataState(float *cell_state, const float *forget_gate, const float *input_gate, const float *cell_gate,
int batch, int hidden_size) {
float *state_buffer, int batch, int hidden_size, const float smooth) {
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // smooth * old_cell_state
memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float));
ArithmeticParameter parameter;
parameter.in_elements_num0_ = batch * hidden_size;
parameter.in_elements_num1_ = 1;
ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, &parameter);
}

ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size); ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size);
ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size); ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size);

if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // (1 - smooth) * new_cell_state
ElementOptMulAcc(cell_state, 1 - smooth, state_buffer, batch * hidden_size);
}
} }


void UpdataOutput(const float *cell_state, const float *output_gate, float *hidden_state, int batch, int hidden_size) {
void UpdataOutput(const float *cell_state, const float *output_gate, float *hidden_state, float *state_buffer_in,
int batch, int hidden_size, const float smooth) {
float *state_buffer = state_buffer_in + batch * hidden_size;
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) {
memcpy(state_buffer, hidden_state, batch * hidden_size * sizeof(float));
ArithmeticParameter parameter;
parameter.in_elements_num0_ = batch * hidden_size;
parameter.in_elements_num1_ = 1;
ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, &parameter);
}

Tanh(cell_state, batch * hidden_size, hidden_state); Tanh(cell_state, batch * hidden_size, hidden_state);
ElementMul(hidden_state, output_gate, hidden_state, batch * hidden_size); ElementMul(hidden_state, output_gate, hidden_state, batch * hidden_size);

if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) {
ElementOptMulAcc(hidden_state, 1 - smooth, state_buffer, batch * hidden_size);
}
} }


void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight, void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight,
const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight, const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight,
const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight, const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight,
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer,
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
const LstmParameter *lstm_parm) { const LstmParameter *lstm_parm) {
InitGate(gate_buffer, bias, lstm_parm); InitGate(gate_buffer, bias, lstm_parm);


@@ -129,17 +172,26 @@ void LstmStepUnit(float *output, const float *input, const float *input_input_we
// update cell_gate // update cell_gate
Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate); Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate);
// update cell state // update cell state
UpdataState(cell_state, forget_gate, input_gate, cell_gate, lstm_parm->batch_, lstm_parm->hidden_size_);
UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->smooth_);


// update output_gate // update output_gate
Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate); Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate);
// update output // update output
UpdataOutput(cell_state, output_gate, hidden_state, lstm_parm->batch_, lstm_parm->hidden_size_);
UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->smooth_);
memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));

if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) {
memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_,
lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
}
} }


void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
float *hidden_state, float *cell_state, float *gate_buffer, const LstmParameter *lstm_parm) {
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
const LstmParameter *lstm_parm) {
// forward // forward
const float *input_input_weight = weight_i; const float *input_input_weight = weight_i;
const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2;
@@ -156,7 +208,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float
float *output_ptr = output + t * lstm_parm->output_step_; float *output_ptr = output + t * lstm_parm->output_step_;
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight, LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight,
state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state,
cell_state, gate_buffer, lstm_parm);
cell_state, gate_buffer, state_buffer, lstm_parm);
} }


// backward // backward
@@ -180,7 +232,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float
float *output_ptr = backward_output + t * lstm_parm->output_step_; float *output_ptr = backward_output + t * lstm_parm->output_step_;
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight,
backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, lstm_parm);
backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, lstm_parm);
} }
} }
} }

+ 12
- 1
mindspore/lite/nnacl/fp32/lstm_fp32.h View File

@@ -31,13 +31,24 @@ typedef struct LstmParameter {
int input_step_; int input_step_;
int output_step_; int output_step_;
bool bidirectional_; bool bidirectional_;
// smooth factor for hidden/cell state calculation:
// output_hidden = old_hidden * smooth + new_hidden * (1 - smooth)
// output_cell = old_cell * smooth + new_cell * (1 - smooth)
float smooth_;
} LstmParameter; } LstmParameter;


#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size);

void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size);

int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size);

void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
float *hidden_state, float *cell_state, float *gate_buffer, const LstmParameter *lstm_parm);
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
const LstmParameter *lstm_parm);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif


+ 1
- 0
mindspore/lite/schema/model.fbs View File

@@ -262,6 +262,7 @@ union PrimitiveType {
Merge, Merge,
Mod, Mod,
GeLU, GeLU,
Gru,
} }


enum QuantType: int { enum QuantType: int {


+ 5
- 0
mindspore/lite/schema/ops.fbs View File

@@ -1005,6 +1005,11 @@ table OneHot {


table Lstm{ table Lstm{
bidirection: bool = false; bidirection: bool = false;
smooth: float = 0.0;
}

table Gru{
bidirection: bool = false;
} }


table PriorBox { table PriorBox {


+ 121
- 0
mindspore/lite/src/ops/gru.cc View File

@@ -0,0 +1,121 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/gru.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool Gru::GetBidirection() const { return this->primitive_->value.AsGru()->bidirection; }

void Gru::SetBidirection(bool bidirection) { this->primitive_->value.AsGru()->bidirection = bidirection; }

#else

bool Gru::GetBidirection() const { return this->primitive_->value_as_Gru()->bidirection(); }
int Gru::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Gru();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Gru return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateGru(*fbb, attr->bidirection());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Gru, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}

PrimitiveC *GruCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Gru>(primitive); }
Registry GruRegistry(schema::PrimitiveType_Gru, GruCreator);
#endif

const int kGruInputNum = 5;
const int kGruInputWithSeqLenNum = 6;
const int kGruOutputNum = 2;
int Gru::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
if ((inputs_.size() != kGruInputNum && inputs_.size() != kGruInputWithSeqLenNum) ||
outputs_.size() != kGruOutputNum) {
MS_LOG(ERROR) << "OpGru inputs or outputs size error.";
return RET_INPUT_TENSOR_ERROR;
}
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto weight_gate = inputs_.at(1);
MS_ASSERT(weight_gate != nullptr);
auto weight_recurrence = inputs_.at(2);
MS_ASSERT(weight_recurrence != nullptr);
auto bias = inputs_.at(3);
MS_ASSERT(bias != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
for (int i = 0; i < kGruOutputNum; i++) {
outputs_.at(i)->set_data_type(input->data_type());
outputs_.at(i)->set_format(input->format());
}
if (!infer_flag()) {
return RET_INFER_INVALID;
}

auto in_shape = input->shape(); // seq_len, batch, input_size
auto w_gate_shape = weight_gate->shape(); // num_direction, hidden_size * 3, input_size
auto w_recu_shape = weight_recurrence->shape(); // num_direction, hidden_size * 3, hidden_size
auto bias_shape = bias->shape(); // num_direction, hidden_size * 6
if (in_shape.size() != 3 || w_gate_shape.size() != 3 || w_recu_shape.size() != 3) {
MS_LOG(ERROR) << "OpGru input dims should be 3.";
return RET_ERROR;
}
if (w_gate_shape[1] != w_recu_shape[1] || w_recu_shape[1] * 2 != bias_shape[1]) {
MS_LOG(ERROR) << "OpGru w_gate, w_recu and bias hidden size not match.";
return RET_ERROR;
}
if (inputs_.size() == kGruInputWithSeqLenNum) {
auto seq_len_shape = inputs_.at(5)->shape();
if (seq_len_shape[0] > 1) {
MS_LOG(WARNING) << "OpGru with batch_size > 1 only support all same sequence_len now.";
return RET_ERROR;
}
if (seq_len_shape.size() != 1 && seq_len_shape[0] != in_shape[1]) {
MS_LOG(ERROR) << "OpGru sequence_len shape[0] and batch_size not match.";
return RET_ERROR;
}
}

int hidden_size = w_gate_shape[1] / 3;
// set output
std::vector<int> out_shape(in_shape);
out_shape[2] = hidden_size;
if (GetBidirection()) {
out_shape.insert(out_shape.begin() + 1, 2);
} else {
out_shape.insert(out_shape.begin() + 1, 1);
}
output->set_shape(out_shape);
// set hidden state
std::vector<int> state_shape(in_shape);
state_shape[0] = GetBidirection() ? 2 : 1;
state_shape[2] = hidden_size;
outputs_[1]->set_shape(state_shape);

return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 47
- 0
mindspore/lite/src/ops/gru.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_GRU_H_
#define MINDSPORE_LITE_SRC_OPS_GRU_H_
#include <vector>
#include <set>
#include <cmath>

#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
/*
* gru with linear_before_reset = 0
*/
class Gru : public PrimitiveC {
public:
Gru() = default;
~Gru() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Gru, PrimitiveC);
explicit Gru(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBidirection(bool bidirection);

#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
bool GetBidirection() const;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_SRC_OPS_GRU_H_

+ 6
- 1
mindspore/lite/src/ops/lstm.cc View File

@@ -25,11 +25,16 @@ namespace lite {
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
bool Lstm::GetBidirection() const { return this->primitive_->value.AsLstm()->bidirection; } bool Lstm::GetBidirection() const { return this->primitive_->value.AsLstm()->bidirection; }


float Lstm::GetSmooth() const { return this->primitive_->value.AsLstm()->smooth; }

void Lstm::SetBidirection(bool bidirection) { this->primitive_->value.AsLstm()->bidirection = bidirection; } void Lstm::SetBidirection(bool bidirection) { this->primitive_->value.AsLstm()->bidirection = bidirection; }


void Lstm::SetSmooth(float smooth) { this->primitive_->value.AsLstm()->smooth = smooth; }

#else #else


bool Lstm::GetBidirection() const { return this->primitive_->value_as_Lstm()->bidirection(); } bool Lstm::GetBidirection() const { return this->primitive_->value_as_Lstm()->bidirection(); }
float Lstm::GetSmooth() const { return this->primitive_->value_as_Lstm()->smooth(); }
int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb); MS_ASSERT(nullptr != fbb);
@@ -38,7 +43,7 @@ int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
MS_LOG(ERROR) << "value_as_Lstm return nullptr"; MS_LOG(ERROR) << "value_as_Lstm return nullptr";
return RET_ERROR; return RET_ERROR;
} }
auto val_offset = schema::CreateLstm(*fbb, attr->bidirection());
auto val_offset = schema::CreateLstm(*fbb, attr->bidirection(), attr->smooth());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Lstm, val_offset.o); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Lstm, val_offset.o);
fbb->Finish(prim_offset); fbb->Finish(prim_offset);
return RET_OK; return RET_OK;


+ 2
- 0
mindspore/lite/src/ops/lstm.h View File

@@ -33,12 +33,14 @@ class Lstm : public PrimitiveC {
MS_DECLARE_PARENT(Lstm, PrimitiveC); MS_DECLARE_PARENT(Lstm, PrimitiveC);
explicit Lstm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} explicit Lstm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBidirection(bool bidirection); void SetBidirection(bool bidirection);
void SetSmooth(float smooth);


#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
bool GetBidirection() const; bool GetBidirection() const;
float GetSmooth() const;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 42
- 0
mindspore/lite/src/ops/populate/gru_populate.cc View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/gru.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32/gru_fp32.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateGruParameter(const mindspore::lite::PrimitiveC *primitive) {
GruParameter *gru_param = reinterpret_cast<GruParameter *>(malloc(sizeof(GruParameter)));
if (gru_param == nullptr) {
MS_LOG(ERROR) << "malloc GruParameter failed.";
return nullptr;
}
memset(gru_param, 0, sizeof(GruParameter));
gru_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::Gru *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
if (param == nullptr) {
free(gru_param);
MS_LOG(ERROR) << "get Gru param nullptr.";
return nullptr;
}
gru_param->bidirectional_ = param->GetBidirection();
return reinterpret_cast<OpParameter *>(gru_param);
}
Registry GruParameterRegistry(schema::PrimitiveType_Gru, PopulateGruParameter);
} // namespace lite
} // namespace mindspore

+ 1
- 0
mindspore/lite/src/ops/populate/lstm_populate.cc View File

@@ -36,6 +36,7 @@ OpParameter *PopulateLstmParameter(const mindspore::lite::PrimitiveC *primitive)
return nullptr; return nullptr;
} }
lstm_param->bidirectional_ = param->GetBidirection(); lstm_param->bidirectional_ = param->GetBidirection();
lstm_param->smooth_ = param->GetSmooth();
return reinterpret_cast<OpParameter *>(lstm_param); return reinterpret_cast<OpParameter *>(lstm_param);
} }
Registry LstmParameterRegistry(schema::PrimitiveType_Lstm, PopulateLstmParameter); Registry LstmParameterRegistry(schema::PrimitiveType_Lstm, PopulateLstmParameter);


+ 3
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -161,6 +161,7 @@
#include "src/ops/switch.h" #include "src/ops/switch.h"
#include "src/ops/partial.h" #include "src/ops/partial.h"
#include "src/ops/gelu.h" #include "src/ops/gelu.h"
#include "src/ops/gru.h"


#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h" #include "src/ops/neg_grad.h"
@@ -995,6 +996,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) AssertOP(primitive); return new (std::nothrow) AssertOP(primitive);
case schema::PrimitiveType_GeLU: case schema::PrimitiveType_GeLU:
return new (std::nothrow) GeLU(primitive); return new (std::nothrow) GeLU(primitive);
case schema::PrimitiveType_Gru:
return new (std::nothrow) Gru(primitive);
#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad: case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive); return new (std::nothrow) ActivationGrad(primitive);


+ 165
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc View File

@@ -0,0 +1,165 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/gru_fp32.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Gru;

namespace mindspore::kernel {
void GruCPUKernel::FreeTmpBuffer() {
if (gate_buffer_ != nullptr) {
free(gate_buffer_);
gate_buffer_ = nullptr;
}
if (bias_ptr_ != nullptr) {
free(bias_ptr_);
bias_ptr_ = nullptr;
}
weight_g_ptr_ = nullptr;
weight_r_ptr_ = nullptr;
}

int GruCPUKernel::InitParam() {
auto input = in_tensors_.front();
MS_ASSERT(input != nullptr);
std::vector<int> in_shape = input->shape();
gru_parm_->seq_len_ = in_shape.at(0);
gru_parm_->batch_ = in_shape.at(1);
gru_parm_->input_size_ = in_shape.at(2);

auto weight_g = in_tensors_.at(1);
MS_ASSERT(weight_g != nullptr);
std::vector<int> w_shape = weight_g->shape();
gru_parm_->hidden_size_ = w_shape.at(1) / 3;

gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_;
gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_
: gru_parm_->batch_ * gru_parm_->hidden_size_;
return RET_OK;
}

int GruCPUKernel::InitBuffer() {
gate_buffer_ = reinterpret_cast<float *>(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float)));
if (gate_buffer_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error.";
return RET_ERROR;
}
return RET_OK;
}

int GruCPUKernel::InitWeightBias() {
auto weight_gate = in_tensors_.at(1);
MS_ASSERT(weight_gate != nullptr);
weight_g_ptr_ = reinterpret_cast<float *>(weight_gate->data_c());

auto weight_recu = in_tensors_.at(2);
MS_ASSERT(weight_recu != nullptr);
weight_r_ptr_ = reinterpret_cast<float *>(weight_recu->data_c());

int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_;
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error.";
return RET_ERROR;
}

auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
const int state_bias_offset = 3 * gru_parm_->hidden_size_;
for (int i = 0; i < state_bias_offset; i++) {
bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset];
}
if (gru_parm_->bidirectional_) {
bias_data += 3 * gru_parm_->hidden_size_ * 2;
auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_;
for (int i = 0; i < state_bias_offset; i++) {
backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset];
}
}
return RET_OK;
}

int GruCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int GruCPUKernel::ReSize() {
FreeTmpBuffer();
auto ret = InitParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "GruCPUKernel InitParam error.";
return RET_ERROR;
}

ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "GruCPUKernel InitWeightBias error.";
FreeTmpBuffer();
return RET_ERROR;
}

ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "GruCPUKernel InitBuffer error.";
FreeTmpBuffer();
return RET_ERROR;
}
return RET_OK;
}

int GruCPUKernel::Run() {
auto input = in_tensors_.at(kInputIndex);
MS_ASSERT(input != nullptr);
auto hidden_state = in_tensors_.at(4);
MS_ASSERT(hidden_state != nullptr);
auto output = out_tensors_.at(0);
MS_ASSERT(output != nullptr);
auto input_ptr = reinterpret_cast<float *>(input->data_c());
MS_ASSERT(input_ptr);
auto output_ptr = reinterpret_cast<float *>(output->MutableData());
MS_ASSERT(output_ptr);
auto output_hidden_state = out_tensors_[1];
memcpy(output_hidden_state->MutableData(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float));
int check_seq_len = gru_parm_->seq_len_;
if (in_tensors_.size() == 6) {
auto seq_len = reinterpret_cast<int *>(in_tensors_.at(5)->data_c());
if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) {
MS_LOG(ERROR) << "different batch seq_len is currently not supported";
return RET_ERROR;
}
check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0]));
}

MS_ASSERT(weight_g_ptr_);
MS_ASSERT(weight_r_ptr_);
MS_ASSERT(bias_ptr_);
MS_ASSERT(gate_buffer_);
Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_,
reinterpret_cast<float *>(output_hidden_state->MutableData()), gate_buffer_, check_seq_len, gru_parm_);
return RET_OK;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gru, LiteKernelCreator<GruCPUKernel>)
} // namespace mindspore::kernel

+ 52
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h View File

@@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/fp32/gru_fp32.h"

namespace mindspore::kernel {
class GruCPUKernel : public LiteKernel {
public:
GruCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
gru_parm_ = reinterpret_cast<GruParameter *>(op_parameter_);
}

~GruCPUKernel() override { FreeTmpBuffer(); }

int Init() override;
int ReSize() override;
int Run() override;

private:
void FreeTmpBuffer();
int InitParam();
int InitBuffer();
int InitWeightBias();

float *gate_buffer_ = nullptr;
const float *weight_g_ptr_ = nullptr;
const float *weight_r_ptr_ = nullptr;
float *bias_ptr_ = nullptr;
GruParameter *gru_parm_ = nullptr;
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_

+ 14
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc View File

@@ -15,6 +15,7 @@
*/ */


#include "src/runtime/kernel/arm/fp32/lstm_fp32.h" #include "src/runtime/kernel/arm/fp32/lstm_fp32.h"
#include <float.h>
#include <vector> #include <vector>
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
@@ -32,6 +33,10 @@ void LstmCPUKernel::FreeTmpBuffer() {
free(gate_buffer_); free(gate_buffer_);
gate_buffer_ = nullptr; gate_buffer_ = nullptr;
} }
if (state_buffer_ != nullptr) {
free(state_buffer_);
state_buffer_ = nullptr;
}
if (weight_i_ptr_ != nullptr) { if (weight_i_ptr_ != nullptr) {
free(weight_i_ptr_); free(weight_i_ptr_);
weight_i_ptr_ = nullptr; weight_i_ptr_ = nullptr;
@@ -71,6 +76,14 @@ int LstmCPUKernel::InitBuffer() {
MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error.";
return RET_ERROR; return RET_ERROR;
} }
if (!(lstm_parm_->smooth_ >= -FLT_EPSILON && lstm_parm_->smooth_ <= FLT_EPSILON)) {
int buffer_size = 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float);
state_buffer_ = reinterpret_cast<float *>(malloc(buffer_size));
if (state_buffer_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error.";
return RET_ERROR;
}
}
return RET_OK; return RET_OK;
} }


@@ -173,7 +186,7 @@ int LstmCPUKernel::Run() {
MS_ASSERT(gate_buffer_); MS_ASSERT(gate_buffer_);
Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_,
reinterpret_cast<float *>(output_hidden_state->MutableData()), reinterpret_cast<float *>(output_hidden_state->MutableData()),
reinterpret_cast<float *>(output_cell_state->MutableData()), gate_buffer_, lstm_parm_);
reinterpret_cast<float *>(output_cell_state->MutableData()), gate_buffer_, state_buffer_, lstm_parm_);
return RET_OK; return RET_OK;
} }




+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h View File

@@ -44,6 +44,7 @@ class LstmCPUKernel : public LiteKernel {
int InitWeightBias(); int InitWeightBias();


float *gate_buffer_ = nullptr; float *gate_buffer_ = nullptr;
float *state_buffer_ = nullptr;
float *weight_i_ptr_ = nullptr; float *weight_i_ptr_ = nullptr;
float *weight_h_ptr_ = nullptr; float *weight_h_ptr_ = nullptr;
float *bias_ptr_ = nullptr; float *bias_ptr_ = nullptr;


+ 3
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -187,6 +187,9 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc


+ 1
- 0
mindspore/lite/test/models_tf.cfg View File

@@ -0,0 +1 @@
decoder_step_201217.pb 5

+ 35
- 0
mindspore/lite/test/run_benchmark_nets.sh View File

@@ -925,6 +925,41 @@ function Run_arm64() {
fi fi
done < ${models_compatibility_config} done < ${models_compatibility_config}


# Run tf converted models:
while read line; do
model_name=${line}
if [[ $model_name == \#* ]]; then
continue
fi
model_name=`echo ${tf_line_info}|awk -F ' ' '{print $1}'`
input_num=`echo ${tf_line_info}|awk -F ' ' '{print $2}'`
input_files=''
for i in $(seq 1 $input_num)
do
input_files=$input_files'/data/local/tmp/input_output/input/'$model_name'.ms_'$i'.bin,'
done
echo ${model_name} >> "${run_arm64_log_file}"
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out' >> "${run_arm64_log_file}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out' >> adb_run_cmd.txt
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}"
if [ $? = 0 ]; then
run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
# run benchmark test without clib data
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> "${run_arm64_log_file}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> adb_run_cmd.txt
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}"
if [ $? = 0 ]; then
run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
done < ${models_tf_config}

# Run tflite converted models: # Run tflite converted models:
while read line; do while read line; do
model_name=${line} model_name=${line}


+ 3
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -46,6 +46,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/batchmatmul_fusion.cc ../optimizer/fusion/batchmatmul_fusion.cc
../optimizer/fusion/sigmoid_mul_fusion.cc ../optimizer/fusion/sigmoid_mul_fusion.cc
../optimizer/fusion/conv_conv_fusion.cc ../optimizer/fusion/conv_conv_fusion.cc
../optimizer/fusion/tflite_lstm_cell_fusion.cc
../optimizer/fusion/tf_lstm_cell_fusion.cc
../optimizer/fusion/bidirection_tf_gru_cell_fusion.cc
../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc ../optimizer/graph/clip_convert_activation_pass.cc


+ 6
- 0
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -29,6 +29,9 @@
#include "tools/optimizer/fusion/batchmatmul_fusion.h" #include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h" #include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
#include "tools/optimizer/fusion/conv_conv_fusion.h" #include "tools/optimizer/fusion/conv_conv_fusion.h"
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h" #include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h"
#include "tools/optimizer/graph/identity_remove_pass.h" #include "tools/optimizer/graph/identity_remove_pass.h"
@@ -114,6 +117,9 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>()); fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>()); fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>()); fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::BiDirectionTfGruCellFusion>());
} }
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>(); auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
weight_format_hardcode_pass->SetFmkType(config->fmk); weight_format_hardcode_pass->SetFmkType(config->fmk);


+ 12
- 1
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -572,7 +572,12 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) {
type = TypeIdToType(kObjectTypeTensorType); type = TypeIdToType(kObjectTypeTensorType);
} }
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vector));
auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape_vector);
if (abstract == nullptr) {
MS_LOG(ERROR) << "create AbstractTensor failed";
return RET_ERROR;
}
anf_node->set_abstract(abstract);
anf_node_map->insert(std::pair(op.name(), anf_node)); anf_node_map->insert(std::pair(op.name(), anf_node));
} else { } else {
AbstractBasePtrList abstractList; AbstractBasePtrList abstractList;
@@ -589,6 +594,12 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue}; std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue};
CNodePtr getItemCNode = anf_graph->NewCNode(inputs); CNodePtr getItemCNode = anf_graph->NewCNode(inputs);
std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
if (abstract == nullptr) {
MS_LOG(ERROR) << "create AbstractTensor failed";
return RET_ERROR;
}
getItemCNode->set_abstract(abstract);
getItemCNode->set_fullname_with_scope(output_item_name); getItemCNode->set_fullname_with_scope(output_item_name);
anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode));
} }


+ 5
- 1
mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc View File

@@ -63,7 +63,11 @@ STATUS TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op,
} }


*output_size = 1; *output_size = 1;
return AddOpInput(tf_op, 0, inputs);
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
return AddOpInput(tf_op, 1, inputs);
} }
TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser()); TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser());
} // namespace lite } // namespace lite


+ 61
- 0
mindspore/lite/tools/converter/parser/tf/tf_select_parser.cc View File

@@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_select_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFSelectParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF SelectParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::SwitchT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}

primitive->value.type = schema::PrimitiveType_Switch;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return RET_OK;
}
TFNodeRegistrar g_tfSelectParser("Select", new TFSelectParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_select_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_

#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFSelectParser : public TFNodeParser {
public:
TFSelectParser() = default;
~TFSelectParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_

+ 1
- 1
mindspore/lite/tools/converter/parser/tf/tf_util.cc View File

@@ -122,7 +122,7 @@ std::string TensorFlowUtils::GetFlattenNodeName(const std::string &input_name) {
std::sregex_token_iterator()); std::sregex_token_iterator());
std::string ret = input_name; std::string ret = input_name;
if (input_splits.size() == 3) { if (input_splits.size() == 3) {
if (input_splits[2] == "0") {
if (input_splits[2].compare("0") == 0) {
ret = input_splits[0]; ret = input_splits[0];
} else { } else {
ret = input_splits[0] + ":" + input_splits[2]; // multi output node ret = input_splits[0] + ":" + input_splits[2]; // multi output node


+ 679
- 0
mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc View File

@@ -0,0 +1,679 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h"
#include <memory>
#include <functional>
#include "src/ops/primitive_c.h"
#include "src/common/utils.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"

namespace mindspore {
namespace opt {
namespace {
constexpr size_t kWhileCommonInputsLength = 2;
constexpr size_t kWhileUniqInputsLength = 6;
constexpr size_t kCondNodesNum = 12;
constexpr size_t kCondCNodesNum = 4;
constexpr size_t kBodyNodesNum = 69;
constexpr size_t kBodyCNodesNum = 25;
const auto &p1 = std::placeholders::_1;

bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }

bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
return opt::GetCNodeType(n) == type;
}
return false;
}
} // namespace

BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, bool multigraph)
: PatternProcessPass(name, multigraph) {
/*
* vars for while input
* common:
* 0:const0 1:init_state
* fw_while_inputs:
* 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
* bw_while_inputs:
* 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
*/
for (size_t i = 0; i < kWhileCommonInputsLength; ++i) {
common_vars_.emplace_back(std::make_shared<Var>());
}
for (size_t i = 0; i < kWhileUniqInputsLength; ++i) {
fw_vars_.emplace_back(std::make_shared<Var>());
bw_vars_.emplace_back(std::make_shared<Var>());
}
input_ = std::make_shared<Var>();
input_length_ = std::make_shared<Var>();
transpose_input_ = std::make_shared<Var>();
}

const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
auto const1 = std::make_shared<CondVar>(IsParameterNode);
auto ele_shape = std::make_shared<CondVar>(IsParameterNode);

// forward
auto fw_max1 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_});
auto fw_max2 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, fw_max1});

auto fw_shape =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_});
auto fw_stride =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), fw_shape});
auto fw_min =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2});

auto fw_reserve =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape,
fw_stride});
auto fw_from_tensor =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)),
transpose_input_, ele_shape});
auto is_fw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While));
auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], common_vars_[0], fw_stride, common_vars_[0],
fw_reserve, common_vars_[1], fw_min, fw_from_tensor, input_length_});
fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end());
fw_while.emplace_back(common_vars_[1]);
auto fw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)),
fw_while, std::make_shared<Var>()});
auto fw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)),
fw_get_item, ele_shape});
auto fw_out_trans =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), fw_stack});

// backward
auto bw_reverse_seq = VectorRef(
{std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), input_, input_length_});
auto bw_max1 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_});
auto bw_max2 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, bw_max1});
auto bw_trans =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_reverse_seq});
auto bw_shape =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans});
auto bw_stride =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), bw_shape});
auto bw_min =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2});
auto bw_reserve =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape,
bw_stride});
auto bw_from_tensor =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), bw_trans,
ele_shape});
auto is_bw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While));
auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], common_vars_[0], bw_stride, common_vars_[0],
bw_reserve, common_vars_[1], bw_min, bw_from_tensor, input_length_});
bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end());
bw_while.emplace_back(common_vars_[1]);
auto bw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)),
bw_while, std::make_shared<Var>()});
auto bw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)),
bw_get_item, ele_shape});
auto bw_out_trans =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_stack});
auto bw_reverse1 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), bw_out_trans,
input_length_});

auto concat = VectorRef(
{std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Concat)), fw_out_trans, bw_reverse1});
return concat;
}

AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode);
auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode);
auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode);
auto is_parameter4 = std::make_shared<CondVar>(IsParameterNode);
auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less));
auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less));
auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_LogicalAnd));
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return));
VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2});
VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4});
VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref});
VectorRef return_ref = VectorRef({is_return, logicaland_ref});
VarPtr fg = std::make_shared<Var>("RootG");
auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true);
return pattern;
}

AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
std::vector<CondVarPtr> placeholders;
for (int i = 0; i < 13; ++i) {
placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode));
}
VectorRef add = VectorRef({std::make_shared<Var>(), placeholders[2], std::make_shared<CondVar>(IsParameterNode)});
VectorRef add1 = VectorRef({std::make_shared<Var>(), placeholders[0], std::make_shared<CondVar>(IsParameterNode)});

VectorRef get_item = VectorRef(
{std::make_shared<Var>("GetItem"), placeholders[6], placeholders[2], std::make_shared<CondVar>(IsParameterNode)});
VectorRef concat_input_h = VectorRef({std::make_shared<Var>(), get_item, placeholders[4]});

VectorRef matmul1 = VectorRef({std::make_shared<Var>("Matmul"), concat_input_h, placeholders[8]});
VectorRef biasadd1 = VectorRef({std::make_shared<Var>("BiasAdd"), matmul1, placeholders[9]});
VectorRef sigmoid1 = VectorRef({std::make_shared<Var>("Sigmoid"), biasadd1});

VectorRef split = VectorRef({std::make_shared<Var>("Split"), sigmoid1});
VectorRef get_item1 = VectorRef({std::make_shared<Var>("TupleGetItem"), split, std::make_shared<Var>()});
VectorRef get_item2 = VectorRef({std::make_shared<Var>("TupleGetItem"), split, std::make_shared<Var>()});

VectorRef pre_reset = VectorRef({std::make_shared<Var>("Mul"), get_item1, placeholders[4]});
VectorRef concat2 = VectorRef({std::make_shared<Var>("Concat"), get_item, pre_reset});
VectorRef matmul2 = VectorRef({std::make_shared<Var>("Matmul"), concat2, placeholders[10]});
VectorRef biasadd2 = VectorRef({std::make_shared<Var>("BiasAdd"), matmul2, placeholders[11]});
VectorRef tanh = VectorRef({std::make_shared<Var>("Tanh"), biasadd2});

VectorRef update_hidden = VectorRef({std::make_shared<Var>("Mul"), get_item2, placeholders[4]});
VectorRef minus_update =
VectorRef({std::make_shared<Var>("Sub"), std::make_shared<CondVar>(IsParameterNode), get_item2});
VectorRef updated = VectorRef({std::make_shared<Var>("Mul"), minus_update, tanh});

VectorRef new_hidden = VectorRef({std::make_shared<Var>("Add"), update_hidden, updated});

VectorRef greater_equal = VectorRef({std::make_shared<Var>("GreaterEqual"), placeholders[2], placeholders[7]});

VectorRef select_output = VectorRef({std::make_shared<Var>("Switch"), greater_equal, placeholders[12], new_hidden});
VectorRef output = VectorRef({std::make_shared<Var>("SetItem"), placeholders[3], placeholders[2], select_output});

VectorRef select_hidden = VectorRef({std::make_shared<Var>("Switch"), greater_equal, placeholders[4], new_hidden});

auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple));
std::vector<BaseRef> outputs = {is_make_tuple, add1, placeholders[1], add,
output, select_hidden, placeholders[5], placeholders[6],
placeholders[7]};
outputs.insert(outputs.end(), placeholders.begin() + 8, placeholders.end());
VectorRef make_tuple_node = VectorRef(outputs);
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return));
VectorRef return_node = VectorRef({is_return, make_tuple_node});

VarPtr fg = std::make_shared<Var>("RootG");
auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true);
return pattern;
}

ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNodePtr &parameter_anf) const {
MS_ASSERT(parameter_anf != nullptr);
if (!utils::isa<ParameterPtr>(parameter_anf)) {
MS_LOG(DEBUG) << "parameter_anf is not ParameterPtr";
return nullptr;
}
auto parameter = utils::cast<ParameterPtr>(parameter_anf);
if (!parameter->has_default()) {
MS_LOG(DEBUG) << "parameter not have default value";
return nullptr;
}
auto param_value = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
return param_value;
}

STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf,
const AnfNodePtr &bw_cand_kernel_anf, int *input_size,
int *hidden_size) const {
MS_ASSERT(fw_cand_kernel != nullptr);
MS_ASSERT(bw_cand_kernel != nullptr);
MS_ASSERT(input_size != nullptr);
MS_ASSERT(hidden_size != nullptr);
auto fw_cand_kernel_value = GetDefaultParamValue(fw_cand_kernel_anf);
if (fw_cand_kernel_value == nullptr) {
return RET_ERROR;
}
auto fw_cand_kernel_shape = fw_cand_kernel_value->tensor_shape();
if (fw_cand_kernel_shape.size() != 2) {
return RET_ERROR;
}
auto bw_cand_kernel_value = GetDefaultParamValue(bw_cand_kernel_anf);
if (bw_cand_kernel_value == nullptr) {
return RET_ERROR;
}
auto bw_cand_kernel_shape = bw_cand_kernel_value->tensor_shape();
if (bw_cand_kernel_shape.size() != 2) {
return RET_ERROR;
}
if (fw_cand_kernel_shape != bw_cand_kernel_shape) {
return RET_ERROR;
}
if (fw_cand_kernel_shape[1] <= 0 || fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1] <= 0) {
MS_LOG(DEBUG) << "gru input size or hidden size illegal";
return RET_ERROR;
}
*hidden_size = fw_cand_kernel_shape[1];
*input_size = fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1];
return RET_OK;
}

ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name,
const std::vector<int> &shape, const TypeId type,
void **tensor_data) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(tensor_data != nullptr);
auto parameter = func_graph->add_parameter();
parameter->set_name(name);
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape_vector);
if (abstract_tensor == nullptr) {
return nullptr;
}
parameter->set_abstract(abstract_tensor);

auto gate_weight_default = std::make_shared<ParamValueLite>();
if (gate_weight_default == nullptr) {
MS_LOG(ERROR) << "gate_weight_default is nullptr";
return nullptr;
}
gate_weight_default->set_tensor_shape(shape);
gate_weight_default->set_tensor_type(type);
gate_weight_default->set_format(schema::Format_NHWC);
int data_len = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int data_size = 0;
if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) {
data_size = data_len * sizeof(float);
*tensor_data = new (std::nothrow) float[data_len];
} else if (type == kNumberTypeInt || type == kNumberTypeInt32) {
data_size = data_len * sizeof(int);
*tensor_data = new (std::nothrow) int[data_len];
} else {
MS_LOG(DEBUG) << "unsupported data type";
return nullptr;
}
if (*tensor_data == nullptr) {
MS_LOG(ERROR) << "new data failed";
return nullptr;
}

gate_weight_default->SetTensorData(*tensor_data, data_size);
parameter->set_default_param(gate_weight_default);
return parameter;
}

void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0,
const int r1, const int c0, const int c1, float *data,
bool t) const {
MS_ASSERT(mat != nullptr);
MS_ASSERT(data != nullptr);
MS_ASSERT(0 <= r0 && r0 < r1 && r1 <= R);
MS_ASSERT(0 <= c0 && c0 < c1 && c1 <= C);
const int RT = r1 - r0;
const int CT = c1 - c0;
for (int i = r0; i < r1; ++i) {
for (int j = c0; j < c1; ++j) {
if (t) {
data[(j - c0) * RT + (i - r0)] = mat[i * C + j];
} else {
data[(i - r0) * CT + (j - c0)] = mat[i * C + j];
}
}
}
}

STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight,
const int input_size, const int hidden_size,
float *gate_tensor_data, float *recu_tensor_data) const {
MS_ASSERT(gate_weight != nullptr);
MS_ASSERT(cand_weight != nullptr);
MS_ASSERT(gate_tensor_data != nullptr);
MS_ASSERT(recu_tensor_data != nullptr);
const std::vector<int> gate_shape{input_size + hidden_size, hidden_size * 2};
const std::vector<int> cand_shape{hidden_size * 2, hidden_size};
auto gate_weight_value = GetDefaultParamValue(gate_weight);
if (gate_weight_value == nullptr) {
return RET_ERROR;
}
auto gate_weight_data = reinterpret_cast<float *>(gate_weight_value->tensor_addr());
if (gate_weight_data == nullptr) {
return RET_ERROR;
}
auto gate_weight_shape = gate_weight_value->tensor_shape();

auto cand_weight_value = GetDefaultParamValue(cand_weight);
if (cand_weight_value == nullptr) {
return RET_ERROR;
}
auto cand_weight_data = reinterpret_cast<float *>(cand_weight_value->tensor_addr());
if (cand_weight_data == nullptr) {
return RET_ERROR;
}
auto cand_weight_shape = cand_weight_value->tensor_shape();

if (gate_weight_shape != gate_shape || cand_weight_shape != cand_shape) {
return RET_ERROR;
}

// input_update_weight
CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, 0, input_size, hidden_size,
hidden_size * 2, gate_tensor_data, true);
// input_reset_weight
CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, 0, input_size, 0, hidden_size,
gate_tensor_data + input_size * hidden_size, true);
// input_hidden_weight
CopyFlattenMatData(cand_weight_data, input_size + hidden_size, hidden_size, 0, input_size, 0, hidden_size,
gate_tensor_data + input_size * hidden_size * 2, true);

// state_update_weight
CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, input_size, input_size + hidden_size,
hidden_size, hidden_size * 2, recu_tensor_data, true);
// state_reset_weight
CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, input_size, input_size + hidden_size,
0, hidden_size, recu_tensor_data + hidden_size * hidden_size, true);
// state_hidden_weight
CopyFlattenMatData(cand_weight_data, input_size + hidden_size, hidden_size, input_size, input_size + hidden_size, 0,
hidden_size, recu_tensor_data + hidden_size * hidden_size * 2, true);
return RET_OK;
}

STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias,
const int hidden_size, float *tensor_data) const {
MS_ASSERT(bias != nullptr);
MS_ASSERT(tensor_data != nullptr);
std::vector<int> gate_shape{hidden_size * 2};
std::vector<int> cand_shape{hidden_size};
auto gate_bias_value = GetDefaultParamValue(gate_bias);
if (gate_bias_value == nullptr) {
return RET_ERROR;
}
auto gate_bias_data = reinterpret_cast<float *>(gate_bias_value->tensor_addr());
auto gate_bias_shape = gate_bias_value->tensor_shape();
auto cand_bias_value = GetDefaultParamValue(cand_bias);
if (cand_bias_value == nullptr) {
return RET_ERROR;
}
auto cand_bias_data = reinterpret_cast<float *>(cand_bias_value->tensor_addr());
auto cand_bias_shape = cand_bias_value->tensor_shape();
if (gate_bias_shape != gate_shape || cand_bias_shape != cand_shape) {
return RET_ERROR;
}

// update_gate bias
CopyFlattenMatData(gate_bias_data, 1, hidden_size * 2, 0, 1, hidden_size, hidden_size * 2, tensor_data, false);
// reset_gate bias
CopyFlattenMatData(gate_bias_data, 1, hidden_size * 2, 0, 1, 0, hidden_size, tensor_data + hidden_size, false);
// hidden_gate bias
CopyFlattenMatData(cand_bias_data, 1, hidden_size, 0, 1, 0, hidden_size, tensor_data + hidden_size * 2, false);

return RET_OK;
}

CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph,
const AnfNodePtr &hidden_state,
const std::string base_name) const {
MS_ASSERT(func_graph);
MS_ASSERT(hidden_state);
auto stack_primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::StackT> attr = std::make_unique<schema::StackT>();
attr->axis = 0;
stack_primitive->value.type = schema::PrimitiveType_Stack;
stack_primitive->value.value = attr.release();
auto stack_cvalue = lite::PrimitiveC::Create(stack_primitive.release());
auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(stack_cvalue));
std::vector<AnfNodePtr> new_node_inputs = {value_node, hidden_state, hidden_state};
auto new_node = func_graph->NewCNode(new_node_inputs);
new_node->set_abstract(hidden_state->abstract()->Clone());
new_node->set_fullname_with_scope("stack_hidden_" + base_name);
return new_node;
}

CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const EquivPtr &equiv, const EquivPtr &fw_body_equiv,
const EquivPtr &bw_body_equiv,
const std::string &base_name) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(input != nullptr);
MS_ASSERT(equiv != nullptr);
MS_ASSERT(fw_body_equiv != nullptr);
MS_ASSERT(bw_body_equiv != nullptr);
auto gru_primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::GruT> attr = std::make_unique<schema::GruT>();
attr->bidirection = true;
gru_primitive->value.type = schema::PrimitiveType_Gru;
gru_primitive->value.value = attr.release();
auto gru_cvalue = lite::PrimitiveC::Create(gru_primitive.release());
auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(gru_cvalue));

auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[2]]);
MS_ASSERT(fw_gate_kernel);
auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[3]]);
MS_ASSERT(fw_gate_bias);
auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[4]]);
MS_ASSERT(fw_cand_kernel);
auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[5]]);
MS_ASSERT(fw_cand_bias);

auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[2]]);
MS_ASSERT(bw_gate_kernel);
auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[3]]);
MS_ASSERT(bw_gate_bias);
auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[4]]);
MS_ASSERT(bw_cand_kernel);
auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[5]]);
MS_ASSERT(bw_cand_bias);

auto hidden = utils::cast<AnfNodePtr>((*equiv)[common_vars_[1]]);
MS_ASSERT(hidden);
auto stacked_hidden = GetStackedHiddenState(func_graph, hidden, base_name);
if (stacked_hidden == nullptr) {
return nullptr;
}
auto input_length = utils::cast<AnfNodePtr>((*equiv)[input_length_]);
MS_ASSERT(hidden);

int input_size = 0;
int hidden_size = 0;
auto status = GetInputAndHiddenSize(fw_cand_kernel, bw_cand_kernel, &input_size, &hidden_size);
if (status != RET_OK) {
return nullptr;
}
std::vector<int> gate_weight_shape{2, hidden_size * 3, input_size};
float *gate_tensor_data = nullptr;
auto gate_weight = AddDefaultParameter(func_graph, base_name + "_gate_weight", gate_weight_shape, kNumberTypeFloat32,
reinterpret_cast<void **>(&gate_tensor_data));
if (gate_weight == nullptr) {
return nullptr;
}
std::vector<int> recu_weight_shape{2, hidden_size * 3, hidden_size};
float *recu_tensor_data = nullptr;
auto recu_weight = AddDefaultParameter(func_graph, base_name + "_cand_weight", recu_weight_shape, kNumberTypeFloat32,
reinterpret_cast<void **>(&recu_tensor_data));
if (recu_weight == nullptr) {
return nullptr;
}
std::vector<int> bias_shape{2, hidden_size * 6};
float *bias_tensor_data = nullptr;
auto bias = AddDefaultParameter(func_graph, base_name + "_bias", bias_shape, kNumberTypeFloat32,
reinterpret_cast<void **>(&bias_tensor_data));
if (bias == nullptr) {
return nullptr;
}
for (int i = 0; i < 2 * hidden_size * 6; ++i) {
bias_tensor_data[i] = 0.0f;
}

if (ConvertWeightData(fw_gate_kernel, fw_cand_kernel, input_size, hidden_size, gate_tensor_data, recu_tensor_data) !=
RET_OK) {
return nullptr;
}
auto gate_data_diff = hidden_size * input_size * 3;
auto recu_data_diff = hidden_size * hidden_size * 3;
if (ConvertWeightData(bw_gate_kernel, bw_cand_kernel, input_size, hidden_size, gate_tensor_data + gate_data_diff,
recu_tensor_data + recu_data_diff) != RET_OK) {
return nullptr;
}

if (ConvertBiasData(fw_gate_bias, fw_cand_bias, hidden_size, bias_tensor_data) != RET_OK) {
return nullptr;
}
auto bias_data_diff = hidden_size * 6;
if (ConvertBiasData(bw_gate_bias, bw_cand_bias, hidden_size, bias_tensor_data + bias_data_diff) != RET_OK) {
return nullptr;
}
std::vector<AnfNodePtr> new_node_inputs = {value_node, input, gate_weight, recu_weight,
bias, stacked_hidden, input_length};
auto new_node = func_graph->NewCNode(new_node_inputs);
new_node->set_fullname_with_scope(base_name);
return new_node;
}

CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
const std::string base_name) const {
MS_ASSERT(func_graph);
MS_ASSERT(gru_output);
auto split_primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::SplitT> split_attr = std::make_unique<schema::SplitT>();
split_attr->numberSplit = 2;
split_attr->splitDim = 1;
split_primitive->value.type = schema::PrimitiveType_Split;
split_primitive->value.value = split_attr.release();
auto split_cvalue = lite::PrimitiveC::Create(split_primitive.release());
auto split_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(split_cvalue));
std::vector<AnfNodePtr> new_node_inputs = {split_value_node, gru_output};
auto split_new_node = func_graph->NewCNode(new_node_inputs);
split_new_node->set_fullname_with_scope("split_" + base_name);
if (TfliteLstmCellFusion::SetAbstractTuple(split_new_node, 2) != RET_OK) {
return nullptr;
}

auto split_out1 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 0);
if (split_out1 == nullptr) {
return nullptr;
}
auto split_out2 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 1);
if (split_out2 == nullptr) {
return nullptr;
}

auto concat_primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::ConcatT> concat_attr = std::make_unique<schema::ConcatT>();
concat_attr->axis = 3;
concat_primitive->value.type = schema::PrimitiveType_Concat;
concat_primitive->value.value = concat_attr.release();
auto concat_cvalue = lite::PrimitiveC::Create(concat_primitive.release());
auto concat_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(concat_cvalue));
std::vector<AnfNodePtr> concat_new_node_inputs = {concat_value_node, split_out1, split_out2};
auto concat_new_node = func_graph->NewCNode(concat_new_node_inputs);
concat_new_node->set_fullname_with_scope("concat_" + base_name);
concat_new_node->set_abstract(gru_output->abstract()->Clone());

auto squeeze_primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::SqueezeT> squeeze_attr = std::make_unique<schema::SqueezeT>();
squeeze_attr->axis = std::vector<int>{1};
squeeze_primitive->value.type = schema::PrimitiveType_Squeeze;
squeeze_primitive->value.value = squeeze_attr.release();
auto squeeze_cvalue = lite::PrimitiveC::Create(squeeze_primitive.release());
auto squeeze_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(squeeze_cvalue));
std::vector<AnfNodePtr> squeeze_new_node_inputs = {squeeze_value_node, concat_new_node};
auto squeeze_new_node = func_graph->NewCNode(squeeze_new_node_inputs);
squeeze_new_node->set_fullname_with_scope("squeeze_" + base_name);
squeeze_new_node->set_abstract(gru_output->abstract()->Clone());

auto transpose_primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::TransposeT> transpose_attr = std::make_unique<schema::TransposeT>();
transpose_attr->perm = std::vector<int>{1, 0, 2};
transpose_primitive->value.type = schema::PrimitiveType_Transpose;
transpose_primitive->value.value = transpose_attr.release();
auto transpose_cvalue = lite::PrimitiveC::Create(transpose_primitive.release());
auto transpose_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(transpose_cvalue));
std::vector<AnfNodePtr> transpose_new_node_inputs = {transpose_value_node, squeeze_new_node};
auto transpose_new_node = func_graph->NewCNode(transpose_new_node_inputs);
transpose_new_node->set_fullname_with_scope("transpose_" + base_name);
transpose_new_node->set_abstract(gru_output->abstract()->Clone());

return transpose_new_node;
}

const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
const EquivPtr &equiv) const {
MS_ASSERT(func_graph);
MS_ASSERT(concat_node);
MS_LOG(DEBUG) << "bidirection tf gru fusion pass";
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}

auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]);
MS_ASSERT(transpose_input);
if (!utils::isa<CNodePtr>(transpose_input) || GetCNodeType(transpose_input) != schema::PrimitiveType_Transpose) {
return nullptr;
}

PrimitiveVarMapPtr fw_cond_primitive_vars = std::make_shared<PrimitiveVarMap>();
auto fw_cond_graph_pattern = GetCondGraphPattern(fw_cond_primitive_vars);
auto fw_cond = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[0]]);
MS_ASSERT(fw_cond != nullptr);
auto fw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_cond_graph_pattern, fw_cond_primitive_vars,
fw_cond, kCondCNodesNum, kCondNodesNum);
if (fw_cond_equiv == nullptr || fw_cond_equiv->empty()) {
return nullptr;
}

PrimitiveVarMapPtr bw_cond_primitive_vars = std::make_shared<PrimitiveVarMap>();
auto bw_cond_graph_pattern = GetCondGraphPattern(bw_cond_primitive_vars);
auto bw_cond = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[0]]);
MS_ASSERT(bw_cond != nullptr);
auto bw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_cond_graph_pattern, bw_cond_primitive_vars,
bw_cond, kCondCNodesNum, kCondNodesNum);
if (bw_cond_equiv == nullptr || bw_cond_equiv->empty()) {
return nullptr;
}

PrimitiveVarMapPtr fw_primitive_vars_body = std::make_shared<PrimitiveVarMap>();
auto fw_body_graph_pattern = GetBodyGraphPattern(fw_primitive_vars_body);
auto fw_body = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[1]]);
MS_ASSERT(fw_body != nullptr);
auto fw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_body_graph_pattern, fw_primitive_vars_body,
fw_body, kBodyCNodesNum, kBodyNodesNum);
if (fw_body_equiv == nullptr || fw_body_equiv->empty()) {
return nullptr;
}

PrimitiveVarMapPtr bw_primitive_vars_body = std::make_shared<PrimitiveVarMap>();
auto bw_body_graph_pattern = GetBodyGraphPattern(bw_primitive_vars_body);
auto bw_body = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[1]]);
MS_ASSERT(bw_body != nullptr);
auto bw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_body_graph_pattern, bw_primitive_vars_body,
bw_body, kBodyCNodesNum, kBodyNodesNum);
if (bw_body_equiv == nullptr || bw_body_equiv->empty()) {
return nullptr;
}

const std::string gru_name = "gru_" + concat_node->fullname_with_scope();
auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, fw_body_equiv, bw_body_equiv, gru_name);
if (gru_node == nullptr) {
return nullptr;
}
if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) {
return nullptr;
}

auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0);
if (get_item_node == nullptr) {
return nullptr;
}

auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope());
MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success";
return output_node;
}
} // namespace opt
} // namespace mindspore

+ 73
- 0
mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h View File

@@ -0,0 +1,73 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_
#include <vector>
#include <memory>
#include <string>
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
#include "schema/inner/model_generated.h"
#include "src/param_value_lite.h"
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"
#include "include/errorcode.h"

namespace mindspore {
namespace opt {
class BiDirectionTfGruCellFusion : public PatternProcessPass {
public:
explicit BiDirectionTfGruCellFusion(const std::string &name = "bidirection_tf_gru_cell_fusion",
bool multigraph = true);
~BiDirectionTfGruCellFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

protected:
virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;

private:
AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv,
const EquivPtr &fw_body_equiv, const EquivPtr &bw_body_equiv,
const std::string &base_name) const;
ParamValueLitePtr GetDefaultParamValue(const AnfNodePtr &parameter_anf) const;
lite::STATUS GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, const AnfNodePtr &bw_cand_kernel_anf,
int *input_size, int *hidden_size) const;
ParameterPtr AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name,
const std::vector<int> &shape, const TypeId type, void **tensor_data) const;
lite::STATUS ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, const int input_size,
const int hidden_size, float *gate_tensor_data, float *recu_tensor_data) const;
lite::STATUS ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, const int hidden_size,
float *tensor_data) const;
void CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, const int r1, const int c0,
const int c1, float *data, bool t = false) const;
CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &hidden_state,
const std::string base_name) const;
CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
const std::string base_name) const;

private:
std::vector<VarPtr> common_vars_;
std::vector<VarPtr> fw_vars_;
std::vector<VarPtr> bw_vars_;
VarPtr input_;
VarPtr input_length_;
VarPtr transpose_input_;
};
} // namespace opt
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_

+ 370
- 0
mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc View File

@@ -0,0 +1,370 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
#include <memory>
#include "src/ops/primitive_c.h"
#include "src/common/utils.h"
#include "src/param_value_lite.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"

namespace mindspore {
namespace opt {
namespace {
constexpr size_t kLstmInputsLength = 13;
constexpr size_t kLstmInputsVarNum = 11;
constexpr size_t kCondNodesNum = 12;
constexpr size_t kCondCNodesNum = 4;
constexpr size_t kBodyNodesNum = 82;
constexpr size_t kBodyCNodesNum = 30;
const auto &p1 = std::placeholders::_1;

bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }

bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
return opt::GetCNodeType(n) == type;
}
return false;
}
} // namespace

TfLstmCellFusion::TfLstmCellFusion(const std::string &name, bool multigraph)
: TfliteLstmCellFusion(name, multigraph, kLstmInputsLength, kLstmInputsVarNum, kCondNodesNum, kCondCNodesNum,
kBodyNodesNum, kBodyCNodesNum) {
/*
* vars for lstm cell input
* 0:cond 1:body 2:index 3:limit1 4:output 5:cell 6:hidden 7:limit2 8:input 9:kernel 10:bias
*/
forget_bias_ = std::make_shared<Var>();
}

AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
std::vector<CondVarPtr> placeholders;
for (int i = 0; i < 10; ++i) {
placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode));
}
VectorRef add2 = VectorRef({std::make_shared<Var>(), placeholders[2], std::make_shared<CondVar>(IsParameterNode)});
VectorRef add3 = VectorRef({std::make_shared<Var>(), placeholders[0], std::make_shared<CondVar>(IsParameterNode)});

VectorRef get_item = VectorRef(
{std::make_shared<Var>("GetItem"), placeholders[7], placeholders[2], std::make_shared<CondVar>(IsParameterNode)});
VectorRef concat_input_h = VectorRef({std::make_shared<Var>(), get_item, placeholders[5]});

VectorRef matmul = VectorRef({std::make_shared<Var>(), concat_input_h, placeholders[8]});
VectorRef bias = VectorRef({std::make_shared<Var>(), matmul, placeholders[9]});
VectorRef split = VectorRef({std::make_shared<Var>(), bias});

VectorRef get_item1 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()});
VectorRef get_item2 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()});
VectorRef get_item3 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()});
VectorRef get_item4 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()});

VectorRef input_gate = VectorRef({std::make_shared<Var>("Sigmoid"), get_item1});
VectorRef input_to_cell = VectorRef({std::make_shared<Var>("Tanh"), get_item2});
VectorRef forget_bias = VectorRef({std::make_shared<Var>("Add"), get_item3, forget_bias_});
VectorRef forget_gate = VectorRef({std::make_shared<Var>("Sigmoid"), forget_bias});
VectorRef output_gate = VectorRef({std::make_shared<Var>("Sigmoid"), get_item4});

VectorRef forgetted_cell = VectorRef({std::make_shared<Var>(""), forget_gate, placeholders[4]});
VectorRef inputted_cell = VectorRef({std::make_shared<Var>(""), input_gate, input_to_cell});
VectorRef input_forget_cell = VectorRef({std::make_shared<Var>("Add"), forgetted_cell, inputted_cell});
VectorRef to_new_hidden = VectorRef({std::make_shared<Var>("Tanh"), input_forget_cell});
VectorRef new_hidden = VectorRef({std::make_shared<Var>("Mul"), output_gate, to_new_hidden});

VectorRef new_to_cell = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_new_, input_forget_cell});
VectorRef old_to_cell = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_old_, placeholders[4]});
VectorRef output_cell = VectorRef({std::make_shared<Var>("Add"), new_to_cell, old_to_cell});

VectorRef new_to_hidden = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_new_, new_hidden});
VectorRef old_to_hidden = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_old_, placeholders[5]});
VectorRef output_hidden = VectorRef({std::make_shared<Var>("Add"), new_to_hidden, old_to_hidden});

VectorRef set_item = VectorRef({std::make_shared<Var>(""), placeholders[3], placeholders[2], new_hidden});

auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple));
std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, output_cell, output_hidden};
outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end());
VectorRef make_tuple_node = VectorRef(outputs);
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return));
VectorRef return_node = VectorRef({is_return, make_tuple_node});

VarPtr fg = std::make_shared<Var>("RootG");
auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true);
return pattern;
}

STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector<int> &shape,
const float *const data_ptr, const int hidden_size) const {
MS_ASSERT(weight != nullptr);
MS_ASSERT(data_ptr != nullptr);
auto default_param = std::make_shared<ParamValueLite>();
if (default_param == nullptr) {
MS_LOG(ERROR) << "new_default is nullptr";
return RET_ERROR;
}
default_param->set_tensor_shape(shape);
default_param->set_tensor_type(kNumberTypeFloat32);
default_param->set_format(schema::Format_NHWC);

if (shape.size() != 3) {
MS_LOG(ERROR) << "lstm weight shape must have 3 dims";
return RET_ERROR;
}
const auto param_num = shape[0] * shape[1] * shape[2];
auto tensor_data = new (std::nothrow) float[param_num * 4];
std::vector<int> data_diff{0, 3, 2, 1};
if (tensor_data == nullptr) {
MS_LOG(DEBUG) << "new data failed";
return RET_ERROR;
}
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < hidden_size; ++j) {
for (int t = 0; t < shape[2]; ++t) {
tensor_data[(i * hidden_size + j) * shape[2] + t] = data_ptr[t * shape[1] + data_diff[i] * hidden_size + j];
}
}
}
default_param->SetTensorData(tensor_data, param_num * 4);
weight->set_default_param(default_param);
std::vector<int64_t> shape_vector_i(shape.begin(), shape.end());
auto abstract_tensor_i = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector_i);
if (abstract_tensor_i == nullptr) {
MS_LOG(ERROR) << "abstract_tensor is nullptr";
delete[] tensor_data;
return RET_ERROR;
}
weight->set_abstract(abstract_tensor_i);
return RET_OK;
}

STATUS TfLstmCellFusion::SplitWeights(const AnfNodePtr &weight, const ParameterPtr &weight_i,
const ParameterPtr &weight_c, int hidden_size) const {
// split input_size and hidden_size at dim 0
// transform i,c,f,o to i,o,f,c at dim 1
MS_ASSERT(weight != nullptr);
MS_ASSERT(wiehgt_i != nullptr);
MS_ASSERT(wiehgt_c != nullptr);
if (!utils::isa<ParameterPtr>(weight)) {
return RET_ERROR;
}
auto weight_param = utils::cast<ParameterPtr>(weight);
if (!weight_param->has_default()) {
MS_LOG(DEBUG) << "weight not have default value";
return RET_ERROR;
}
if (!utils::isa<ParamValueLitePtr>(weight_param->default_param())) {
MS_LOG(DEBUG) << "default value is not ParamValueLite";
return RET_FAILED;
}
auto origin_tensor = std::dynamic_pointer_cast<ParamValueLite>(weight_param->default_param());
if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) {
MS_LOG(DEBUG) << "origin_tensor is not float32 type";
return RET_ERROR;
}
auto data_ptr = reinterpret_cast<float *>(origin_tensor->tensor_addr());
auto data_shape = origin_tensor->tensor_shape();
if (data_shape.size() != 2) {
MS_LOG(ERROR) << "weight data shape invalid";
return RET_ERROR;
}
if (data_shape[0] <= hidden_size) {
MS_LOG(ERROR) << "weight data shape[0] invalid";
return RET_ERROR;
}
if (hidden_size * 4 != data_shape[1]) {
MS_LOG(ERROR) << "weight data shape[1] invalid";
return RET_ERROR;
}
const auto input_size = data_shape[0] - hidden_size;

std::vector<int> shape_i{1, 4 * hidden_size, input_size};
if (SetWeightAbstractAndDefault(weight_i, shape_i, data_ptr, hidden_size) != RET_OK) {
MS_LOG(ERROR) << "get weight_i failed";
return RET_ERROR;
}

std::vector<int> shape_c{1, 4 * hidden_size, hidden_size};
if (SetWeightAbstractAndDefault(weight_c, shape_c, data_ptr + input_size * data_shape[1], hidden_size) != RET_OK) {
MS_LOG(ERROR) << "get weight_i failed";
return RET_ERROR;
}
return RET_OK;
}

STATUS TfLstmCellFusion::PopulateBiasNode(const EquivPtr &body_equiv, const ParameterPtr &new_bias,
const AnfNodePtr &old_bias, const int hidden_size) const {
MS_ASSERT(body_equiv != nullptr);
MS_ASSERT(new_bias != nullptr);
MS_ASSERT(old_bias != nullptr);
if (!utils::isa<ParameterPtr>(old_bias)) {
MS_LOG(DEBUG) << "old_bias is not parameter";
return RET_ERROR;
}
auto old_bias_param = utils::cast<ParameterPtr>(old_bias);
if (!old_bias_param->has_default()) {
MS_LOG(DEBUG) << "bias not have default value";
return RET_ERROR;
}
if (!utils::isa<ParamValueLitePtr>(old_bias_param->default_param())) {
MS_LOG(DEBUG) << "default value is not ParamValueLite";
return RET_FAILED;
}
auto origin_tensor = std::dynamic_pointer_cast<ParamValueLite>(old_bias_param->default_param());
if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) {
MS_LOG(DEBUG) << "origin_tensor is not float32 type";
return RET_ERROR;
}
auto data_ptr = reinterpret_cast<float *>(origin_tensor->tensor_addr());
auto data_shape = origin_tensor->tensor_shape();
if (data_shape.size() != 1 || data_shape[0] != 4 * hidden_size) {
MS_LOG(DEBUG) << "bias data shape illegal";
return RET_ERROR;
}
std::vector<int> shape{1, 8 * hidden_size};

auto default_param = std::make_shared<ParamValueLite>();
if (default_param == nullptr) {
MS_LOG(ERROR) << "new_default is nullptr";
return RET_ERROR;
}
default_param->set_tensor_shape(shape);
default_param->set_tensor_type(kNumberTypeFloat32);
default_param->set_format(schema::Format_NHWC);
auto tensor_data = new (std::nothrow) float[hidden_size * 8];

auto forget_bias_node = utils::cast<AnfNodePtr>((*body_equiv)[forget_bias_]);
if (forget_bias_node == nullptr) {
MS_LOG(ERROR) << "forget bias node is nullptr";
return RET_ERROR;
}
float forget_bias_value = 0.0f;
if (GetFloatScalarFromParamValueLite(forget_bias_node, &forget_bias_value) != RET_OK) {
return RET_ERROR;
}

std::vector<int> data_diff{0, 3, 2, 1};
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < hidden_size; ++j) {
if (i < 4) {
tensor_data[i * hidden_size + j] = data_ptr[data_diff[i] * hidden_size + j];
if (i == 2) { // forget bias
tensor_data[i * hidden_size + j] += forget_bias_value;
}
} else {
tensor_data[i * hidden_size + j] = 0.0f;
}
}
}
default_param->SetTensorData(tensor_data, hidden_size * 8 * 4);
new_bias->set_default_param(default_param);
std::vector<int64_t> shape_vector_i(shape.begin(), shape.end());
auto abstract_tensor_i = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector_i);
if (abstract_tensor_i == nullptr) {
MS_LOG(ERROR) << "abstract_tensor is nullptr";
delete[] tensor_data;
return RET_ERROR;
}
new_bias->set_abstract(abstract_tensor_i);
return RET_OK;
}

CNodePtr TfLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const EquivPtr &body_equiv, const std::string &base_name,
const float smooth) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
auto lstm_primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::LstmT> attr = std::make_unique<schema::LstmT>();
attr->bidirection = false;
attr->smooth = smooth;
lstm_primitive->value.type = schema::PrimitiveType_Lstm;
lstm_primitive->value.value = attr.release();
auto lstm_cvalue = lite::PrimitiveC::Create(lstm_primitive.release());
auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(lstm_cvalue));

auto &vars = while_input_vars_;

auto limit1 = utils::cast<AnfNodePtr>((*equiv)[vars[3]]);
MS_ASSERT(limit1);
auto limit2 = utils::cast<AnfNodePtr>((*equiv)[vars[7]]);
MS_ASSERT(limit2);
auto weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
MS_ASSERT(weight);
auto bias = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);
MS_ASSERT(bias);
auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]);
MS_ASSERT(input);
auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]);
MS_ASSERT(cell);
auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]);
MS_ASSERT(hidden);

if (!utils::isa<ParameterPtr>(hidden)) {
MS_LOG(DEBUG) << "hidden is not parameter";
return nullptr;
}
auto hidden_param = utils::cast<ParameterPtr>(hidden);
if (!utils::isa<abstract::AbstractTensorPtr>(hidden_param->abstract())) {
MS_LOG(DEBUG) << "hidden abstract is not AbstractTensor";
return nullptr;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(hidden_param->abstract());
auto hidden_shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
if (hidden_shape.size() == 0) {
MS_LOG(DEBUG) << "can't get hidden shape";
return nullptr;
}

auto i_weight = func_graph->add_parameter();
i_weight->set_name(base_name + "_weight_i");
i_weight->set_abstract(weight->abstract()->Clone());

auto c_weight = func_graph->add_parameter();
c_weight->set_name(base_name + "_weight_c");
c_weight->set_abstract(weight->abstract()->Clone());

if (SplitWeights(weight, i_weight, c_weight, hidden_shape.back()) != RET_OK) {
MS_LOG(DEBUG) << "split weight to i_weight and c_weight failed";
return nullptr;
}

auto bias_node = func_graph->add_parameter();
bias_node->set_name(base_name + "_bias");
bias_node->set_abstract(bias->abstract()->Clone());

if (PopulateBiasNode(body_equiv, bias_node, bias, hidden_shape.back()) != RET_OK) {
MS_LOG(DEBUG) << "reorder bias failed";
return nullptr;
}

if (!utils::isa<CNodePtr>(input) || GetCNodeType(input) != schema::PrimitiveType_TensorListFromTensor) {
MS_LOG(DEBUG) << "input is not tensorlistfromtensor op";
return nullptr;
}
auto tensor_list_cnode = utils::cast<CNodePtr>(input);
auto input_tensor_node = tensor_list_cnode->input(1);

std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias_node, hidden,
cell};
auto new_node = func_graph->NewCNode(new_node_inputs);
new_node->set_fullname_with_scope(base_name);
return new_node;
}
} // namespace opt
} // namespace mindspore

+ 53
- 0
mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.h View File

@@ -0,0 +1,53 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_
#include <vector>
#include <memory>
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"
#include "src/param_value_lite.h"
#include "include/errorcode.h"

namespace mindspore {
namespace opt {
class TfLstmCellFusion : public TfliteLstmCellFusion {
public:
explicit TfLstmCellFusion(const std::string &name = "lstm_cell_fusion", bool multigraph = true);
~TfLstmCellFusion() override = default;

private:
AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const override;
CNodePtr CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const EquivPtr &body_equiv,
const std::string &base_name, const float smooth) const override;

lite::STATUS SplitWeights(const AnfNodePtr &weight, const ParameterPtr &weight_i, const ParameterPtr &weight_c,
int hidden_size) const;
lite::STATUS SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector<int> &shape,
const float *const data_ptr, const int hidden_size) const;
lite::STATUS PopulateBiasNode(const EquivPtr &body_equiv, const ParameterPtr &new_bias, const AnfNodePtr &old_bias,
const int hidden_size) const;

private:
VarPtr forget_bias_ = nullptr;
};
} // namespace opt
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_

+ 727
- 0
mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc View File

@@ -0,0 +1,727 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
#include <memory>
#include <functional>
#include "src/ops/primitive_c.h"
#include "src/common/utils.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"

namespace mindspore {
namespace opt {
namespace {
constexpr size_t kWhileInputsLength = 23;
constexpr size_t kWhileInputsVarNum = 21;
constexpr size_t kCondNodesNum = 12;
constexpr size_t kCondCNodesNum = 4;
constexpr size_t kBodyNodesNum = 95;
constexpr size_t kBodyCNodesNum = 34;
constexpr size_t kLSTMOutputNum = 3;
const auto &p1 = std::placeholders::_1;
constexpr float EPSILON = 1e-5;

bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }

bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
return opt::GetCNodeType(n) == type;
}
return false;
}
} // namespace

STATUS TfliteLstmCellFusion::GetFloatScalarFromParamValueLite(const AnfNodePtr &param_value, float *v) const {
if (param_value == nullptr || v == nullptr) {
MS_LOG(ERROR) << "param_value or v is nullptr";
return RET_ERROR;
}
if (!utils::isa<ParameterPtr>(param_value)) {
MS_LOG(DEBUG) << "param_value is not ParamValueLitePtr";
return RET_ERROR;
}
auto param_ptr = utils::cast<ParameterPtr>(param_value);
if (!param_ptr->has_default()) {
MS_LOG(DEBUG) << "param not have default";
return RET_ERROR;
}
auto default_param = param_ptr->default_param();
if (!utils::isa<ParamValueLitePtr>(default_param)) {
MS_LOG(DEBUG) << "param_value is not ParamValueLitePtr";
return RET_ERROR;
}
auto default_param_ptr = utils::cast<ParamValueLitePtr>(default_param);
auto tensor_shape = default_param_ptr->tensor_shape();
if (!(tensor_shape.size() == 0 || (tensor_shape.size() == 1 && tensor_shape[0] == 1))) {
MS_LOG(DEBUG) << "default param is not scalar";
return RET_ERROR;
}
if (default_param_ptr->tensor_type() != kNumberTypeFloat32 && default_param_ptr->tensor_type() != kNumberTypeFloat) {
MS_LOG(DEBUG) << "default param is not float";
return RET_ERROR;
}
*v = *(reinterpret_cast<float *>(default_param_ptr->tensor_addr()));
return RET_OK;
}

TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num,
int cond_nodes_num, int cond_cnodes_num, int body_nodes_num,
int body_cnodes_num)
: PatternProcessPass(name, multigraph) {
/*
* input vars for lstm while node
* 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
* 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_
*/
this->while_inputs_num_ = input_length == 0 ? kWhileInputsLength : input_length;
this->while_input_var_num_ = var_num == 0 ? kWhileInputsVarNum : var_num;
this->cond_nodes_num_ = cond_nodes_num == 0 ? kCondNodesNum : cond_nodes_num;
this->cond_cnodes_num_ = cond_cnodes_num == 0 ? kCondCNodesNum : cond_cnodes_num;
this->body_nodes_num_ = body_nodes_num == 0 ? kBodyNodesNum : body_nodes_num;
this->body_cnodes_num_ = body_cnodes_num == 0 ? kBodyCNodesNum : body_cnodes_num;
for (size_t i = 0; i < this->while_input_var_num_; ++i) {
while_input_vars_.emplace_back(std::make_shared<Var>());
}
cell_smooth_old_ = std::make_shared<Var>();
cell_smooth_new_ = std::make_shared<Var>();
hidden_smooth_old_ = std::make_shared<Var>();
hidden_smooth_new_ = std::make_shared<Var>();
}

AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode);
auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode);
auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode);
auto is_parameter4 = std::make_shared<CondVar>(IsParameterNode);
auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less));
auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less));
auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_LogicalAnd));
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return));
VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2});
VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4});
VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref});
VectorRef return_ref = VectorRef({is_return, logicaland_ref});
VarPtr fg = std::make_shared<Var>("RootG");
auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true);
return pattern;
}

AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
std::vector<CondVarPtr> placeholders;
for (int i = 0; i < 20; ++i) {
placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode));
}
VectorRef add2 = VectorRef({std::make_shared<Var>(), placeholders[2], std::make_shared<CondVar>(IsParameterNode)});
VectorRef add3 = VectorRef({std::make_shared<Var>(), placeholders[0], std::make_shared<CondVar>(IsParameterNode)});

VectorRef concat_i_w = VectorRef({std::make_shared<Var>(), placeholders[8], placeholders[12]});
VectorRef concat_f_w = VectorRef({std::make_shared<Var>(), placeholders[9], placeholders[13]});
VectorRef concat_c_w = VectorRef({std::make_shared<Var>(), placeholders[10], placeholders[14]});
VectorRef concat_o_w = VectorRef({std::make_shared<Var>(), placeholders[11], placeholders[15]});

VectorRef get_item = VectorRef(
{std::make_shared<Var>("GetItem"), placeholders[7], placeholders[2], std::make_shared<CondVar>(IsParameterNode)});
VectorRef concat_input_h = VectorRef({std::make_shared<Var>(), get_item, placeholders[5]});

VectorRef matmul_input = VectorRef({std::make_shared<Var>(), concat_input_h, concat_i_w});
VectorRef matmul_forget = VectorRef({std::make_shared<Var>(), concat_input_h, concat_f_w});
VectorRef matmul_cell = VectorRef({std::make_shared<Var>(), concat_input_h, concat_c_w});
VectorRef matmul_output = VectorRef({std::make_shared<Var>(), concat_input_h, concat_o_w});

VectorRef bias_input = VectorRef({std::make_shared<Var>(), matmul_input, placeholders[16]});
VectorRef bias_forget = VectorRef({std::make_shared<Var>(), matmul_forget, placeholders[17]});
VectorRef bias_cell = VectorRef({std::make_shared<Var>(), matmul_cell, placeholders[18]});
VectorRef bias_output = VectorRef({std::make_shared<Var>(), matmul_output, placeholders[19]});

VectorRef cell = VectorRef({std::make_shared<Var>("Tanh"), bias_cell});
VectorRef input_gate = VectorRef({std::make_shared<Var>("Sigmoid"), bias_input});
VectorRef cell_input = VectorRef({std::make_shared<Var>("Mul"), input_gate, cell});
VectorRef forget_gate = VectorRef({std::make_shared<Var>("Sigmoid"), bias_forget});
VectorRef cell_forgeted = VectorRef({std::make_shared<Var>("Mul"), forget_gate, placeholders[4]});
VectorRef cell_new = VectorRef({std::make_shared<Var>("Add"), cell_forgeted, cell_input});

VectorRef smooth_cell_old = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_old_, placeholders[4]});
VectorRef smooth_cell_new = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_new_, cell_new});
VectorRef cell_output = VectorRef({std::make_shared<Var>("Add"), smooth_cell_new, smooth_cell_old});

VectorRef output_gate = VectorRef({std::make_shared<Var>("Sigmoid"), bias_output});
VectorRef cell_to_output = VectorRef({std::make_shared<Var>("Tanh"), cell_new});
VectorRef output = VectorRef({std::make_shared<Var>("Mul"), output_gate, cell_to_output});

VectorRef smooth_hidden_old = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_old_, placeholders[5]});
VectorRef smooth_hidden_new = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_new_, output});
VectorRef hidden_output = VectorRef({std::make_shared<Var>("Add"), smooth_hidden_new, smooth_hidden_old});

VectorRef set_item = VectorRef({std::make_shared<Var>("SetItem"), placeholders[3], placeholders[2], output});

auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple));
std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output};
outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end());
VectorRef make_tuple_node = VectorRef(outputs);
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return));
VectorRef return_node = VectorRef({is_return, make_tuple_node});

VarPtr fg = std::make_shared<Var>("RootG");
auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true);
return pattern;
}

const BaseRef TfliteLstmCellFusion::DefinePattern() const {
auto is_while_node = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While));
VectorRef while_node = VectorRef({is_while_node});
auto while_inputs = while_input_vars_;
while_inputs.insert(while_inputs.begin() + 4, while_input_vars_[2]);
while_node.insert(while_node.end(), while_inputs.begin(), while_inputs.end());

auto is_tuple_get_item = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem));
VectorRef while_output = VectorRef({is_tuple_get_item, while_node, std::make_shared<Var>()});

auto is_tensor_list_stack = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack));
auto is_parameter = std::make_shared<CondVar>(IsParameterNode);
VectorRef tensor_list_stack_node = VectorRef({is_tensor_list_stack, while_output, is_parameter});

return tensor_list_stack_node;
}

EquivPtr TfliteLstmCellFusion::MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars,
const AnfNodePtr &pattern) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(pattern != nullptr);
auto return_node = func_graph->get_return();
PatternEngine pattern_engine(PatternEngine(std::make_shared<DefaultVisitor>(),
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual)));
auto empty_equiv = std::make_shared<Equiv>();
EquivPtr equiv = pattern_engine.Match(pattern, return_node, *primitive_vars, empty_equiv);
return equiv;
}

// make sure that only 3,4,5 output of while is referenced
bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(while_cnode != nullptr);
auto manager = func_graph->manager();
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr";
return RET_ERROR;
}
auto while_node_users = manager->node_users()[while_cnode];
std::vector<size_t> valid_indexes{3, 4, 5};
for (auto &node_user : while_node_users) {
if (!utils::isa<CNodePtr>(node_user.first)) {
return false;
}
auto cnode = utils::cast<CNodePtr>(node_user.first);
if (GetCNodeType(cnode) != schema::PrimitiveType_TupleGetItem) {
return false;
}
auto index = GetTupleGetItemOutIndex(cnode);
if (!lite::IsContain(valid_indexes, index)) {
return false;
}
}
return true;
}

EquivPtr TfliteLstmCellFusion::CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern,
const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph,
const size_t cnode_num, const size_t all_node_num) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(pattern != nullptr);
MS_ASSERT(anf_sub_graph != nullptr);
auto sub_graph = GetValueNode<FuncGraphPtr>(anf_sub_graph);
auto nodes = TopoSort(sub_graph->get_return());
auto cnodes = sub_graph->GetOrderedCnodes();
if (cnodes.size() != cnode_num || nodes.size() != all_node_num) {
MS_LOG(DEBUG) << "sub graph nodes num not match";
return nullptr;
}
return MatchGraph(sub_graph, primitive_vars, pattern);
}

bool TfliteLstmCellFusion::CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const CNodePtr &while_cnode, float *smooth) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
MS_ASSERT(while_cnode != nullptr);
MS_ASSERT(smooth != nullptr);

auto cell_smooth_old_node = utils::cast<AnfNodePtr>((*equiv)[cell_smooth_old_]);
MS_ASSERT(cell_smooth_old_node != nullptr);
auto cell_smooth_new_node = utils::cast<AnfNodePtr>((*equiv)[cell_smooth_new_]);
MS_ASSERT(cell_smooth_new_node != nullptr);
auto hidden_smooth_old_node = utils::cast<AnfNodePtr>((*equiv)[hidden_smooth_old_]);
MS_ASSERT(hidden_smooth_old_node != nullptr);
auto hidden_smooth_new_node = utils::cast<AnfNodePtr>((*equiv)[hidden_smooth_new_]);
MS_ASSERT(hidden_smooth_new_node != nullptr);

float cell_old, cell_new, hidden_old, hidden_new;
if (GetFloatScalarFromParamValueLite(cell_smooth_old_node, &cell_old) != RET_OK) {
return false;
}
if (GetFloatScalarFromParamValueLite(cell_smooth_new_node, &cell_new) != RET_OK) {
return false;
}
if (GetFloatScalarFromParamValueLite(hidden_smooth_old_node, &hidden_old) != RET_OK) {
return false;
}
if (GetFloatScalarFromParamValueLite(hidden_smooth_new_node, &hidden_new) != RET_OK) {
return false;
}
if (cell_old < 0.0f || cell_old > 1.0f || cell_new < 0.0f || cell_new > 1.0f) {
MS_LOG(DEBUG) << "cell smooth value illegal";
return false;
}
if (hidden_old < 0.0f || hidden_old > 1.0f || hidden_new < 0.0f || hidden_new > 1.0f) {
MS_LOG(DEBUG) << "hidden smooth value illegal";
return false;
}
if (std::abs(cell_old + cell_new - 1.0f) > EPSILON || std::abs(hidden_old + hidden_new - 1.0f) > EPSILON ||
std::abs(cell_old - hidden_old) > EPSILON) {
MS_LOG(DEBUG) << "smooth value illegal";
return false;
}
*smooth = cell_old;
return true;
}

STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector<AnfNodePtr> &params, const ParameterPtr &new_param,
bool is_bias) const {
MS_ASSERT(new_param != nullptr);
MS_ASSERT(params.size() == 4);
std::vector<float *> data_ptrs;
std::vector<std::vector<int>> data_shapes;
for (auto &param : params) {
if (!utils::isa<ParameterPtr>(param)) {
MS_LOG(DEBUG) << "param is not Parameter node";
return RET_FAILED;
}
auto param_t = utils::cast<ParameterPtr>(param);
if (!param_t->has_default()) {
MS_LOG(DEBUG) << "param not have default value";
return RET_FAILED;
}
if (!utils::isa<ParamValueLitePtr>(param_t->default_param())) {
MS_LOG(DEBUG) << "default value is not ParamValueLite";
return RET_FAILED;
}
auto origin_tensor = std::dynamic_pointer_cast<ParamValueLite>(param_t->default_param());
if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) {
MS_LOG(DEBUG) << "origin_tensor is not float32 type";
return RET_FAILED;
}
auto data_ptr = reinterpret_cast<float *>(origin_tensor->tensor_addr());
auto data_shape = origin_tensor->tensor_shape();
data_ptrs.push_back(data_ptr);
data_shapes.push_back(data_shape);
}

for (size_t i = 1; i < data_shapes.size(); ++i) {
if (data_shapes[i] != data_shapes[0]) {
MS_LOG(DEBUG) << "data shape not same";
return RET_FAILED;
}
}
auto new_default = std::make_shared<ParamValueLite>();
if (new_default == nullptr) {
MS_LOG(ERROR) << "new_default is nullptr";
return RET_ERROR;
}
std::vector<int> new_shape;
float *tensor_data = nullptr;
int step = 0;
int data_size = 0;
if (is_bias) {
if (data_shapes[0].size() != 1) {
MS_LOG(ERROR) << "bias data shape error";
return RET_ERROR;
}
step = data_shapes[0][0];
data_size = 8 * step;
new_shape = std::vector<int>({1, data_size});

} else {
if (data_shapes[0].size() != 2) {
MS_LOG(ERROR) << "weight data shape error";
return RET_ERROR;
}
new_shape = std::vector<int>({1, data_shapes[0][0] * 4, data_shapes[0][1]});
step = data_shapes[0][0] * data_shapes[0][1];
data_size = 4 * step;
}

tensor_data = new (std::nothrow) float[data_size];
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "new data failed";
return RET_ERROR;
}
for (int i = 0; i < data_size; ++i) { // bias are stored into first 4*hidden_size buffer, the rest is all 0
tensor_data[i] = 0.0f;
}

for (size_t i = 0; i < data_ptrs.size(); ++i) {
auto source_len = std::accumulate(data_shapes[i].begin(), data_shapes[i].end(), 1, std::multiplies<int>());
auto ret = memcpy_s(tensor_data + i * step, step * sizeof(float), data_ptrs[i], source_len * sizeof(float));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s error";
delete[] tensor_data;
return RET_ERROR;
}
}
new_default->set_tensor_shape(new_shape);
new_default->set_tensor_type(kNumberTypeFloat32);
new_default->set_format(schema::Format_NHWC);
new_default->SetTensorData(tensor_data, data_size * sizeof(float));
new_param->set_default_param(new_default);

std::vector<int64_t> shape_vector(new_shape.begin(), new_shape.end());
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "abstract_tensor is nullptr";
return RET_ERROR;
}
new_param->set_abstract(abstract_tensor);
return RET_OK;
}

CNodePtr TfliteLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const EquivPtr &body_equiv, const std::string &base_name,
const float smooth) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
MS_ASSERT(body_equiv != nullptr);
/*
* input vars for while node
* 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
* 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_
*/
auto lstm_primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::LstmT> attr = std::make_unique<schema::LstmT>();
attr->bidirection = false;
attr->smooth = smooth;
lstm_primitive->value.type = schema::PrimitiveType_Lstm;
lstm_primitive->value.value = attr.release();
auto lstm_cvalue = lite::PrimitiveC::Create(lstm_primitive.release());
auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(lstm_cvalue));

auto &vars = while_input_vars_;

auto limit1 = utils::cast<AnfNodePtr>((*equiv)[vars[3]]);
MS_ASSERT(limit1);
auto limit2 = utils::cast<AnfNodePtr>((*equiv)[vars[7]]);
MS_ASSERT(limit2);

auto i2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
MS_ASSERT(i2i_weight);
auto i2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);
MS_ASSERT(i2f_weight);
auto i2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[11]]);
MS_ASSERT(i2c_weight);
auto i2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[12]]);
MS_ASSERT(i2o_weight);

auto c2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[13]]);
MS_ASSERT(c2i_weight);
auto c2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[14]]);
MS_ASSERT(c2f_weight);
auto c2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[15]]);
MS_ASSERT(c2c_weight);
auto c2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[16]]);
MS_ASSERT(c2o_weight);

auto i_bias = utils::cast<AnfNodePtr>((*equiv)[vars[17]]);
MS_ASSERT(i_bias);
auto f_bias = utils::cast<AnfNodePtr>((*equiv)[vars[18]]);
MS_ASSERT(f_bias);
auto c_bias = utils::cast<AnfNodePtr>((*equiv)[vars[19]]);
MS_ASSERT(c_bias);
auto o_bias = utils::cast<AnfNodePtr>((*equiv)[vars[20]]);
MS_ASSERT(o_bias);

auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]);
MS_ASSERT(input);
auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]);
MS_ASSERT(cell);
auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]);
MS_ASSERT(hidden);

std::vector<AnfNodePtr> i_weights{i2i_weight, i2o_weight, i2f_weight, i2c_weight};
auto i_weight = func_graph->add_parameter();
auto status = GetConcatedParam(i_weights, i_weight, false);
if (status != RET_OK) {
return nullptr;
}
i_weight->set_name(base_name + "_weight_i");

std::vector<AnfNodePtr> c_weights{c2i_weight, c2o_weight, c2f_weight, c2c_weight};
auto c_weight = func_graph->add_parameter();
status = GetConcatedParam(c_weights, c_weight, false);
if (status != RET_OK) {
return nullptr;
}
c_weight->set_name(base_name + "_weight_c");

std::vector<AnfNodePtr> biases{i_bias, o_bias, f_bias, c_bias};
auto bias = func_graph->add_parameter();
status = GetConcatedParam(biases, bias, true);
if (status != RET_OK) {
return nullptr;
}
bias->set_name(base_name + "_bias");

if (!utils::isa<CNodePtr>(input) || GetCNodeType(input) != schema::PrimitiveType_TensorListFromTensor) {
MS_LOG(DEBUG) << "input is not tensorlistfromtensor op";
return nullptr;
}
auto tensor_list_cnode = utils::cast<CNodePtr>(input);
auto input_tensor_node = tensor_list_cnode->input(1);

std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias, hidden, cell};
auto new_node = func_graph->NewCNode(new_node_inputs);
new_node->set_fullname_with_scope(base_name);
return new_node;
}

CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node,
const int item_index) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(get_items != nullptr);
auto tuple_get_item_prim_ptr = lite::GetTupleGetItemPrim();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
return nullptr;
}
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
auto get_item_value = NewValueNode(MakeValue<int>(item_index));
if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
MS_LOG(ERROR) << "NewValueNode is nullptr";
return nullptr;
}
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, node, get_item_value};
CNodePtr get_item_cnode = func_graph->NewCNode(inputs);
std::vector<int64_t> shape_vector;
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "create abstract_tensor failed";
return nullptr;
}
get_item_cnode->set_abstract(abstract_tensor);
get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" +
std::to_string(item_index));
return get_item_cnode;
}

STATUS TfliteLstmCellFusion::AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode,
const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(while_cnode != nullptr);
auto manager = func_graph->manager();
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr";
return RET_ERROR;
}
auto tr = manager->Transact();
auto while_node_users = manager->node_users()[while_cnode];
for (auto &node_user : while_node_users) {
if (node_user.first == output_get_item) {
continue;
}
if (!utils::isa<CNodePtr>(node_user.first)) {
return RET_ERROR;
}
auto get_item = utils::cast<CNodePtr>(node_user.first);
if (GetCNodeType(get_item) != schema::PrimitiveType_TupleGetItem) {
return RET_ERROR;
}
auto new_inputs = get_item->inputs();
if (new_inputs.size() != 3) {
return RET_ERROR;
}
new_inputs[1] = lstm_cnode;
auto index_vnode = get_item->input(2);
if (!utils::isa<ValueNode>(index_vnode)) {
MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node";
return RET_ERROR;
}
auto value_node = utils::cast<ValueNodePtr>(index_vnode);
if (value_node == nullptr) {
MS_LOG(ERROR) << "cast to ValueNode failed";
return RET_ERROR;
}
auto origin_index = GetValue<int>(value_node->value());
int new_index = origin_index == 4 ? 2 : 1;
auto new_index_vnode = NewValueNode(MakeValue<int>(new_index));
new_inputs[2] = new_index_vnode;
get_item->set_inputs(new_inputs);
get_item->set_fullname_with_scope(lstm_cnode->fullname_with_scope() + "_getitem_" + std::to_string(new_index));
if (get_item->abstract() == nullptr) {
MS_LOG(ERROR) << "get_item's abstract is nullptr";
return RET_ERROR;
}

std::vector<int> squeeze_axis{0};
auto squeeze_node = CreateSqueezeNode(func_graph, get_item, squeeze_axis);
if (squeeze_node == nullptr) {
return RET_ERROR;
}

auto get_item_users = manager->node_users()[get_item];
for (auto &get_item_user : get_item_users) {
tr.SetEdge(get_item_user.first, get_item_user.second, squeeze_node);
}
}
tr.Commit();
return RET_OK;
}

STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int output_num) {
MS_ASSERT(cnode != nullptr);
AbstractBasePtrList abstract_list;
for (int i = 0; i < output_num; ++i) {
std::vector<int64_t> shape_vector;
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "create abstract_tensor failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract_tensor);
}
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
if (abstract_tuple == nullptr) {
MS_LOG(ERROR) << "create abstract_tuple failed";
return RET_ERROR;
}
cnode->set_abstract(abstract_tuple);
return RET_OK;
}

CNodePtr TfliteLstmCellFusion::CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node,
const std::vector<int> &axis) const {
MS_ASSERT(func_graph != nullptr);
std::unique_ptr<schema::SqueezeT> attr = std::make_unique<schema::SqueezeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new SqueezeT failed";
return nullptr;
}
attr->axis = axis;
auto new_primitive_t = std::make_unique<schema::PrimitiveT>();
if (new_primitive_t == nullptr) {
MS_LOG(ERROR) << "primitive_t is nullptr";
return nullptr;
}
new_primitive_t->value.type = schema::PrimitiveType_Squeeze;
new_primitive_t->value.value = attr.release();
auto new_primtive_c = std::shared_ptr<lite::PrimitiveC>(lite::PrimitiveC::Create(new_primitive_t.release()));
if (new_primtive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return nullptr;
}
ValueNodePtr value_node = NewValueNode(new_primtive_c);
auto squeeze_cnode = func_graph->NewCNode({value_node, input_node});
squeeze_cnode->set_abstract(input_node->abstract()->Clone());
squeeze_cnode->set_fullname_with_scope("squeeze_" + input_node->fullname_with_scope());
return squeeze_cnode;
}

const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
MS_LOG(DEBUG) << "lstm fusion pass";
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}

if (!utils::isa<CNodePtr>(node)) {
return nullptr;
}
auto tensor_list_stack_cnode = utils::cast<CNodePtr>(node);
auto tuple_get_item_node = tensor_list_stack_cnode->input(1);
if (!utils::isa<CNodePtr>(tuple_get_item_node)) {
return nullptr;
}
auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item_node);
auto while_node = tuple_get_item_cnode->input(1);
if (!utils::isa<CNodePtr>(while_node)) {
return nullptr;
}
auto while_cnode = utils::cast<CNodePtr>(while_node);

if (CheckIfCNodeIsNull(while_cnode) != RET_OK || CheckInputSize(while_cnode, while_inputs_num_) != RET_OK) {
return nullptr;
}
if (!CheckReferencedOutputs(func_graph, while_cnode)) {
return nullptr;
}
PrimitiveVarMapPtr primitive_vars_cond = std::make_shared<PrimitiveVarMap>();
auto cond_graph_pattern = GetCondGraphPattern(primitive_vars_cond);
auto cond_equiv = CheckSubGraph(func_graph, cond_graph_pattern, primitive_vars_cond, while_cnode->input(1),
cond_cnodes_num_, cond_nodes_num_);
if (cond_equiv == nullptr || cond_equiv->empty()) {
return nullptr;
}
PrimitiveVarMapPtr primitive_vars_body = std::make_shared<PrimitiveVarMap>();
auto body_graph_pattern = GetBodyGraphPattern(primitive_vars_body);
auto body_equiv = CheckSubGraph(func_graph, body_graph_pattern, primitive_vars_body, while_cnode->input(2),
body_cnodes_num_, body_nodes_num_);
if (body_equiv == nullptr || body_equiv->empty()) {
return nullptr;
}
float smooth = 0.0f;
if (!CheckBodyGraph(func_graph, body_equiv, while_cnode, &smooth)) {
return nullptr;
}
const std::string lstm_name = "lstm_" + while_cnode->fullname_with_scope();
auto lstm_node = CreateLSTMNode(func_graph, equiv, body_equiv, lstm_name, smooth);
if (lstm_node == nullptr) {
return nullptr;
}
auto status = SetAbstractTuple(lstm_node, kLSTMOutputNum);
if (status != RET_OK) {
return nullptr;
}

auto get_item_node = CreateOutputGetItem(func_graph, lstm_node, 0);
if (get_item_node == nullptr) {
MS_LOG(DEBUG) << "create lstm output get_item node failed";
return nullptr;
}

status = AdjustOtherGetItems(func_graph, while_cnode, lstm_node, tuple_get_item_cnode);
if (status != RET_OK) {
return nullptr;
}

std::vector<int> squeeze_axis{1}; // our lstm output:0 have an extra axis that tflite not have, it must be squeezed
auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis);
if (squeeze_node == nullptr) {
return nullptr;
}

auto cond_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 1);
func_graph->DropFuncGraphCNodeIndex(cond_cnode_index_pair);
auto body_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 2);
func_graph->DropFuncGraphCNodeIndex(body_cnode_index_pair);
MS_LOG(INFO) << "lstm node:" << lstm_node->fullname_with_scope() << " fusion success";
return squeeze_node;
}
} // namespace opt
} // namespace mindspore

+ 82
- 0
mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.h View File

@@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_

#include <vector>
#include <memory>
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"
#include "include/errorcode.h"

namespace mindspore {
namespace opt {
class TfliteLstmCellFusion : public PatternProcessPass {
public:
explicit TfliteLstmCellFusion(const std::string &name = "tflite_lstm_cell_fusion", bool multigraph = true,
int input_length = 0, int var_num = 0, int cond_nodes_num = 0, int cond_cnodes_num = 0,
int body_nodes_num = 0, int body_cnodes_num = 0);
~TfliteLstmCellFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

public:
static EquivPtr MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars,
const AnfNodePtr &pattern);
static EquivPtr CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern,
const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph,
const size_t cnode_num, const size_t all_node_num);
static lite::STATUS SetAbstractTuple(const CNodePtr &cnode, const int output_num);
static CNodePtr CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, const int item_index);

protected:
VarPtr cell_smooth_old_ = nullptr;
VarPtr cell_smooth_new_ = nullptr;
VarPtr hidden_smooth_old_ = nullptr;
VarPtr hidden_smooth_new_ = nullptr;
std::vector<VarPtr> while_input_vars_;

lite::STATUS GetFloatScalarFromParamValueLite(const AnfNodePtr &param_value, float *v) const;
CNodePtr CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node,
const std::vector<int> &axis) const;
lite::STATUS AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode,
const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) const;
AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
virtual CNodePtr CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const EquivPtr &body_equiv,
const std::string &base_name, const float smooth) const;

private:
bool CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const CNodePtr &while_cnode,
float *smooth) const;
bool CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) const;

lite::STATUS GetConcatedParam(const std::vector<AnfNodePtr> &params, const ParameterPtr &new_param,
bool is_bias) const;

private:
size_t while_input_var_num_ = 0;
size_t while_inputs_num_ = 0;
size_t cond_nodes_num_ = 0;
size_t cond_cnodes_num_ = 0;
size_t body_nodes_num_ = 0;
size_t body_cnodes_num_ = 0;
};

} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_

Loading…
Cancel
Save