diff --git a/mindspore/lite/nnacl/fp32/gru_fp32.c b/mindspore/lite/nnacl/fp32/gru_fp32.c new file mode 100644 index 0000000000..164e9cc195 --- /dev/null +++ b/mindspore/lite/nnacl/fp32/gru_fp32.c @@ -0,0 +1,134 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/gru_fp32.h" +#include +#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl/fp32/activation_fp32.h" +#include "nnacl/fp32/arithmetic_fp32.h" + +void InitGruGate(float *gate_buffer, const float *bias, const GruParameter *gru_parm) { + int gate_offest = 0; + for (int l = 0; l < 3; l++) { + int batch_offest = gate_offest; + int bias_offest = l * gru_parm->hidden_size_; + for (int b = 0; b < gru_parm->batch_; b++) { + memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float)); + batch_offest += gru_parm->hidden_size_; + } + gate_offest += gru_parm->batch_ * gru_parm->hidden_size_; + } +} + +void GruStepUnit(float *output, const float *input, const float *input_reset_weight, const float *input_update_weight, + const float *input_hidden_weight, const float *state_reset_weight, const float *state_update_weight, + const float *state_hidden_weight, const float *bias, float *hidden_state, float *gate_buffer, + const GruParameter *gru_parm) { + InitGruGate(gate_buffer, bias, gru_parm); + + float *update_gate = gate_buffer; + float *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_; + float *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2; + + // input * weight + MatMulAcc(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); + MatMulAcc(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); + MatMulAcc(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); + + // state * weight + MatMulAcc(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, + gru_parm->hidden_size_); + MatMulAcc(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_, + gru_parm->hidden_size_); + + // update reset_gate + Sigmoid(reset_gate, gru_parm->batch_ * gru_parm->hidden_size_, reset_gate); + + // update update_gate + Sigmoid(update_gate, gru_parm->batch_ * gru_parm->hidden_size_, update_gate); + + ElementMul(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_); + MatMulAcc(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, + gru_parm->hidden_size_); + + Tanh(hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_, hidden_buffer); + + ElementMul(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); + + ArithmeticParameter parameter; + parameter.in_elements_num0_ = 1; + parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_; + const float one = 1.0f; + ElementOptSub(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, ¶meter); + + ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); + + memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float)); +} + +void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias, + float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm) { + // forward + const float *input_update_weight = weight_g; + const float *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_; + const float *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2; + + const float *state_update_weight = weight_r; + const float *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_; + const float *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2; + + for (int t = 0; t < check_seq_len; t++) { + const float *input_ptr = input + t * gru_parm->input_step_; + float *output_ptr = output + t * gru_parm->output_step_; + GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, state_reset_weight, + state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer, gru_parm); + } + // zero out extra fw outputs + for (int t = check_seq_len; t < gru_parm->seq_len_; t++) { + float *output_ptr = output + t * gru_parm->output_step_; + for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + + // backward + if (gru_parm->bidirectional_) { + input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3; + input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4; + input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5; + + state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3; + state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4; + state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5; + + float *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_; + const float *backward_bias = bias + 3 * gru_parm->hidden_size_; + float *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_; + for (int t = check_seq_len - 1; t >= 0; t--) { + const float *input_ptr = input + t * gru_parm->input_step_; + float *output_ptr = backward_output + t * gru_parm->output_step_; + GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, + state_reset_weight, state_update_weight, state_hidden_weight, backward_bias, backward_hidden_state, + gate_buffer, gru_parm); + } + // zero out extra bw outputs + for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) { + float *output_ptr = backward_output + t * gru_parm->output_step_; + for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + } +} diff --git a/mindspore/lite/nnacl/fp32/gru_fp32.h b/mindspore/lite/nnacl/fp32/gru_fp32.h new file mode 100644 index 0000000000..e247783501 --- /dev/null +++ b/mindspore/lite/nnacl/fp32/gru_fp32.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ +#define MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ +#include "nnacl/op_base.h" + +typedef struct GruParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; // output_size + int seq_len_; + int batch_; + // other parameter + int input_step_; + int output_step_; + bool bidirectional_; +} GruParameter; + +#ifdef __cplusplus +extern "C" { +#endif +void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias, + float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.c b/mindspore/lite/nnacl/fp32/lstm_fp32.c index d4c5edf914..fd615201bd 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.c +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.c @@ -16,6 +16,7 @@ #include "nnacl/fp32/lstm_fp32.h" #include +#include #include "nnacl/fp32/activation_fp32.h" #include "nnacl/fp32/arithmetic_fp32.h" @@ -79,21 +80,63 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int } } +int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 4; index += C4NUM) { + float32x4_t vin0 = vld1q_f32(input0 + index); + float32x4_t vout = vld1q_f32(output + index); + vout = vmlaq_n_f32(vout, vin0, input1); + vst1q_f32(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] += input0[index] * input1; + } + return NNACL_OK; +} + void UpdataState(float *cell_state, const float *forget_gate, const float *input_gate, const float *cell_gate, - int batch, int hidden_size) { + float *state_buffer, int batch, int hidden_size, const float smooth) { + if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // smooth * old_cell_state + memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float)); + ArithmeticParameter parameter; + parameter.in_elements_num0_ = batch * hidden_size; + parameter.in_elements_num1_ = 1; + ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter); + } + ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size); ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size); + + if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // (1 - smooth) * new_cell_state + ElementOptMulAcc(cell_state, 1 - smooth, state_buffer, batch * hidden_size); + } } -void UpdataOutput(const float *cell_state, const float *output_gate, float *hidden_state, int batch, int hidden_size) { +void UpdataOutput(const float *cell_state, const float *output_gate, float *hidden_state, float *state_buffer_in, + int batch, int hidden_size, const float smooth) { + float *state_buffer = state_buffer_in + batch * hidden_size; + if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { + memcpy(state_buffer, hidden_state, batch * hidden_size * sizeof(float)); + ArithmeticParameter parameter; + parameter.in_elements_num0_ = batch * hidden_size; + parameter.in_elements_num1_ = 1; + ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter); + } + Tanh(cell_state, batch * hidden_size, hidden_state); ElementMul(hidden_state, output_gate, hidden_state, batch * hidden_size); + + if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { + ElementOptMulAcc(hidden_state, 1 - smooth, state_buffer, batch * hidden_size); + } } void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight, const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight, const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight, - const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, + const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, const LstmParameter *lstm_parm) { InitGate(gate_buffer, bias, lstm_parm); @@ -129,17 +172,26 @@ void LstmStepUnit(float *output, const float *input, const float *input_input_we // update cell_gate Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate); // update cell state - UpdataState(cell_state, forget_gate, input_gate, cell_gate, lstm_parm->batch_, lstm_parm->hidden_size_); + UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->smooth_); // update output_gate Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate); // update output - UpdataOutput(cell_state, output_gate, hidden_state, lstm_parm->batch_, lstm_parm->hidden_size_); + UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->smooth_); memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); + + if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) { + memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); + memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_, + lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); + } } void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, - float *hidden_state, float *cell_state, float *gate_buffer, const LstmParameter *lstm_parm) { + float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, + const LstmParameter *lstm_parm) { // forward const float *input_input_weight = weight_i; const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; @@ -156,7 +208,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float float *output_ptr = output + t * lstm_parm->output_step_; LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state, - cell_state, gate_buffer, lstm_parm); + cell_state, gate_buffer, state_buffer, lstm_parm); } // backward @@ -180,7 +232,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float float *output_ptr = backward_output + t * lstm_parm->output_step_; LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, - backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, lstm_parm); + backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, lstm_parm); } } } diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.h b/mindspore/lite/nnacl/fp32/lstm_fp32.h index 68c7c94178..265e56058e 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.h +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.h @@ -31,13 +31,24 @@ typedef struct LstmParameter { int input_step_; int output_step_; bool bidirectional_; + // smooth factor for hidden/cell state calculation: + // output_hidden = old_hidden * smooth + new_hidden * (1 - smooth) + // output_cell = old_cell * smooth + new_cell * (1 - smooth) + float smooth_; } LstmParameter; #ifdef __cplusplus extern "C" { #endif +void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size); + +void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size); + +int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size); + void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, - float *hidden_state, float *cell_state, float *gate_buffer, const LstmParameter *lstm_parm); + float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, + const LstmParameter *lstm_parm); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 20c666e814..2bda7fee10 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -262,6 +262,7 @@ union PrimitiveType { Merge, Mod, GeLU, + Gru, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index c0b008be42..3e4116c02d 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1005,6 +1005,11 @@ table OneHot { table Lstm{ bidirection: bool = false; + smooth: float = 0.0; +} + +table Gru{ + bidirection: bool = false; } table PriorBox { diff --git a/mindspore/lite/src/ops/gru.cc b/mindspore/lite/src/ops/gru.cc new file mode 100644 index 0000000000..40ae70335d --- /dev/null +++ b/mindspore/lite/src/ops/gru.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/gru.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +bool Gru::GetBidirection() const { return this->primitive_->value.AsGru()->bidirection; } + +void Gru::SetBidirection(bool bidirection) { this->primitive_->value.AsGru()->bidirection = bidirection; } + +#else + +bool Gru::GetBidirection() const { return this->primitive_->value_as_Gru()->bidirection(); } +int Gru::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Gru(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Gru return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateGru(*fbb, attr->bidirection()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Gru, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *GruCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry GruRegistry(schema::PrimitiveType_Gru, GruCreator); +#endif + +const int kGruInputNum = 5; +const int kGruInputWithSeqLenNum = 6; +const int kGruOutputNum = 2; +int Gru::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive_ != nullptr); + if ((inputs_.size() != kGruInputNum && inputs_.size() != kGruInputWithSeqLenNum) || + outputs_.size() != kGruOutputNum) { + MS_LOG(ERROR) << "OpGru inputs or outputs size error."; + return RET_INPUT_TENSOR_ERROR; + } + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto weight_gate = inputs_.at(1); + MS_ASSERT(weight_gate != nullptr); + auto weight_recurrence = inputs_.at(2); + MS_ASSERT(weight_recurrence != nullptr); + auto bias = inputs_.at(3); + MS_ASSERT(bias != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + for (int i = 0; i < kGruOutputNum; i++) { + outputs_.at(i)->set_data_type(input->data_type()); + outputs_.at(i)->set_format(input->format()); + } + if (!infer_flag()) { + return RET_INFER_INVALID; + } + + auto in_shape = input->shape(); // seq_len, batch, input_size + auto w_gate_shape = weight_gate->shape(); // num_direction, hidden_size * 3, input_size + auto w_recu_shape = weight_recurrence->shape(); // num_direction, hidden_size * 3, hidden_size + auto bias_shape = bias->shape(); // num_direction, hidden_size * 6 + if (in_shape.size() != 3 || w_gate_shape.size() != 3 || w_recu_shape.size() != 3) { + MS_LOG(ERROR) << "OpGru input dims should be 3."; + return RET_ERROR; + } + if (w_gate_shape[1] != w_recu_shape[1] || w_recu_shape[1] * 2 != bias_shape[1]) { + MS_LOG(ERROR) << "OpGru w_gate, w_recu and bias hidden size not match."; + return RET_ERROR; + } + if (inputs_.size() == kGruInputWithSeqLenNum) { + auto seq_len_shape = inputs_.at(5)->shape(); + if (seq_len_shape[0] > 1) { + MS_LOG(WARNING) << "OpGru with batch_size > 1 only support all same sequence_len now."; + return RET_ERROR; + } + if (seq_len_shape.size() != 1 && seq_len_shape[0] != in_shape[1]) { + MS_LOG(ERROR) << "OpGru sequence_len shape[0] and batch_size not match."; + return RET_ERROR; + } + } + + int hidden_size = w_gate_shape[1] / 3; + // set output + std::vector out_shape(in_shape); + out_shape[2] = hidden_size; + if (GetBidirection()) { + out_shape.insert(out_shape.begin() + 1, 2); + } else { + out_shape.insert(out_shape.begin() + 1, 1); + } + output->set_shape(out_shape); + // set hidden state + std::vector state_shape(in_shape); + state_shape[0] = GetBidirection() ? 2 : 1; + state_shape[2] = hidden_size; + outputs_[1]->set_shape(state_shape); + + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/gru.h b/mindspore/lite/src/ops/gru.h new file mode 100644 index 0000000000..84ca28fb9b --- /dev/null +++ b/mindspore/lite/src/ops/gru.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_OPS_GRU_H_ +#define MINDSPORE_LITE_SRC_OPS_GRU_H_ +#include +#include +#include + +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +/* + * gru with linear_before_reset = 0 + */ +class Gru : public PrimitiveC { + public: + Gru() = default; + ~Gru() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Gru, PrimitiveC); + explicit Gru(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + void SetBidirection(bool bidirection); + +#else + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; + bool GetBidirection() const; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_OPS_GRU_H_ diff --git a/mindspore/lite/src/ops/lstm.cc b/mindspore/lite/src/ops/lstm.cc index 8963020915..7d1a398784 100644 --- a/mindspore/lite/src/ops/lstm.cc +++ b/mindspore/lite/src/ops/lstm.cc @@ -25,11 +25,16 @@ namespace lite { #ifdef PRIMITIVE_WRITEABLE bool Lstm::GetBidirection() const { return this->primitive_->value.AsLstm()->bidirection; } +float Lstm::GetSmooth() const { return this->primitive_->value.AsLstm()->smooth; } + void Lstm::SetBidirection(bool bidirection) { this->primitive_->value.AsLstm()->bidirection = bidirection; } +void Lstm::SetSmooth(float smooth) { this->primitive_->value.AsLstm()->smooth = smooth; } + #else bool Lstm::GetBidirection() const { return this->primitive_->value_as_Lstm()->bidirection(); } +float Lstm::GetSmooth() const { return this->primitive_->value_as_Lstm()->smooth(); } int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != fbb); @@ -38,7 +43,7 @@ int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F MS_LOG(ERROR) << "value_as_Lstm return nullptr"; return RET_ERROR; } - auto val_offset = schema::CreateLstm(*fbb, attr->bidirection()); + auto val_offset = schema::CreateLstm(*fbb, attr->bidirection(), attr->smooth()); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Lstm, val_offset.o); fbb->Finish(prim_offset); return RET_OK; diff --git a/mindspore/lite/src/ops/lstm.h b/mindspore/lite/src/ops/lstm.h index 7944b97370..fd58a99a46 100644 --- a/mindspore/lite/src/ops/lstm.h +++ b/mindspore/lite/src/ops/lstm.h @@ -33,12 +33,14 @@ class Lstm : public PrimitiveC { MS_DECLARE_PARENT(Lstm, PrimitiveC); explicit Lstm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetBidirection(bool bidirection); + void SetSmooth(float smooth); #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; bool GetBidirection() const; + float GetSmooth() const; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/populate/gru_populate.cc b/mindspore/lite/src/ops/populate/gru_populate.cc new file mode 100644 index 0000000000..1e57855d30 --- /dev/null +++ b/mindspore/lite/src/ops/populate/gru_populate.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/gru.h" +#include "src/ops/primitive_c.h" +#include "src/ops/populate/populate_register.h" +#include "nnacl/fp32/gru_fp32.h" + +namespace mindspore { +namespace lite { +OpParameter *PopulateGruParameter(const mindspore::lite::PrimitiveC *primitive) { + GruParameter *gru_param = reinterpret_cast(malloc(sizeof(GruParameter))); + if (gru_param == nullptr) { + MS_LOG(ERROR) << "malloc GruParameter failed."; + return nullptr; + } + memset(gru_param, 0, sizeof(GruParameter)); + gru_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + if (param == nullptr) { + free(gru_param); + MS_LOG(ERROR) << "get Gru param nullptr."; + return nullptr; + } + gru_param->bidirectional_ = param->GetBidirection(); + return reinterpret_cast(gru_param); +} +Registry GruParameterRegistry(schema::PrimitiveType_Gru, PopulateGruParameter); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/populate/lstm_populate.cc b/mindspore/lite/src/ops/populate/lstm_populate.cc index 7939498b10..95642daec0 100644 --- a/mindspore/lite/src/ops/populate/lstm_populate.cc +++ b/mindspore/lite/src/ops/populate/lstm_populate.cc @@ -36,6 +36,7 @@ OpParameter *PopulateLstmParameter(const mindspore::lite::PrimitiveC *primitive) return nullptr; } lstm_param->bidirectional_ = param->GetBidirection(); + lstm_param->smooth_ = param->GetSmooth(); return reinterpret_cast(lstm_param); } Registry LstmParameterRegistry(schema::PrimitiveType_Lstm, PopulateLstmParameter); diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index a8f316ab03..a486717c0d 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -161,6 +161,7 @@ #include "src/ops/switch.h" #include "src/ops/partial.h" #include "src/ops/gelu.h" +#include "src/ops/gru.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -995,6 +996,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) AssertOP(primitive); case schema::PrimitiveType_GeLU: return new (std::nothrow) GeLU(primitive); + case schema::PrimitiveType_Gru: + return new (std::nothrow) Gru(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: return new (std::nothrow) ActivationGrad(primitive); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc new file mode 100644 index 0000000000..e7a1911bc9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc @@ -0,0 +1,165 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/gru_fp32.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Gru; + +namespace mindspore::kernel { +void GruCPUKernel::FreeTmpBuffer() { + if (gate_buffer_ != nullptr) { + free(gate_buffer_); + gate_buffer_ = nullptr; + } + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; + } + weight_g_ptr_ = nullptr; + weight_r_ptr_ = nullptr; +} + +int GruCPUKernel::InitParam() { + auto input = in_tensors_.front(); + MS_ASSERT(input != nullptr); + std::vector in_shape = input->shape(); + gru_parm_->seq_len_ = in_shape.at(0); + gru_parm_->batch_ = in_shape.at(1); + gru_parm_->input_size_ = in_shape.at(2); + + auto weight_g = in_tensors_.at(1); + MS_ASSERT(weight_g != nullptr); + std::vector w_shape = weight_g->shape(); + gru_parm_->hidden_size_ = w_shape.at(1) / 3; + + gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_; + gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_ + : gru_parm_->batch_ * gru_parm_->hidden_size_; + return RET_OK; +} + +int GruCPUKernel::InitBuffer() { + gate_buffer_ = reinterpret_cast(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float))); + if (gate_buffer_ == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error."; + return RET_ERROR; + } + return RET_OK; +} + +int GruCPUKernel::InitWeightBias() { + auto weight_gate = in_tensors_.at(1); + MS_ASSERT(weight_gate != nullptr); + weight_g_ptr_ = reinterpret_cast(weight_gate->data_c()); + + auto weight_recu = in_tensors_.at(2); + MS_ASSERT(weight_recu != nullptr); + weight_r_ptr_ = reinterpret_cast(weight_recu->data_c()); + + int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_; + bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float))); + if (bias_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error."; + return RET_ERROR; + } + + auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); + const int state_bias_offset = 3 * gru_parm_->hidden_size_; + for (int i = 0; i < state_bias_offset; i++) { + bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; + } + if (gru_parm_->bidirectional_) { + bias_data += 3 * gru_parm_->hidden_size_ * 2; + auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_; + for (int i = 0; i < state_bias_offset; i++) { + backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset]; + } + } + return RET_OK; +} + +int GruCPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int GruCPUKernel::ReSize() { + FreeTmpBuffer(); + auto ret = InitParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruCPUKernel InitParam error."; + return RET_ERROR; + } + + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruCPUKernel InitWeightBias error."; + FreeTmpBuffer(); + return RET_ERROR; + } + + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruCPUKernel InitBuffer error."; + FreeTmpBuffer(); + return RET_ERROR; + } + return RET_OK; +} + +int GruCPUKernel::Run() { + auto input = in_tensors_.at(kInputIndex); + MS_ASSERT(input != nullptr); + auto hidden_state = in_tensors_.at(4); + MS_ASSERT(hidden_state != nullptr); + auto output = out_tensors_.at(0); + MS_ASSERT(output != nullptr); + auto input_ptr = reinterpret_cast(input->data_c()); + MS_ASSERT(input_ptr); + auto output_ptr = reinterpret_cast(output->MutableData()); + MS_ASSERT(output_ptr); + auto output_hidden_state = out_tensors_[1]; + memcpy(output_hidden_state->MutableData(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float)); + int check_seq_len = gru_parm_->seq_len_; + if (in_tensors_.size() == 6) { + auto seq_len = reinterpret_cast(in_tensors_.at(5)->data_c()); + if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) { + MS_LOG(ERROR) << "different batch seq_len is currently not supported"; + return RET_ERROR; + } + check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0])); + } + + MS_ASSERT(weight_g_ptr_); + MS_ASSERT(weight_r_ptr_); + MS_ASSERT(bias_ptr_); + MS_ASSERT(gate_buffer_); + Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_, + reinterpret_cast(output_hidden_state->MutableData()), gate_buffer_, check_seq_len, gru_parm_); + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gru, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h new file mode 100644 index 0000000000..720323d520 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_ +#include +#include "src/lite_kernel.h" +#include "nnacl/fp32/gru_fp32.h" + +namespace mindspore::kernel { +class GruCPUKernel : public LiteKernel { + public: + GruCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + gru_parm_ = reinterpret_cast(op_parameter_); + } + + ~GruCPUKernel() override { FreeTmpBuffer(); } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + void FreeTmpBuffer(); + int InitParam(); + int InitBuffer(); + int InitWeightBias(); + + float *gate_buffer_ = nullptr; + const float *weight_g_ptr_ = nullptr; + const float *weight_r_ptr_ = nullptr; + float *bias_ptr_ = nullptr; + GruParameter *gru_parm_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc index 607f46c48b..2ab97225c4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc @@ -15,6 +15,7 @@ */ #include "src/runtime/kernel/arm/fp32/lstm_fp32.h" +#include #include #include "schema/model_generated.h" #include "src/kernel_registry.h" @@ -32,6 +33,10 @@ void LstmCPUKernel::FreeTmpBuffer() { free(gate_buffer_); gate_buffer_ = nullptr; } + if (state_buffer_ != nullptr) { + free(state_buffer_); + state_buffer_ = nullptr; + } if (weight_i_ptr_ != nullptr) { free(weight_i_ptr_); weight_i_ptr_ = nullptr; @@ -71,6 +76,14 @@ int LstmCPUKernel::InitBuffer() { MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; return RET_ERROR; } + if (!(lstm_parm_->smooth_ >= -FLT_EPSILON && lstm_parm_->smooth_ <= FLT_EPSILON)) { + int buffer_size = 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float); + state_buffer_ = reinterpret_cast(malloc(buffer_size)); + if (state_buffer_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error."; + return RET_ERROR; + } + } return RET_OK; } @@ -173,7 +186,7 @@ int LstmCPUKernel::Run() { MS_ASSERT(gate_buffer_); Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, reinterpret_cast(output_hidden_state->MutableData()), - reinterpret_cast(output_cell_state->MutableData()), gate_buffer_, lstm_parm_); + reinterpret_cast(output_cell_state->MutableData()), gate_buffer_, state_buffer_, lstm_parm_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h index 82ed1d6e70..a2ced62b73 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h @@ -44,6 +44,7 @@ class LstmCPUKernel : public LiteKernel { int InitWeightBias(); float *gate_buffer_ = nullptr; + float *state_buffer_ = nullptr; float *weight_i_ptr_ = nullptr; float *weight_h_ptr_ = nullptr; float *bias_ptr_ = nullptr; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index cec95bf66a..dfb6054a4f 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -187,6 +187,9 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc diff --git a/mindspore/lite/test/models_tf.cfg b/mindspore/lite/test/models_tf.cfg index e69de29bb2..7712e1d401 100644 --- a/mindspore/lite/test/models_tf.cfg +++ b/mindspore/lite/test/models_tf.cfg @@ -0,0 +1 @@ +decoder_step_201217.pb 5 diff --git a/mindspore/lite/test/run_benchmark_nets.sh b/mindspore/lite/test/run_benchmark_nets.sh index a242e14783..cd42456882 100644 --- a/mindspore/lite/test/run_benchmark_nets.sh +++ b/mindspore/lite/test/run_benchmark_nets.sh @@ -925,6 +925,41 @@ function Run_arm64() { fi done < ${models_compatibility_config} + # Run tf converted models: + while read line; do + model_name=${line} + if [[ $model_name == \#* ]]; then + continue + fi + model_name=`echo ${tf_line_info}|awk -F ' ' '{print $1}'` + input_num=`echo ${tf_line_info}|awk -F ' ' '{print $2}'` + input_files='' + for i in $(seq 1 $input_num) + do + input_files=$input_files'/data/local/tmp/input_output/input/'$model_name'.ms_'$i'.bin,' + done + echo ${model_name} >> "${run_arm64_log_file}" + echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out' >> "${run_arm64_log_file}" + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out' >> adb_run_cmd.txt + adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" + if [ $? = 0 ]; then + run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} + else + run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 + fi + # run benchmark test without clib data + echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> "${run_arm64_log_file}" + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> adb_run_cmd.txt + adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" + if [ $? = 0 ]; then + run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} + else + run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 + fi + done < ${models_tf_config} + # Run tflite converted models: while read line; do model_name=${line} diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 6992601f4f..6117927085 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -46,6 +46,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/batchmatmul_fusion.cc ../optimizer/fusion/sigmoid_mul_fusion.cc ../optimizer/fusion/conv_conv_fusion.cc + ../optimizer/fusion/tflite_lstm_cell_fusion.cc + ../optimizer/fusion/tf_lstm_cell_fusion.cc + ../optimizer/fusion/bidirection_tf_gru_cell_fusion.cc ../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc ../optimizer/graph/clip_convert_activation_pass.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 094c54b59d..a6ac5687ff 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -29,6 +29,9 @@ #include "tools/optimizer/fusion/batchmatmul_fusion.h" #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" #include "tools/optimizer/fusion/conv_conv_fusion.h" +#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" +#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" +#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" #include "tools/optimizer/graph/mindir_adjust_pass.h" #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" #include "tools/optimizer/graph/identity_remove_pass.h" @@ -114,6 +117,9 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); } auto weight_format_hardcode_pass = std::make_shared(); weight_format_hardcode_pass->SetFmkType(config->fmk); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index dad98203f9..7e13a46a29 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -572,7 +572,12 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { type = TypeIdToType(kObjectTypeTensorType); } - anf_node->set_abstract(std::make_shared(type, shape_vector)); + auto abstract = std::make_shared(type, shape_vector); + if (abstract == nullptr) { + MS_LOG(ERROR) << "create AbstractTensor failed"; + return RET_ERROR; + } + anf_node->set_abstract(abstract); anf_node_map->insert(std::pair(op.name(), anf_node)); } else { AbstractBasePtrList abstractList; @@ -589,6 +594,12 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C std::vector inputs{tupleGetItemPrim, anf_node, getItemValue}; CNodePtr getItemCNode = anf_graph->NewCNode(inputs); std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); + auto abstract = std::make_shared(kFloat32, shape_vector); + if (abstract == nullptr) { + MS_LOG(ERROR) << "create AbstractTensor failed"; + return RET_ERROR; + } + getItemCNode->set_abstract(abstract); getItemCNode->set_fullname_with_scope(output_item_name); anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); } diff --git a/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc index 4fda58a80a..e4f1ee1ea9 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc @@ -63,7 +63,11 @@ STATUS TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op, } *output_size = 1; - return AddOpInput(tf_op, 0, inputs); + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + return AddOpInput(tf_op, 1, inputs); } TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tf/tf_select_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_select_parser.cc new file mode 100644 index 0000000000..a4d3c8338c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_select_parser.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_select_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFSelectParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF SelectParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_Switch; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + return RET_OK; +} +TFNodeRegistrar g_tfSelectParser("Select", new TFSelectParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_select_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_select_parser.h new file mode 100644 index 0000000000..79e1fa8da5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_select_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFSelectParser : public TFNodeParser { + public: + TFSelectParser() = default; + ~TFSelectParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.cc b/mindspore/lite/tools/converter/parser/tf/tf_util.cc index e95b040724..1a605da670 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_util.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.cc @@ -122,7 +122,7 @@ std::string TensorFlowUtils::GetFlattenNodeName(const std::string &input_name) { std::sregex_token_iterator()); std::string ret = input_name; if (input_splits.size() == 3) { - if (input_splits[2] == "0") { + if (input_splits[2].compare("0") == 0) { ret = input_splits[0]; } else { ret = input_splits[0] + ":" + input_splits[2]; // multi output node diff --git a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc new file mode 100644 index 0000000000..83e69b28d3 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc @@ -0,0 +1,679 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" +#include +#include +#include "src/ops/primitive_c.h" +#include "src/common/utils.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "securec/include/securec.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kWhileCommonInputsLength = 2; +constexpr size_t kWhileUniqInputsLength = 6; +constexpr size_t kCondNodesNum = 12; +constexpr size_t kCondCNodesNum = 4; +constexpr size_t kBodyNodesNum = 69; +constexpr size_t kBodyCNodesNum = 25; +const auto &p1 = std::placeholders::_1; + +bool IsParameterNode(const BaseRef &n) { return utils::isa(n); } + +bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) { + if (utils::isa(n) || utils::isa(n)) { + return opt::GetCNodeType(n) == type; + } + return false; +} +} // namespace + +BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, bool multigraph) + : PatternProcessPass(name, multigraph) { + /* + * vars for while input + * common: + * 0:const0 1:init_state + * fw_while_inputs: + * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias + * bw_while_inputs: + * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias + */ + for (size_t i = 0; i < kWhileCommonInputsLength; ++i) { + common_vars_.emplace_back(std::make_shared()); + } + for (size_t i = 0; i < kWhileUniqInputsLength; ++i) { + fw_vars_.emplace_back(std::make_shared()); + bw_vars_.emplace_back(std::make_shared()); + } + input_ = std::make_shared(); + input_length_ = std::make_shared(); + transpose_input_ = std::make_shared(); +} + +const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { + auto const1 = std::make_shared(IsParameterNode); + auto ele_shape = std::make_shared(IsParameterNode); + + // forward + auto fw_max1 = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_}); + auto fw_max2 = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, fw_max1}); + + auto fw_shape = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_}); + auto fw_stride = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), fw_shape}); + auto fw_min = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2}); + + auto fw_reserve = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape, + fw_stride}); + auto fw_from_tensor = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), + transpose_input_, ele_shape}); + auto is_fw_while = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_While)); + auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], common_vars_[0], fw_stride, common_vars_[0], + fw_reserve, common_vars_[1], fw_min, fw_from_tensor, input_length_}); + fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end()); + fw_while.emplace_back(common_vars_[1]); + auto fw_get_item = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)), + fw_while, std::make_shared()}); + auto fw_stack = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)), + fw_get_item, ele_shape}); + auto fw_out_trans = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), fw_stack}); + + // backward + auto bw_reverse_seq = VectorRef( + {std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), input_, input_length_}); + auto bw_max1 = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_}); + auto bw_max2 = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, bw_max1}); + auto bw_trans = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_reverse_seq}); + auto bw_shape = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans}); + auto bw_stride = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), bw_shape}); + auto bw_min = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2}); + auto bw_reserve = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape, + bw_stride}); + auto bw_from_tensor = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), bw_trans, + ele_shape}); + auto is_bw_while = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_While)); + auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], common_vars_[0], bw_stride, common_vars_[0], + bw_reserve, common_vars_[1], bw_min, bw_from_tensor, input_length_}); + bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end()); + bw_while.emplace_back(common_vars_[1]); + auto bw_get_item = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)), + bw_while, std::make_shared()}); + auto bw_stack = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)), + bw_get_item, ele_shape}); + auto bw_out_trans = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_stack}); + auto bw_reverse1 = + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), bw_out_trans, + input_length_}); + + auto concat = VectorRef( + {std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Concat)), fw_out_trans, bw_reverse1}); + return concat; +} + +AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { + auto is_parameter1 = std::make_shared(IsParameterNode); + auto is_parameter2 = std::make_shared(IsParameterNode); + auto is_parameter3 = std::make_shared(IsParameterNode); + auto is_parameter4 = std::make_shared(IsParameterNode); + auto is_less1 = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); + auto is_less2 = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); + auto is_logical_and = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_LogicalAnd)); + auto is_return = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); + VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); + VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); + VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); + VectorRef return_ref = VectorRef({is_return, logicaland_ref}); + VarPtr fg = std::make_shared("RootG"); + auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true); + return pattern; +} + +AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { + std::vector placeholders; + for (int i = 0; i < 13; ++i) { + placeholders.emplace_back(std::make_shared(IsParameterNode)); + } + VectorRef add = VectorRef({std::make_shared(), placeholders[2], std::make_shared(IsParameterNode)}); + VectorRef add1 = VectorRef({std::make_shared(), placeholders[0], std::make_shared(IsParameterNode)}); + + VectorRef get_item = VectorRef( + {std::make_shared("GetItem"), placeholders[6], placeholders[2], std::make_shared(IsParameterNode)}); + VectorRef concat_input_h = VectorRef({std::make_shared(), get_item, placeholders[4]}); + + VectorRef matmul1 = VectorRef({std::make_shared("Matmul"), concat_input_h, placeholders[8]}); + VectorRef biasadd1 = VectorRef({std::make_shared("BiasAdd"), matmul1, placeholders[9]}); + VectorRef sigmoid1 = VectorRef({std::make_shared("Sigmoid"), biasadd1}); + + VectorRef split = VectorRef({std::make_shared("Split"), sigmoid1}); + VectorRef get_item1 = VectorRef({std::make_shared("TupleGetItem"), split, std::make_shared()}); + VectorRef get_item2 = VectorRef({std::make_shared("TupleGetItem"), split, std::make_shared()}); + + VectorRef pre_reset = VectorRef({std::make_shared("Mul"), get_item1, placeholders[4]}); + VectorRef concat2 = VectorRef({std::make_shared("Concat"), get_item, pre_reset}); + VectorRef matmul2 = VectorRef({std::make_shared("Matmul"), concat2, placeholders[10]}); + VectorRef biasadd2 = VectorRef({std::make_shared("BiasAdd"), matmul2, placeholders[11]}); + VectorRef tanh = VectorRef({std::make_shared("Tanh"), biasadd2}); + + VectorRef update_hidden = VectorRef({std::make_shared("Mul"), get_item2, placeholders[4]}); + VectorRef minus_update = + VectorRef({std::make_shared("Sub"), std::make_shared(IsParameterNode), get_item2}); + VectorRef updated = VectorRef({std::make_shared("Mul"), minus_update, tanh}); + + VectorRef new_hidden = VectorRef({std::make_shared("Add"), update_hidden, updated}); + + VectorRef greater_equal = VectorRef({std::make_shared("GreaterEqual"), placeholders[2], placeholders[7]}); + + VectorRef select_output = VectorRef({std::make_shared("Switch"), greater_equal, placeholders[12], new_hidden}); + VectorRef output = VectorRef({std::make_shared("SetItem"), placeholders[3], placeholders[2], select_output}); + + VectorRef select_hidden = VectorRef({std::make_shared("Switch"), greater_equal, placeholders[4], new_hidden}); + + auto is_make_tuple = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple)); + std::vector outputs = {is_make_tuple, add1, placeholders[1], add, + output, select_hidden, placeholders[5], placeholders[6], + placeholders[7]}; + outputs.insert(outputs.end(), placeholders.begin() + 8, placeholders.end()); + VectorRef make_tuple_node = VectorRef(outputs); + auto is_return = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); + VectorRef return_node = VectorRef({is_return, make_tuple_node}); + + VarPtr fg = std::make_shared("RootG"); + auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true); + return pattern; +} + +ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const { + MS_ASSERT(parameter_anf != nullptr); + if (!utils::isa(parameter_anf)) { + MS_LOG(DEBUG) << "parameter_anf is not ParameterPtr"; + return nullptr; + } + auto parameter = utils::cast(parameter_anf); + if (!parameter->has_default()) { + MS_LOG(DEBUG) << "parameter not have default value"; + return nullptr; + } + auto param_value = std::dynamic_pointer_cast(parameter->default_param()); + return param_value; +} + +STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, + const AnfNodePtr &bw_cand_kernel_anf, int *input_size, + int *hidden_size) const { + MS_ASSERT(fw_cand_kernel != nullptr); + MS_ASSERT(bw_cand_kernel != nullptr); + MS_ASSERT(input_size != nullptr); + MS_ASSERT(hidden_size != nullptr); + auto fw_cand_kernel_value = GetDefaultParamValue(fw_cand_kernel_anf); + if (fw_cand_kernel_value == nullptr) { + return RET_ERROR; + } + auto fw_cand_kernel_shape = fw_cand_kernel_value->tensor_shape(); + if (fw_cand_kernel_shape.size() != 2) { + return RET_ERROR; + } + auto bw_cand_kernel_value = GetDefaultParamValue(bw_cand_kernel_anf); + if (bw_cand_kernel_value == nullptr) { + return RET_ERROR; + } + auto bw_cand_kernel_shape = bw_cand_kernel_value->tensor_shape(); + if (bw_cand_kernel_shape.size() != 2) { + return RET_ERROR; + } + if (fw_cand_kernel_shape != bw_cand_kernel_shape) { + return RET_ERROR; + } + if (fw_cand_kernel_shape[1] <= 0 || fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1] <= 0) { + MS_LOG(DEBUG) << "gru input size or hidden size illegal"; + return RET_ERROR; + } + *hidden_size = fw_cand_kernel_shape[1]; + *input_size = fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1]; + return RET_OK; +} + +ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name, + const std::vector &shape, const TypeId type, + void **tensor_data) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(tensor_data != nullptr); + auto parameter = func_graph->add_parameter(); + parameter->set_name(name); + std::vector shape_vector(shape.begin(), shape.end()); + auto abstract_tensor = std::make_shared(TypeIdToType(type), shape_vector); + if (abstract_tensor == nullptr) { + return nullptr; + } + parameter->set_abstract(abstract_tensor); + + auto gate_weight_default = std::make_shared(); + if (gate_weight_default == nullptr) { + MS_LOG(ERROR) << "gate_weight_default is nullptr"; + return nullptr; + } + gate_weight_default->set_tensor_shape(shape); + gate_weight_default->set_tensor_type(type); + gate_weight_default->set_format(schema::Format_NHWC); + int data_len = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + int data_size = 0; + if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) { + data_size = data_len * sizeof(float); + *tensor_data = new (std::nothrow) float[data_len]; + } else if (type == kNumberTypeInt || type == kNumberTypeInt32) { + data_size = data_len * sizeof(int); + *tensor_data = new (std::nothrow) int[data_len]; + } else { + MS_LOG(DEBUG) << "unsupported data type"; + return nullptr; + } + if (*tensor_data == nullptr) { + MS_LOG(ERROR) << "new data failed"; + return nullptr; + } + + gate_weight_default->SetTensorData(*tensor_data, data_size); + parameter->set_default_param(gate_weight_default); + return parameter; +} + +void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, + const int r1, const int c0, const int c1, float *data, + bool t) const { + MS_ASSERT(mat != nullptr); + MS_ASSERT(data != nullptr); + MS_ASSERT(0 <= r0 && r0 < r1 && r1 <= R); + MS_ASSERT(0 <= c0 && c0 < c1 && c1 <= C); + const int RT = r1 - r0; + const int CT = c1 - c0; + for (int i = r0; i < r1; ++i) { + for (int j = c0; j < c1; ++j) { + if (t) { + data[(j - c0) * RT + (i - r0)] = mat[i * C + j]; + } else { + data[(i - r0) * CT + (j - c0)] = mat[i * C + j]; + } + } + } +} + +STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, + const int input_size, const int hidden_size, + float *gate_tensor_data, float *recu_tensor_data) const { + MS_ASSERT(gate_weight != nullptr); + MS_ASSERT(cand_weight != nullptr); + MS_ASSERT(gate_tensor_data != nullptr); + MS_ASSERT(recu_tensor_data != nullptr); + const std::vector gate_shape{input_size + hidden_size, hidden_size * 2}; + const std::vector cand_shape{hidden_size * 2, hidden_size}; + auto gate_weight_value = GetDefaultParamValue(gate_weight); + if (gate_weight_value == nullptr) { + return RET_ERROR; + } + auto gate_weight_data = reinterpret_cast(gate_weight_value->tensor_addr()); + if (gate_weight_data == nullptr) { + return RET_ERROR; + } + auto gate_weight_shape = gate_weight_value->tensor_shape(); + + auto cand_weight_value = GetDefaultParamValue(cand_weight); + if (cand_weight_value == nullptr) { + return RET_ERROR; + } + auto cand_weight_data = reinterpret_cast(cand_weight_value->tensor_addr()); + if (cand_weight_data == nullptr) { + return RET_ERROR; + } + auto cand_weight_shape = cand_weight_value->tensor_shape(); + + if (gate_weight_shape != gate_shape || cand_weight_shape != cand_shape) { + return RET_ERROR; + } + + // input_update_weight + CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, 0, input_size, hidden_size, + hidden_size * 2, gate_tensor_data, true); + // input_reset_weight + CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, 0, input_size, 0, hidden_size, + gate_tensor_data + input_size * hidden_size, true); + // input_hidden_weight + CopyFlattenMatData(cand_weight_data, input_size + hidden_size, hidden_size, 0, input_size, 0, hidden_size, + gate_tensor_data + input_size * hidden_size * 2, true); + + // state_update_weight + CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, input_size, input_size + hidden_size, + hidden_size, hidden_size * 2, recu_tensor_data, true); + // state_reset_weight + CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, input_size, input_size + hidden_size, + 0, hidden_size, recu_tensor_data + hidden_size * hidden_size, true); + // state_hidden_weight + CopyFlattenMatData(cand_weight_data, input_size + hidden_size, hidden_size, input_size, input_size + hidden_size, 0, + hidden_size, recu_tensor_data + hidden_size * hidden_size * 2, true); + return RET_OK; +} + +STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, + const int hidden_size, float *tensor_data) const { + MS_ASSERT(bias != nullptr); + MS_ASSERT(tensor_data != nullptr); + std::vector gate_shape{hidden_size * 2}; + std::vector cand_shape{hidden_size}; + auto gate_bias_value = GetDefaultParamValue(gate_bias); + if (gate_bias_value == nullptr) { + return RET_ERROR; + } + auto gate_bias_data = reinterpret_cast(gate_bias_value->tensor_addr()); + auto gate_bias_shape = gate_bias_value->tensor_shape(); + auto cand_bias_value = GetDefaultParamValue(cand_bias); + if (cand_bias_value == nullptr) { + return RET_ERROR; + } + auto cand_bias_data = reinterpret_cast(cand_bias_value->tensor_addr()); + auto cand_bias_shape = cand_bias_value->tensor_shape(); + if (gate_bias_shape != gate_shape || cand_bias_shape != cand_shape) { + return RET_ERROR; + } + + // update_gate bias + CopyFlattenMatData(gate_bias_data, 1, hidden_size * 2, 0, 1, hidden_size, hidden_size * 2, tensor_data, false); + // reset_gate bias + CopyFlattenMatData(gate_bias_data, 1, hidden_size * 2, 0, 1, 0, hidden_size, tensor_data + hidden_size, false); + // hidden_gate bias + CopyFlattenMatData(cand_bias_data, 1, hidden_size, 0, 1, 0, hidden_size, tensor_data + hidden_size * 2, false); + + return RET_OK; +} + +CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, + const AnfNodePtr &hidden_state, + const std::string base_name) const { + MS_ASSERT(func_graph); + MS_ASSERT(hidden_state); + auto stack_primitive = std::make_unique(); + std::unique_ptr attr = std::make_unique(); + attr->axis = 0; + stack_primitive->value.type = schema::PrimitiveType_Stack; + stack_primitive->value.value = attr.release(); + auto stack_cvalue = lite::PrimitiveC::Create(stack_primitive.release()); + auto value_node = NewValueNode(std::shared_ptr(stack_cvalue)); + std::vector new_node_inputs = {value_node, hidden_state, hidden_state}; + auto new_node = func_graph->NewCNode(new_node_inputs); + new_node->set_abstract(hidden_state->abstract()->Clone()); + new_node->set_fullname_with_scope("stack_hidden_" + base_name); + return new_node; +} + +CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, + const EquivPtr &equiv, const EquivPtr &fw_body_equiv, + const EquivPtr &bw_body_equiv, + const std::string &base_name) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(input != nullptr); + MS_ASSERT(equiv != nullptr); + MS_ASSERT(fw_body_equiv != nullptr); + MS_ASSERT(bw_body_equiv != nullptr); + auto gru_primitive = std::make_unique(); + std::unique_ptr attr = std::make_unique(); + attr->bidirection = true; + gru_primitive->value.type = schema::PrimitiveType_Gru; + gru_primitive->value.value = attr.release(); + auto gru_cvalue = lite::PrimitiveC::Create(gru_primitive.release()); + auto value_node = NewValueNode(std::shared_ptr(gru_cvalue)); + + auto fw_gate_kernel = utils::cast((*equiv)[fw_vars_[2]]); + MS_ASSERT(fw_gate_kernel); + auto fw_gate_bias = utils::cast((*equiv)[fw_vars_[3]]); + MS_ASSERT(fw_gate_bias); + auto fw_cand_kernel = utils::cast((*equiv)[fw_vars_[4]]); + MS_ASSERT(fw_cand_kernel); + auto fw_cand_bias = utils::cast((*equiv)[fw_vars_[5]]); + MS_ASSERT(fw_cand_bias); + + auto bw_gate_kernel = utils::cast((*equiv)[bw_vars_[2]]); + MS_ASSERT(bw_gate_kernel); + auto bw_gate_bias = utils::cast((*equiv)[bw_vars_[3]]); + MS_ASSERT(bw_gate_bias); + auto bw_cand_kernel = utils::cast((*equiv)[bw_vars_[4]]); + MS_ASSERT(bw_cand_kernel); + auto bw_cand_bias = utils::cast((*equiv)[bw_vars_[5]]); + MS_ASSERT(bw_cand_bias); + + auto hidden = utils::cast((*equiv)[common_vars_[1]]); + MS_ASSERT(hidden); + auto stacked_hidden = GetStackedHiddenState(func_graph, hidden, base_name); + if (stacked_hidden == nullptr) { + return nullptr; + } + auto input_length = utils::cast((*equiv)[input_length_]); + MS_ASSERT(hidden); + + int input_size = 0; + int hidden_size = 0; + auto status = GetInputAndHiddenSize(fw_cand_kernel, bw_cand_kernel, &input_size, &hidden_size); + if (status != RET_OK) { + return nullptr; + } + std::vector gate_weight_shape{2, hidden_size * 3, input_size}; + float *gate_tensor_data = nullptr; + auto gate_weight = AddDefaultParameter(func_graph, base_name + "_gate_weight", gate_weight_shape, kNumberTypeFloat32, + reinterpret_cast(&gate_tensor_data)); + if (gate_weight == nullptr) { + return nullptr; + } + std::vector recu_weight_shape{2, hidden_size * 3, hidden_size}; + float *recu_tensor_data = nullptr; + auto recu_weight = AddDefaultParameter(func_graph, base_name + "_cand_weight", recu_weight_shape, kNumberTypeFloat32, + reinterpret_cast(&recu_tensor_data)); + if (recu_weight == nullptr) { + return nullptr; + } + std::vector bias_shape{2, hidden_size * 6}; + float *bias_tensor_data = nullptr; + auto bias = AddDefaultParameter(func_graph, base_name + "_bias", bias_shape, kNumberTypeFloat32, + reinterpret_cast(&bias_tensor_data)); + if (bias == nullptr) { + return nullptr; + } + for (int i = 0; i < 2 * hidden_size * 6; ++i) { + bias_tensor_data[i] = 0.0f; + } + + if (ConvertWeightData(fw_gate_kernel, fw_cand_kernel, input_size, hidden_size, gate_tensor_data, recu_tensor_data) != + RET_OK) { + return nullptr; + } + auto gate_data_diff = hidden_size * input_size * 3; + auto recu_data_diff = hidden_size * hidden_size * 3; + if (ConvertWeightData(bw_gate_kernel, bw_cand_kernel, input_size, hidden_size, gate_tensor_data + gate_data_diff, + recu_tensor_data + recu_data_diff) != RET_OK) { + return nullptr; + } + + if (ConvertBiasData(fw_gate_bias, fw_cand_bias, hidden_size, bias_tensor_data) != RET_OK) { + return nullptr; + } + auto bias_data_diff = hidden_size * 6; + if (ConvertBiasData(bw_gate_bias, bw_cand_bias, hidden_size, bias_tensor_data + bias_data_diff) != RET_OK) { + return nullptr; + } + std::vector new_node_inputs = {value_node, input, gate_weight, recu_weight, + bias, stacked_hidden, input_length}; + auto new_node = func_graph->NewCNode(new_node_inputs); + new_node->set_fullname_with_scope(base_name); + return new_node; +} + +CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, + const std::string base_name) const { + MS_ASSERT(func_graph); + MS_ASSERT(gru_output); + auto split_primitive = std::make_unique(); + std::unique_ptr split_attr = std::make_unique(); + split_attr->numberSplit = 2; + split_attr->splitDim = 1; + split_primitive->value.type = schema::PrimitiveType_Split; + split_primitive->value.value = split_attr.release(); + auto split_cvalue = lite::PrimitiveC::Create(split_primitive.release()); + auto split_value_node = NewValueNode(std::shared_ptr(split_cvalue)); + std::vector new_node_inputs = {split_value_node, gru_output}; + auto split_new_node = func_graph->NewCNode(new_node_inputs); + split_new_node->set_fullname_with_scope("split_" + base_name); + if (TfliteLstmCellFusion::SetAbstractTuple(split_new_node, 2) != RET_OK) { + return nullptr; + } + + auto split_out1 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 0); + if (split_out1 == nullptr) { + return nullptr; + } + auto split_out2 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 1); + if (split_out2 == nullptr) { + return nullptr; + } + + auto concat_primitive = std::make_unique(); + std::unique_ptr concat_attr = std::make_unique(); + concat_attr->axis = 3; + concat_primitive->value.type = schema::PrimitiveType_Concat; + concat_primitive->value.value = concat_attr.release(); + auto concat_cvalue = lite::PrimitiveC::Create(concat_primitive.release()); + auto concat_value_node = NewValueNode(std::shared_ptr(concat_cvalue)); + std::vector concat_new_node_inputs = {concat_value_node, split_out1, split_out2}; + auto concat_new_node = func_graph->NewCNode(concat_new_node_inputs); + concat_new_node->set_fullname_with_scope("concat_" + base_name); + concat_new_node->set_abstract(gru_output->abstract()->Clone()); + + auto squeeze_primitive = std::make_unique(); + std::unique_ptr squeeze_attr = std::make_unique(); + squeeze_attr->axis = std::vector{1}; + squeeze_primitive->value.type = schema::PrimitiveType_Squeeze; + squeeze_primitive->value.value = squeeze_attr.release(); + auto squeeze_cvalue = lite::PrimitiveC::Create(squeeze_primitive.release()); + auto squeeze_value_node = NewValueNode(std::shared_ptr(squeeze_cvalue)); + std::vector squeeze_new_node_inputs = {squeeze_value_node, concat_new_node}; + auto squeeze_new_node = func_graph->NewCNode(squeeze_new_node_inputs); + squeeze_new_node->set_fullname_with_scope("squeeze_" + base_name); + squeeze_new_node->set_abstract(gru_output->abstract()->Clone()); + + auto transpose_primitive = std::make_unique(); + std::unique_ptr transpose_attr = std::make_unique(); + transpose_attr->perm = std::vector{1, 0, 2}; + transpose_primitive->value.type = schema::PrimitiveType_Transpose; + transpose_primitive->value.value = transpose_attr.release(); + auto transpose_cvalue = lite::PrimitiveC::Create(transpose_primitive.release()); + auto transpose_value_node = NewValueNode(std::shared_ptr(transpose_cvalue)); + std::vector transpose_new_node_inputs = {transpose_value_node, squeeze_new_node}; + auto transpose_new_node = func_graph->NewCNode(transpose_new_node_inputs); + transpose_new_node->set_fullname_with_scope("transpose_" + base_name); + transpose_new_node->set_abstract(gru_output->abstract()->Clone()); + + return transpose_new_node; +} + +const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node, + const EquivPtr &equiv) const { + MS_ASSERT(func_graph); + MS_ASSERT(concat_node); + MS_LOG(DEBUG) << "bidirection tf gru fusion pass"; + if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return nullptr; + } + + auto transpose_input = utils::cast((*equiv)[transpose_input_]); + MS_ASSERT(transpose_input); + if (!utils::isa(transpose_input) || GetCNodeType(transpose_input) != schema::PrimitiveType_Transpose) { + return nullptr; + } + + PrimitiveVarMapPtr fw_cond_primitive_vars = std::make_shared(); + auto fw_cond_graph_pattern = GetCondGraphPattern(fw_cond_primitive_vars); + auto fw_cond = utils::cast((*equiv)[fw_vars_[0]]); + MS_ASSERT(fw_cond != nullptr); + auto fw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_cond_graph_pattern, fw_cond_primitive_vars, + fw_cond, kCondCNodesNum, kCondNodesNum); + if (fw_cond_equiv == nullptr || fw_cond_equiv->empty()) { + return nullptr; + } + + PrimitiveVarMapPtr bw_cond_primitive_vars = std::make_shared(); + auto bw_cond_graph_pattern = GetCondGraphPattern(bw_cond_primitive_vars); + auto bw_cond = utils::cast((*equiv)[bw_vars_[0]]); + MS_ASSERT(bw_cond != nullptr); + auto bw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_cond_graph_pattern, bw_cond_primitive_vars, + bw_cond, kCondCNodesNum, kCondNodesNum); + if (bw_cond_equiv == nullptr || bw_cond_equiv->empty()) { + return nullptr; + } + + PrimitiveVarMapPtr fw_primitive_vars_body = std::make_shared(); + auto fw_body_graph_pattern = GetBodyGraphPattern(fw_primitive_vars_body); + auto fw_body = utils::cast((*equiv)[fw_vars_[1]]); + MS_ASSERT(fw_body != nullptr); + auto fw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_body_graph_pattern, fw_primitive_vars_body, + fw_body, kBodyCNodesNum, kBodyNodesNum); + if (fw_body_equiv == nullptr || fw_body_equiv->empty()) { + return nullptr; + } + + PrimitiveVarMapPtr bw_primitive_vars_body = std::make_shared(); + auto bw_body_graph_pattern = GetBodyGraphPattern(bw_primitive_vars_body); + auto bw_body = utils::cast((*equiv)[bw_vars_[1]]); + MS_ASSERT(bw_body != nullptr); + auto bw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_body_graph_pattern, bw_primitive_vars_body, + bw_body, kBodyCNodesNum, kBodyNodesNum); + if (bw_body_equiv == nullptr || bw_body_equiv->empty()) { + return nullptr; + } + + const std::string gru_name = "gru_" + concat_node->fullname_with_scope(); + auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, fw_body_equiv, bw_body_equiv, gru_name); + if (gru_node == nullptr) { + return nullptr; + } + if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) { + return nullptr; + } + + auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0); + if (get_item_node == nullptr) { + return nullptr; + } + + auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope()); + MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success"; + return output_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h new file mode 100644 index 0000000000..be5ef15190 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ +#include +#include +#include +#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" +#include "schema/inner/model_generated.h" +#include "src/param_value_lite.h" +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace opt { +class BiDirectionTfGruCellFusion : public PatternProcessPass { + public: + explicit BiDirectionTfGruCellFusion(const std::string &name = "bidirection_tf_gru_cell_fusion", + bool multigraph = true); + ~BiDirectionTfGruCellFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + protected: + virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; + + private: + AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; + CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv, + const EquivPtr &fw_body_equiv, const EquivPtr &bw_body_equiv, + const std::string &base_name) const; + ParamValueLitePtr GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const; + lite::STATUS GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, const AnfNodePtr &bw_cand_kernel_anf, + int *input_size, int *hidden_size) const; + ParameterPtr AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name, + const std::vector &shape, const TypeId type, void **tensor_data) const; + lite::STATUS ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, const int input_size, + const int hidden_size, float *gate_tensor_data, float *recu_tensor_data) const; + lite::STATUS ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, const int hidden_size, + float *tensor_data) const; + void CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, const int r1, const int c0, + const int c1, float *data, bool t = false) const; + CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &hidden_state, + const std::string base_name) const; + CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, + const std::string base_name) const; + + private: + std::vector common_vars_; + std::vector fw_vars_; + std::vector bw_vars_; + VarPtr input_; + VarPtr input_length_; + VarPtr transpose_input_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc new file mode 100644 index 0000000000..4171475e0f --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc @@ -0,0 +1,370 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" +#include +#include "src/ops/primitive_c.h" +#include "src/common/utils.h" +#include "src/param_value_lite.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "securec/include/securec.h" +#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kLstmInputsLength = 13; +constexpr size_t kLstmInputsVarNum = 11; +constexpr size_t kCondNodesNum = 12; +constexpr size_t kCondCNodesNum = 4; +constexpr size_t kBodyNodesNum = 82; +constexpr size_t kBodyCNodesNum = 30; +const auto &p1 = std::placeholders::_1; + +bool IsParameterNode(const BaseRef &n) { return utils::isa(n); } + +bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) { + if (utils::isa(n) || utils::isa(n)) { + return opt::GetCNodeType(n) == type; + } + return false; +} +} // namespace + +TfLstmCellFusion::TfLstmCellFusion(const std::string &name, bool multigraph) + : TfliteLstmCellFusion(name, multigraph, kLstmInputsLength, kLstmInputsVarNum, kCondNodesNum, kCondCNodesNum, + kBodyNodesNum, kBodyCNodesNum) { + /* + * vars for lstm cell input + * 0:cond 1:body 2:index 3:limit1 4:output 5:cell 6:hidden 7:limit2 8:input 9:kernel 10:bias + */ + forget_bias_ = std::make_shared(); +} + +AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { + std::vector placeholders; + for (int i = 0; i < 10; ++i) { + placeholders.emplace_back(std::make_shared(IsParameterNode)); + } + VectorRef add2 = VectorRef({std::make_shared(), placeholders[2], std::make_shared(IsParameterNode)}); + VectorRef add3 = VectorRef({std::make_shared(), placeholders[0], std::make_shared(IsParameterNode)}); + + VectorRef get_item = VectorRef( + {std::make_shared("GetItem"), placeholders[7], placeholders[2], std::make_shared(IsParameterNode)}); + VectorRef concat_input_h = VectorRef({std::make_shared(), get_item, placeholders[5]}); + + VectorRef matmul = VectorRef({std::make_shared(), concat_input_h, placeholders[8]}); + VectorRef bias = VectorRef({std::make_shared(), matmul, placeholders[9]}); + VectorRef split = VectorRef({std::make_shared(), bias}); + + VectorRef get_item1 = VectorRef({std::make_shared(), split, std::make_shared()}); + VectorRef get_item2 = VectorRef({std::make_shared(), split, std::make_shared()}); + VectorRef get_item3 = VectorRef({std::make_shared(), split, std::make_shared()}); + VectorRef get_item4 = VectorRef({std::make_shared(), split, std::make_shared()}); + + VectorRef input_gate = VectorRef({std::make_shared("Sigmoid"), get_item1}); + VectorRef input_to_cell = VectorRef({std::make_shared("Tanh"), get_item2}); + VectorRef forget_bias = VectorRef({std::make_shared("Add"), get_item3, forget_bias_}); + VectorRef forget_gate = VectorRef({std::make_shared("Sigmoid"), forget_bias}); + VectorRef output_gate = VectorRef({std::make_shared("Sigmoid"), get_item4}); + + VectorRef forgetted_cell = VectorRef({std::make_shared(""), forget_gate, placeholders[4]}); + VectorRef inputted_cell = VectorRef({std::make_shared(""), input_gate, input_to_cell}); + VectorRef input_forget_cell = VectorRef({std::make_shared("Add"), forgetted_cell, inputted_cell}); + VectorRef to_new_hidden = VectorRef({std::make_shared("Tanh"), input_forget_cell}); + VectorRef new_hidden = VectorRef({std::make_shared("Mul"), output_gate, to_new_hidden}); + + VectorRef new_to_cell = VectorRef({std::make_shared("Mul"), cell_smooth_new_, input_forget_cell}); + VectorRef old_to_cell = VectorRef({std::make_shared("Mul"), cell_smooth_old_, placeholders[4]}); + VectorRef output_cell = VectorRef({std::make_shared("Add"), new_to_cell, old_to_cell}); + + VectorRef new_to_hidden = VectorRef({std::make_shared("Mul"), hidden_smooth_new_, new_hidden}); + VectorRef old_to_hidden = VectorRef({std::make_shared("Mul"), hidden_smooth_old_, placeholders[5]}); + VectorRef output_hidden = VectorRef({std::make_shared("Add"), new_to_hidden, old_to_hidden}); + + VectorRef set_item = VectorRef({std::make_shared(""), placeholders[3], placeholders[2], new_hidden}); + + auto is_make_tuple = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple)); + std::vector outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, output_cell, output_hidden}; + outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); + VectorRef make_tuple_node = VectorRef(outputs); + auto is_return = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); + VectorRef return_node = VectorRef({is_return, make_tuple_node}); + + VarPtr fg = std::make_shared("RootG"); + auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true); + return pattern; +} + +STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector &shape, + const float *const data_ptr, const int hidden_size) const { + MS_ASSERT(weight != nullptr); + MS_ASSERT(data_ptr != nullptr); + auto default_param = std::make_shared(); + if (default_param == nullptr) { + MS_LOG(ERROR) << "new_default is nullptr"; + return RET_ERROR; + } + default_param->set_tensor_shape(shape); + default_param->set_tensor_type(kNumberTypeFloat32); + default_param->set_format(schema::Format_NHWC); + + if (shape.size() != 3) { + MS_LOG(ERROR) << "lstm weight shape must have 3 dims"; + return RET_ERROR; + } + const auto param_num = shape[0] * shape[1] * shape[2]; + auto tensor_data = new (std::nothrow) float[param_num * 4]; + std::vector data_diff{0, 3, 2, 1}; + if (tensor_data == nullptr) { + MS_LOG(DEBUG) << "new data failed"; + return RET_ERROR; + } + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < hidden_size; ++j) { + for (int t = 0; t < shape[2]; ++t) { + tensor_data[(i * hidden_size + j) * shape[2] + t] = data_ptr[t * shape[1] + data_diff[i] * hidden_size + j]; + } + } + } + default_param->SetTensorData(tensor_data, param_num * 4); + weight->set_default_param(default_param); + std::vector shape_vector_i(shape.begin(), shape.end()); + auto abstract_tensor_i = std::make_shared(kFloat32, shape_vector_i); + if (abstract_tensor_i == nullptr) { + MS_LOG(ERROR) << "abstract_tensor is nullptr"; + delete[] tensor_data; + return RET_ERROR; + } + weight->set_abstract(abstract_tensor_i); + return RET_OK; +} + +STATUS TfLstmCellFusion::SplitWeights(const AnfNodePtr &weight, const ParameterPtr &weight_i, + const ParameterPtr &weight_c, int hidden_size) const { + // split input_size and hidden_size at dim 0 + // transform i,c,f,o to i,o,f,c at dim 1 + MS_ASSERT(weight != nullptr); + MS_ASSERT(wiehgt_i != nullptr); + MS_ASSERT(wiehgt_c != nullptr); + if (!utils::isa(weight)) { + return RET_ERROR; + } + auto weight_param = utils::cast(weight); + if (!weight_param->has_default()) { + MS_LOG(DEBUG) << "weight not have default value"; + return RET_ERROR; + } + if (!utils::isa(weight_param->default_param())) { + MS_LOG(DEBUG) << "default value is not ParamValueLite"; + return RET_FAILED; + } + auto origin_tensor = std::dynamic_pointer_cast(weight_param->default_param()); + if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { + MS_LOG(DEBUG) << "origin_tensor is not float32 type"; + return RET_ERROR; + } + auto data_ptr = reinterpret_cast(origin_tensor->tensor_addr()); + auto data_shape = origin_tensor->tensor_shape(); + if (data_shape.size() != 2) { + MS_LOG(ERROR) << "weight data shape invalid"; + return RET_ERROR; + } + if (data_shape[0] <= hidden_size) { + MS_LOG(ERROR) << "weight data shape[0] invalid"; + return RET_ERROR; + } + if (hidden_size * 4 != data_shape[1]) { + MS_LOG(ERROR) << "weight data shape[1] invalid"; + return RET_ERROR; + } + const auto input_size = data_shape[0] - hidden_size; + + std::vector shape_i{1, 4 * hidden_size, input_size}; + if (SetWeightAbstractAndDefault(weight_i, shape_i, data_ptr, hidden_size) != RET_OK) { + MS_LOG(ERROR) << "get weight_i failed"; + return RET_ERROR; + } + + std::vector shape_c{1, 4 * hidden_size, hidden_size}; + if (SetWeightAbstractAndDefault(weight_c, shape_c, data_ptr + input_size * data_shape[1], hidden_size) != RET_OK) { + MS_LOG(ERROR) << "get weight_i failed"; + return RET_ERROR; + } + return RET_OK; +} + +STATUS TfLstmCellFusion::PopulateBiasNode(const EquivPtr &body_equiv, const ParameterPtr &new_bias, + const AnfNodePtr &old_bias, const int hidden_size) const { + MS_ASSERT(body_equiv != nullptr); + MS_ASSERT(new_bias != nullptr); + MS_ASSERT(old_bias != nullptr); + if (!utils::isa(old_bias)) { + MS_LOG(DEBUG) << "old_bias is not parameter"; + return RET_ERROR; + } + auto old_bias_param = utils::cast(old_bias); + if (!old_bias_param->has_default()) { + MS_LOG(DEBUG) << "bias not have default value"; + return RET_ERROR; + } + if (!utils::isa(old_bias_param->default_param())) { + MS_LOG(DEBUG) << "default value is not ParamValueLite"; + return RET_FAILED; + } + auto origin_tensor = std::dynamic_pointer_cast(old_bias_param->default_param()); + if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { + MS_LOG(DEBUG) << "origin_tensor is not float32 type"; + return RET_ERROR; + } + auto data_ptr = reinterpret_cast(origin_tensor->tensor_addr()); + auto data_shape = origin_tensor->tensor_shape(); + if (data_shape.size() != 1 || data_shape[0] != 4 * hidden_size) { + MS_LOG(DEBUG) << "bias data shape illegal"; + return RET_ERROR; + } + std::vector shape{1, 8 * hidden_size}; + + auto default_param = std::make_shared(); + if (default_param == nullptr) { + MS_LOG(ERROR) << "new_default is nullptr"; + return RET_ERROR; + } + default_param->set_tensor_shape(shape); + default_param->set_tensor_type(kNumberTypeFloat32); + default_param->set_format(schema::Format_NHWC); + auto tensor_data = new (std::nothrow) float[hidden_size * 8]; + + auto forget_bias_node = utils::cast((*body_equiv)[forget_bias_]); + if (forget_bias_node == nullptr) { + MS_LOG(ERROR) << "forget bias node is nullptr"; + return RET_ERROR; + } + float forget_bias_value = 0.0f; + if (GetFloatScalarFromParamValueLite(forget_bias_node, &forget_bias_value) != RET_OK) { + return RET_ERROR; + } + + std::vector data_diff{0, 3, 2, 1}; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < hidden_size; ++j) { + if (i < 4) { + tensor_data[i * hidden_size + j] = data_ptr[data_diff[i] * hidden_size + j]; + if (i == 2) { // forget bias + tensor_data[i * hidden_size + j] += forget_bias_value; + } + } else { + tensor_data[i * hidden_size + j] = 0.0f; + } + } + } + default_param->SetTensorData(tensor_data, hidden_size * 8 * 4); + new_bias->set_default_param(default_param); + std::vector shape_vector_i(shape.begin(), shape.end()); + auto abstract_tensor_i = std::make_shared(kFloat32, shape_vector_i); + if (abstract_tensor_i == nullptr) { + MS_LOG(ERROR) << "abstract_tensor is nullptr"; + delete[] tensor_data; + return RET_ERROR; + } + new_bias->set_abstract(abstract_tensor_i); + return RET_OK; +} + +CNodePtr TfLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, + const EquivPtr &body_equiv, const std::string &base_name, + const float smooth) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(equiv != nullptr); + auto lstm_primitive = std::make_unique(); + std::unique_ptr attr = std::make_unique(); + attr->bidirection = false; + attr->smooth = smooth; + lstm_primitive->value.type = schema::PrimitiveType_Lstm; + lstm_primitive->value.value = attr.release(); + auto lstm_cvalue = lite::PrimitiveC::Create(lstm_primitive.release()); + auto value_node = NewValueNode(std::shared_ptr(lstm_cvalue)); + + auto &vars = while_input_vars_; + + auto limit1 = utils::cast((*equiv)[vars[3]]); + MS_ASSERT(limit1); + auto limit2 = utils::cast((*equiv)[vars[7]]); + MS_ASSERT(limit2); + auto weight = utils::cast((*equiv)[vars[9]]); + MS_ASSERT(weight); + auto bias = utils::cast((*equiv)[vars[10]]); + MS_ASSERT(bias); + auto input = utils::cast((*equiv)[vars[8]]); + MS_ASSERT(input); + auto cell = utils::cast((*equiv)[vars[5]]); + MS_ASSERT(cell); + auto hidden = utils::cast((*equiv)[vars[6]]); + MS_ASSERT(hidden); + + if (!utils::isa(hidden)) { + MS_LOG(DEBUG) << "hidden is not parameter"; + return nullptr; + } + auto hidden_param = utils::cast(hidden); + if (!utils::isa(hidden_param->abstract())) { + MS_LOG(DEBUG) << "hidden abstract is not AbstractTensor"; + return nullptr; + } + auto abstract_tensor = utils::cast(hidden_param->abstract()); + auto hidden_shape = utils::cast(abstract_tensor->BuildShape())->shape(); + if (hidden_shape.size() == 0) { + MS_LOG(DEBUG) << "can't get hidden shape"; + return nullptr; + } + + auto i_weight = func_graph->add_parameter(); + i_weight->set_name(base_name + "_weight_i"); + i_weight->set_abstract(weight->abstract()->Clone()); + + auto c_weight = func_graph->add_parameter(); + c_weight->set_name(base_name + "_weight_c"); + c_weight->set_abstract(weight->abstract()->Clone()); + + if (SplitWeights(weight, i_weight, c_weight, hidden_shape.back()) != RET_OK) { + MS_LOG(DEBUG) << "split weight to i_weight and c_weight failed"; + return nullptr; + } + + auto bias_node = func_graph->add_parameter(); + bias_node->set_name(base_name + "_bias"); + bias_node->set_abstract(bias->abstract()->Clone()); + + if (PopulateBiasNode(body_equiv, bias_node, bias, hidden_shape.back()) != RET_OK) { + MS_LOG(DEBUG) << "reorder bias failed"; + return nullptr; + } + + if (!utils::isa(input) || GetCNodeType(input) != schema::PrimitiveType_TensorListFromTensor) { + MS_LOG(DEBUG) << "input is not tensorlistfromtensor op"; + return nullptr; + } + auto tensor_list_cnode = utils::cast(input); + auto input_tensor_node = tensor_list_cnode->input(1); + + std::vector new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias_node, hidden, + cell}; + auto new_node = func_graph->NewCNode(new_node_inputs); + new_node->set_fullname_with_scope(base_name); + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.h b/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.h new file mode 100644 index 0000000000..ce8628e3f3 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_ +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" +#include "src/param_value_lite.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace opt { +class TfLstmCellFusion : public TfliteLstmCellFusion { + public: + explicit TfLstmCellFusion(const std::string &name = "lstm_cell_fusion", bool multigraph = true); + ~TfLstmCellFusion() override = default; + + private: + AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const override; + CNodePtr CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const EquivPtr &body_equiv, + const std::string &base_name, const float smooth) const override; + + lite::STATUS SplitWeights(const AnfNodePtr &weight, const ParameterPtr &weight_i, const ParameterPtr &weight_c, + int hidden_size) const; + lite::STATUS SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector &shape, + const float *const data_ptr, const int hidden_size) const; + lite::STATUS PopulateBiasNode(const EquivPtr &body_equiv, const ParameterPtr &new_bias, const AnfNodePtr &old_bias, + const int hidden_size) const; + + private: + VarPtr forget_bias_ = nullptr; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc new file mode 100644 index 0000000000..4812ed15d0 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc @@ -0,0 +1,727 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" +#include +#include +#include "src/ops/primitive_c.h" +#include "src/common/utils.h" +#include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "securec/include/securec.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kWhileInputsLength = 23; +constexpr size_t kWhileInputsVarNum = 21; +constexpr size_t kCondNodesNum = 12; +constexpr size_t kCondCNodesNum = 4; +constexpr size_t kBodyNodesNum = 95; +constexpr size_t kBodyCNodesNum = 34; +constexpr size_t kLSTMOutputNum = 3; +const auto &p1 = std::placeholders::_1; +constexpr float EPSILON = 1e-5; + +bool IsParameterNode(const BaseRef &n) { return utils::isa(n); } + +bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) { + if (utils::isa(n) || utils::isa(n)) { + return opt::GetCNodeType(n) == type; + } + return false; +} +} // namespace + +STATUS TfliteLstmCellFusion::GetFloatScalarFromParamValueLite(const AnfNodePtr ¶m_value, float *v) const { + if (param_value == nullptr || v == nullptr) { + MS_LOG(ERROR) << "param_value or v is nullptr"; + return RET_ERROR; + } + if (!utils::isa(param_value)) { + MS_LOG(DEBUG) << "param_value is not ParamValueLitePtr"; + return RET_ERROR; + } + auto param_ptr = utils::cast(param_value); + if (!param_ptr->has_default()) { + MS_LOG(DEBUG) << "param not have default"; + return RET_ERROR; + } + auto default_param = param_ptr->default_param(); + if (!utils::isa(default_param)) { + MS_LOG(DEBUG) << "param_value is not ParamValueLitePtr"; + return RET_ERROR; + } + auto default_param_ptr = utils::cast(default_param); + auto tensor_shape = default_param_ptr->tensor_shape(); + if (!(tensor_shape.size() == 0 || (tensor_shape.size() == 1 && tensor_shape[0] == 1))) { + MS_LOG(DEBUG) << "default param is not scalar"; + return RET_ERROR; + } + if (default_param_ptr->tensor_type() != kNumberTypeFloat32 && default_param_ptr->tensor_type() != kNumberTypeFloat) { + MS_LOG(DEBUG) << "default param is not float"; + return RET_ERROR; + } + *v = *(reinterpret_cast(default_param_ptr->tensor_addr())); + return RET_OK; +} + +TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num, + int cond_nodes_num, int cond_cnodes_num, int body_nodes_num, + int body_cnodes_num) + : PatternProcessPass(name, multigraph) { + /* + * input vars for lstm while node + * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_ + * 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_ + */ + this->while_inputs_num_ = input_length == 0 ? kWhileInputsLength : input_length; + this->while_input_var_num_ = var_num == 0 ? kWhileInputsVarNum : var_num; + this->cond_nodes_num_ = cond_nodes_num == 0 ? kCondNodesNum : cond_nodes_num; + this->cond_cnodes_num_ = cond_cnodes_num == 0 ? kCondCNodesNum : cond_cnodes_num; + this->body_nodes_num_ = body_nodes_num == 0 ? kBodyNodesNum : body_nodes_num; + this->body_cnodes_num_ = body_cnodes_num == 0 ? kBodyCNodesNum : body_cnodes_num; + for (size_t i = 0; i < this->while_input_var_num_; ++i) { + while_input_vars_.emplace_back(std::make_shared()); + } + cell_smooth_old_ = std::make_shared(); + cell_smooth_new_ = std::make_shared(); + hidden_smooth_old_ = std::make_shared(); + hidden_smooth_new_ = std::make_shared(); +} + +AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { + auto is_parameter1 = std::make_shared(IsParameterNode); + auto is_parameter2 = std::make_shared(IsParameterNode); + auto is_parameter3 = std::make_shared(IsParameterNode); + auto is_parameter4 = std::make_shared(IsParameterNode); + auto is_less1 = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); + auto is_less2 = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); + auto is_logical_and = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_LogicalAnd)); + auto is_return = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); + VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); + VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); + VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); + VectorRef return_ref = VectorRef({is_return, logicaland_ref}); + VarPtr fg = std::make_shared("RootG"); + auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true); + return pattern; +} + +AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { + std::vector placeholders; + for (int i = 0; i < 20; ++i) { + placeholders.emplace_back(std::make_shared(IsParameterNode)); + } + VectorRef add2 = VectorRef({std::make_shared(), placeholders[2], std::make_shared(IsParameterNode)}); + VectorRef add3 = VectorRef({std::make_shared(), placeholders[0], std::make_shared(IsParameterNode)}); + + VectorRef concat_i_w = VectorRef({std::make_shared(), placeholders[8], placeholders[12]}); + VectorRef concat_f_w = VectorRef({std::make_shared(), placeholders[9], placeholders[13]}); + VectorRef concat_c_w = VectorRef({std::make_shared(), placeholders[10], placeholders[14]}); + VectorRef concat_o_w = VectorRef({std::make_shared(), placeholders[11], placeholders[15]}); + + VectorRef get_item = VectorRef( + {std::make_shared("GetItem"), placeholders[7], placeholders[2], std::make_shared(IsParameterNode)}); + VectorRef concat_input_h = VectorRef({std::make_shared(), get_item, placeholders[5]}); + + VectorRef matmul_input = VectorRef({std::make_shared(), concat_input_h, concat_i_w}); + VectorRef matmul_forget = VectorRef({std::make_shared(), concat_input_h, concat_f_w}); + VectorRef matmul_cell = VectorRef({std::make_shared(), concat_input_h, concat_c_w}); + VectorRef matmul_output = VectorRef({std::make_shared(), concat_input_h, concat_o_w}); + + VectorRef bias_input = VectorRef({std::make_shared(), matmul_input, placeholders[16]}); + VectorRef bias_forget = VectorRef({std::make_shared(), matmul_forget, placeholders[17]}); + VectorRef bias_cell = VectorRef({std::make_shared(), matmul_cell, placeholders[18]}); + VectorRef bias_output = VectorRef({std::make_shared(), matmul_output, placeholders[19]}); + + VectorRef cell = VectorRef({std::make_shared("Tanh"), bias_cell}); + VectorRef input_gate = VectorRef({std::make_shared("Sigmoid"), bias_input}); + VectorRef cell_input = VectorRef({std::make_shared("Mul"), input_gate, cell}); + VectorRef forget_gate = VectorRef({std::make_shared("Sigmoid"), bias_forget}); + VectorRef cell_forgeted = VectorRef({std::make_shared("Mul"), forget_gate, placeholders[4]}); + VectorRef cell_new = VectorRef({std::make_shared("Add"), cell_forgeted, cell_input}); + + VectorRef smooth_cell_old = VectorRef({std::make_shared("Mul"), cell_smooth_old_, placeholders[4]}); + VectorRef smooth_cell_new = VectorRef({std::make_shared("Mul"), cell_smooth_new_, cell_new}); + VectorRef cell_output = VectorRef({std::make_shared("Add"), smooth_cell_new, smooth_cell_old}); + + VectorRef output_gate = VectorRef({std::make_shared("Sigmoid"), bias_output}); + VectorRef cell_to_output = VectorRef({std::make_shared("Tanh"), cell_new}); + VectorRef output = VectorRef({std::make_shared("Mul"), output_gate, cell_to_output}); + + VectorRef smooth_hidden_old = VectorRef({std::make_shared("Mul"), hidden_smooth_old_, placeholders[5]}); + VectorRef smooth_hidden_new = VectorRef({std::make_shared("Mul"), hidden_smooth_new_, output}); + VectorRef hidden_output = VectorRef({std::make_shared("Add"), smooth_hidden_new, smooth_hidden_old}); + + VectorRef set_item = VectorRef({std::make_shared("SetItem"), placeholders[3], placeholders[2], output}); + + auto is_make_tuple = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple)); + std::vector outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output}; + outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); + VectorRef make_tuple_node = VectorRef(outputs); + auto is_return = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); + VectorRef return_node = VectorRef({is_return, make_tuple_node}); + + VarPtr fg = std::make_shared("RootG"); + auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true); + return pattern; +} + +const BaseRef TfliteLstmCellFusion::DefinePattern() const { + auto is_while_node = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_While)); + VectorRef while_node = VectorRef({is_while_node}); + auto while_inputs = while_input_vars_; + while_inputs.insert(while_inputs.begin() + 4, while_input_vars_[2]); + while_node.insert(while_node.end(), while_inputs.begin(), while_inputs.end()); + + auto is_tuple_get_item = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)); + VectorRef while_output = VectorRef({is_tuple_get_item, while_node, std::make_shared()}); + + auto is_tensor_list_stack = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)); + auto is_parameter = std::make_shared(IsParameterNode); + VectorRef tensor_list_stack_node = VectorRef({is_tensor_list_stack, while_output, is_parameter}); + + return tensor_list_stack_node; +} + +EquivPtr TfliteLstmCellFusion::MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars, + const AnfNodePtr &pattern) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(pattern != nullptr); + auto return_node = func_graph->get_return(); + PatternEngine pattern_engine(PatternEngine(std::make_shared(), + std::function(AnfEqual), + std::function(CNodeTypeEqual))); + auto empty_equiv = std::make_shared(); + EquivPtr equiv = pattern_engine.Match(pattern, return_node, *primitive_vars, empty_equiv); + return equiv; +} + +// make sure that only 3,4,5 output of while is referenced +bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(while_cnode != nullptr); + auto manager = func_graph->manager(); + if (manager == nullptr) { + MS_LOG(ERROR) << "manager is nullptr"; + return RET_ERROR; + } + auto while_node_users = manager->node_users()[while_cnode]; + std::vector valid_indexes{3, 4, 5}; + for (auto &node_user : while_node_users) { + if (!utils::isa(node_user.first)) { + return false; + } + auto cnode = utils::cast(node_user.first); + if (GetCNodeType(cnode) != schema::PrimitiveType_TupleGetItem) { + return false; + } + auto index = GetTupleGetItemOutIndex(cnode); + if (!lite::IsContain(valid_indexes, index)) { + return false; + } + } + return true; +} + +EquivPtr TfliteLstmCellFusion::CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern, + const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph, + const size_t cnode_num, const size_t all_node_num) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(pattern != nullptr); + MS_ASSERT(anf_sub_graph != nullptr); + auto sub_graph = GetValueNode(anf_sub_graph); + auto nodes = TopoSort(sub_graph->get_return()); + auto cnodes = sub_graph->GetOrderedCnodes(); + if (cnodes.size() != cnode_num || nodes.size() != all_node_num) { + MS_LOG(DEBUG) << "sub graph nodes num not match"; + return nullptr; + } + return MatchGraph(sub_graph, primitive_vars, pattern); +} + +bool TfliteLstmCellFusion::CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv, + const CNodePtr &while_cnode, float *smooth) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(equiv != nullptr); + MS_ASSERT(while_cnode != nullptr); + MS_ASSERT(smooth != nullptr); + + auto cell_smooth_old_node = utils::cast((*equiv)[cell_smooth_old_]); + MS_ASSERT(cell_smooth_old_node != nullptr); + auto cell_smooth_new_node = utils::cast((*equiv)[cell_smooth_new_]); + MS_ASSERT(cell_smooth_new_node != nullptr); + auto hidden_smooth_old_node = utils::cast((*equiv)[hidden_smooth_old_]); + MS_ASSERT(hidden_smooth_old_node != nullptr); + auto hidden_smooth_new_node = utils::cast((*equiv)[hidden_smooth_new_]); + MS_ASSERT(hidden_smooth_new_node != nullptr); + + float cell_old, cell_new, hidden_old, hidden_new; + if (GetFloatScalarFromParamValueLite(cell_smooth_old_node, &cell_old) != RET_OK) { + return false; + } + if (GetFloatScalarFromParamValueLite(cell_smooth_new_node, &cell_new) != RET_OK) { + return false; + } + if (GetFloatScalarFromParamValueLite(hidden_smooth_old_node, &hidden_old) != RET_OK) { + return false; + } + if (GetFloatScalarFromParamValueLite(hidden_smooth_new_node, &hidden_new) != RET_OK) { + return false; + } + if (cell_old < 0.0f || cell_old > 1.0f || cell_new < 0.0f || cell_new > 1.0f) { + MS_LOG(DEBUG) << "cell smooth value illegal"; + return false; + } + if (hidden_old < 0.0f || hidden_old > 1.0f || hidden_new < 0.0f || hidden_new > 1.0f) { + MS_LOG(DEBUG) << "hidden smooth value illegal"; + return false; + } + if (std::abs(cell_old + cell_new - 1.0f) > EPSILON || std::abs(hidden_old + hidden_new - 1.0f) > EPSILON || + std::abs(cell_old - hidden_old) > EPSILON) { + MS_LOG(DEBUG) << "smooth value illegal"; + return false; + } + *smooth = cell_old; + return true; +} + +STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector ¶ms, const ParameterPtr &new_param, + bool is_bias) const { + MS_ASSERT(new_param != nullptr); + MS_ASSERT(params.size() == 4); + std::vector data_ptrs; + std::vector> data_shapes; + for (auto ¶m : params) { + if (!utils::isa(param)) { + MS_LOG(DEBUG) << "param is not Parameter node"; + return RET_FAILED; + } + auto param_t = utils::cast(param); + if (!param_t->has_default()) { + MS_LOG(DEBUG) << "param not have default value"; + return RET_FAILED; + } + if (!utils::isa(param_t->default_param())) { + MS_LOG(DEBUG) << "default value is not ParamValueLite"; + return RET_FAILED; + } + auto origin_tensor = std::dynamic_pointer_cast(param_t->default_param()); + if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { + MS_LOG(DEBUG) << "origin_tensor is not float32 type"; + return RET_FAILED; + } + auto data_ptr = reinterpret_cast(origin_tensor->tensor_addr()); + auto data_shape = origin_tensor->tensor_shape(); + data_ptrs.push_back(data_ptr); + data_shapes.push_back(data_shape); + } + + for (size_t i = 1; i < data_shapes.size(); ++i) { + if (data_shapes[i] != data_shapes[0]) { + MS_LOG(DEBUG) << "data shape not same"; + return RET_FAILED; + } + } + auto new_default = std::make_shared(); + if (new_default == nullptr) { + MS_LOG(ERROR) << "new_default is nullptr"; + return RET_ERROR; + } + std::vector new_shape; + float *tensor_data = nullptr; + int step = 0; + int data_size = 0; + if (is_bias) { + if (data_shapes[0].size() != 1) { + MS_LOG(ERROR) << "bias data shape error"; + return RET_ERROR; + } + step = data_shapes[0][0]; + data_size = 8 * step; + new_shape = std::vector({1, data_size}); + + } else { + if (data_shapes[0].size() != 2) { + MS_LOG(ERROR) << "weight data shape error"; + return RET_ERROR; + } + new_shape = std::vector({1, data_shapes[0][0] * 4, data_shapes[0][1]}); + step = data_shapes[0][0] * data_shapes[0][1]; + data_size = 4 * step; + } + + tensor_data = new (std::nothrow) float[data_size]; + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new data failed"; + return RET_ERROR; + } + for (int i = 0; i < data_size; ++i) { // bias are stored into first 4*hidden_size buffer, the rest is all 0 + tensor_data[i] = 0.0f; + } + + for (size_t i = 0; i < data_ptrs.size(); ++i) { + auto source_len = std::accumulate(data_shapes[i].begin(), data_shapes[i].end(), 1, std::multiplies()); + auto ret = memcpy_s(tensor_data + i * step, step * sizeof(float), data_ptrs[i], source_len * sizeof(float)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s error"; + delete[] tensor_data; + return RET_ERROR; + } + } + new_default->set_tensor_shape(new_shape); + new_default->set_tensor_type(kNumberTypeFloat32); + new_default->set_format(schema::Format_NHWC); + new_default->SetTensorData(tensor_data, data_size * sizeof(float)); + new_param->set_default_param(new_default); + + std::vector shape_vector(new_shape.begin(), new_shape.end()); + auto abstract_tensor = std::make_shared(kFloat32, shape_vector); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "abstract_tensor is nullptr"; + return RET_ERROR; + } + new_param->set_abstract(abstract_tensor); + return RET_OK; +} + +CNodePtr TfliteLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, + const EquivPtr &body_equiv, const std::string &base_name, + const float smooth) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(equiv != nullptr); + MS_ASSERT(body_equiv != nullptr); + /* + * input vars for while node + * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_ + * 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_ + */ + auto lstm_primitive = std::make_unique(); + std::unique_ptr attr = std::make_unique(); + attr->bidirection = false; + attr->smooth = smooth; + lstm_primitive->value.type = schema::PrimitiveType_Lstm; + lstm_primitive->value.value = attr.release(); + auto lstm_cvalue = lite::PrimitiveC::Create(lstm_primitive.release()); + auto value_node = NewValueNode(std::shared_ptr(lstm_cvalue)); + + auto &vars = while_input_vars_; + + auto limit1 = utils::cast((*equiv)[vars[3]]); + MS_ASSERT(limit1); + auto limit2 = utils::cast((*equiv)[vars[7]]); + MS_ASSERT(limit2); + + auto i2i_weight = utils::cast((*equiv)[vars[9]]); + MS_ASSERT(i2i_weight); + auto i2f_weight = utils::cast((*equiv)[vars[10]]); + MS_ASSERT(i2f_weight); + auto i2c_weight = utils::cast((*equiv)[vars[11]]); + MS_ASSERT(i2c_weight); + auto i2o_weight = utils::cast((*equiv)[vars[12]]); + MS_ASSERT(i2o_weight); + + auto c2i_weight = utils::cast((*equiv)[vars[13]]); + MS_ASSERT(c2i_weight); + auto c2f_weight = utils::cast((*equiv)[vars[14]]); + MS_ASSERT(c2f_weight); + auto c2c_weight = utils::cast((*equiv)[vars[15]]); + MS_ASSERT(c2c_weight); + auto c2o_weight = utils::cast((*equiv)[vars[16]]); + MS_ASSERT(c2o_weight); + + auto i_bias = utils::cast((*equiv)[vars[17]]); + MS_ASSERT(i_bias); + auto f_bias = utils::cast((*equiv)[vars[18]]); + MS_ASSERT(f_bias); + auto c_bias = utils::cast((*equiv)[vars[19]]); + MS_ASSERT(c_bias); + auto o_bias = utils::cast((*equiv)[vars[20]]); + MS_ASSERT(o_bias); + + auto input = utils::cast((*equiv)[vars[8]]); + MS_ASSERT(input); + auto cell = utils::cast((*equiv)[vars[5]]); + MS_ASSERT(cell); + auto hidden = utils::cast((*equiv)[vars[6]]); + MS_ASSERT(hidden); + + std::vector i_weights{i2i_weight, i2o_weight, i2f_weight, i2c_weight}; + auto i_weight = func_graph->add_parameter(); + auto status = GetConcatedParam(i_weights, i_weight, false); + if (status != RET_OK) { + return nullptr; + } + i_weight->set_name(base_name + "_weight_i"); + + std::vector c_weights{c2i_weight, c2o_weight, c2f_weight, c2c_weight}; + auto c_weight = func_graph->add_parameter(); + status = GetConcatedParam(c_weights, c_weight, false); + if (status != RET_OK) { + return nullptr; + } + c_weight->set_name(base_name + "_weight_c"); + + std::vector biases{i_bias, o_bias, f_bias, c_bias}; + auto bias = func_graph->add_parameter(); + status = GetConcatedParam(biases, bias, true); + if (status != RET_OK) { + return nullptr; + } + bias->set_name(base_name + "_bias"); + + if (!utils::isa(input) || GetCNodeType(input) != schema::PrimitiveType_TensorListFromTensor) { + MS_LOG(DEBUG) << "input is not tensorlistfromtensor op"; + return nullptr; + } + auto tensor_list_cnode = utils::cast(input); + auto input_tensor_node = tensor_list_cnode->input(1); + + std::vector new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias, hidden, cell}; + auto new_node = func_graph->NewCNode(new_node_inputs); + new_node->set_fullname_with_scope(base_name); + return new_node; +} + +CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, + const int item_index) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(node != nullptr); + MS_ASSERT(get_items != nullptr); + auto tuple_get_item_prim_ptr = lite::GetTupleGetItemPrim(); + if (tuple_get_item_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; + return nullptr; + } + auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); + auto get_item_value = NewValueNode(MakeValue(item_index)); + if (tuple_get_item_prim == nullptr || get_item_value == nullptr) { + MS_LOG(ERROR) << "NewValueNode is nullptr"; + return nullptr; + } + std::vector inputs{tuple_get_item_prim, node, get_item_value}; + CNodePtr get_item_cnode = func_graph->NewCNode(inputs); + std::vector shape_vector; + auto abstract_tensor = std::make_shared(kFloat32, shape_vector); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "create abstract_tensor failed"; + return nullptr; + } + get_item_cnode->set_abstract(abstract_tensor); + get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" + + std::to_string(item_index)); + return get_item_cnode; +} + +STATUS TfliteLstmCellFusion::AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode, + const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(while_cnode != nullptr); + auto manager = func_graph->manager(); + if (manager == nullptr) { + MS_LOG(ERROR) << "manager is nullptr"; + return RET_ERROR; + } + auto tr = manager->Transact(); + auto while_node_users = manager->node_users()[while_cnode]; + for (auto &node_user : while_node_users) { + if (node_user.first == output_get_item) { + continue; + } + if (!utils::isa(node_user.first)) { + return RET_ERROR; + } + auto get_item = utils::cast(node_user.first); + if (GetCNodeType(get_item) != schema::PrimitiveType_TupleGetItem) { + return RET_ERROR; + } + auto new_inputs = get_item->inputs(); + if (new_inputs.size() != 3) { + return RET_ERROR; + } + new_inputs[1] = lstm_cnode; + auto index_vnode = get_item->input(2); + if (!utils::isa(index_vnode)) { + MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node"; + return RET_ERROR; + } + auto value_node = utils::cast(index_vnode); + if (value_node == nullptr) { + MS_LOG(ERROR) << "cast to ValueNode failed"; + return RET_ERROR; + } + auto origin_index = GetValue(value_node->value()); + int new_index = origin_index == 4 ? 2 : 1; + auto new_index_vnode = NewValueNode(MakeValue(new_index)); + new_inputs[2] = new_index_vnode; + get_item->set_inputs(new_inputs); + get_item->set_fullname_with_scope(lstm_cnode->fullname_with_scope() + "_getitem_" + std::to_string(new_index)); + if (get_item->abstract() == nullptr) { + MS_LOG(ERROR) << "get_item's abstract is nullptr"; + return RET_ERROR; + } + + std::vector squeeze_axis{0}; + auto squeeze_node = CreateSqueezeNode(func_graph, get_item, squeeze_axis); + if (squeeze_node == nullptr) { + return RET_ERROR; + } + + auto get_item_users = manager->node_users()[get_item]; + for (auto &get_item_user : get_item_users) { + tr.SetEdge(get_item_user.first, get_item_user.second, squeeze_node); + } + } + tr.Commit(); + return RET_OK; +} + +STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int output_num) { + MS_ASSERT(cnode != nullptr); + AbstractBasePtrList abstract_list; + for (int i = 0; i < output_num; ++i) { + std::vector shape_vector; + auto abstract_tensor = std::make_shared(kFloat32, shape_vector); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "create abstract_tensor failed"; + return RET_ERROR; + } + abstract_list.emplace_back(abstract_tensor); + } + auto abstract_tuple = std::make_shared(abstract_list); + if (abstract_tuple == nullptr) { + MS_LOG(ERROR) << "create abstract_tuple failed"; + return RET_ERROR; + } + cnode->set_abstract(abstract_tuple); + return RET_OK; +} + +CNodePtr TfliteLstmCellFusion::CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node, + const std::vector &axis) const { + MS_ASSERT(func_graph != nullptr); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new SqueezeT failed"; + return nullptr; + } + attr->axis = axis; + auto new_primitive_t = std::make_unique(); + if (new_primitive_t == nullptr) { + MS_LOG(ERROR) << "primitive_t is nullptr"; + return nullptr; + } + new_primitive_t->value.type = schema::PrimitiveType_Squeeze; + new_primitive_t->value.value = attr.release(); + auto new_primtive_c = std::shared_ptr(lite::PrimitiveC::Create(new_primitive_t.release())); + if (new_primtive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return nullptr; + } + ValueNodePtr value_node = NewValueNode(new_primtive_c); + auto squeeze_cnode = func_graph->NewCNode({value_node, input_node}); + squeeze_cnode->set_abstract(input_node->abstract()->Clone()); + squeeze_cnode->set_fullname_with_scope("squeeze_" + input_node->fullname_with_scope()); + return squeeze_cnode; +} + +const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(node != nullptr); + MS_LOG(DEBUG) << "lstm fusion pass"; + if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return nullptr; + } + + if (!utils::isa(node)) { + return nullptr; + } + auto tensor_list_stack_cnode = utils::cast(node); + auto tuple_get_item_node = tensor_list_stack_cnode->input(1); + if (!utils::isa(tuple_get_item_node)) { + return nullptr; + } + auto tuple_get_item_cnode = utils::cast(tuple_get_item_node); + auto while_node = tuple_get_item_cnode->input(1); + if (!utils::isa(while_node)) { + return nullptr; + } + auto while_cnode = utils::cast(while_node); + + if (CheckIfCNodeIsNull(while_cnode) != RET_OK || CheckInputSize(while_cnode, while_inputs_num_) != RET_OK) { + return nullptr; + } + if (!CheckReferencedOutputs(func_graph, while_cnode)) { + return nullptr; + } + PrimitiveVarMapPtr primitive_vars_cond = std::make_shared(); + auto cond_graph_pattern = GetCondGraphPattern(primitive_vars_cond); + auto cond_equiv = CheckSubGraph(func_graph, cond_graph_pattern, primitive_vars_cond, while_cnode->input(1), + cond_cnodes_num_, cond_nodes_num_); + if (cond_equiv == nullptr || cond_equiv->empty()) { + return nullptr; + } + PrimitiveVarMapPtr primitive_vars_body = std::make_shared(); + auto body_graph_pattern = GetBodyGraphPattern(primitive_vars_body); + auto body_equiv = CheckSubGraph(func_graph, body_graph_pattern, primitive_vars_body, while_cnode->input(2), + body_cnodes_num_, body_nodes_num_); + if (body_equiv == nullptr || body_equiv->empty()) { + return nullptr; + } + float smooth = 0.0f; + if (!CheckBodyGraph(func_graph, body_equiv, while_cnode, &smooth)) { + return nullptr; + } + const std::string lstm_name = "lstm_" + while_cnode->fullname_with_scope(); + auto lstm_node = CreateLSTMNode(func_graph, equiv, body_equiv, lstm_name, smooth); + if (lstm_node == nullptr) { + return nullptr; + } + auto status = SetAbstractTuple(lstm_node, kLSTMOutputNum); + if (status != RET_OK) { + return nullptr; + } + + auto get_item_node = CreateOutputGetItem(func_graph, lstm_node, 0); + if (get_item_node == nullptr) { + MS_LOG(DEBUG) << "create lstm output get_item node failed"; + return nullptr; + } + + status = AdjustOtherGetItems(func_graph, while_cnode, lstm_node, tuple_get_item_cnode); + if (status != RET_OK) { + return nullptr; + } + + std::vector squeeze_axis{1}; // our lstm output:0 have an extra axis that tflite not have, it must be squeezed + auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis); + if (squeeze_node == nullptr) { + return nullptr; + } + + auto cond_cnode_index_pair = std::make_shared(while_cnode, 1); + func_graph->DropFuncGraphCNodeIndex(cond_cnode_index_pair); + auto body_cnode_index_pair = std::make_shared(while_cnode, 2); + func_graph->DropFuncGraphCNodeIndex(body_cnode_index_pair); + MS_LOG(INFO) << "lstm node:" << lstm_node->fullname_with_scope() << " fusion success"; + return squeeze_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.h b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.h new file mode 100644 index 0000000000..7a9ab7ed3e --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.h @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace opt { +class TfliteLstmCellFusion : public PatternProcessPass { + public: + explicit TfliteLstmCellFusion(const std::string &name = "tflite_lstm_cell_fusion", bool multigraph = true, + int input_length = 0, int var_num = 0, int cond_nodes_num = 0, int cond_cnodes_num = 0, + int body_nodes_num = 0, int body_cnodes_num = 0); + ~TfliteLstmCellFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + public: + static EquivPtr MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars, + const AnfNodePtr &pattern); + static EquivPtr CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern, + const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph, + const size_t cnode_num, const size_t all_node_num); + static lite::STATUS SetAbstractTuple(const CNodePtr &cnode, const int output_num); + static CNodePtr CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, const int item_index); + + protected: + VarPtr cell_smooth_old_ = nullptr; + VarPtr cell_smooth_new_ = nullptr; + VarPtr hidden_smooth_old_ = nullptr; + VarPtr hidden_smooth_new_ = nullptr; + std::vector while_input_vars_; + + lite::STATUS GetFloatScalarFromParamValueLite(const AnfNodePtr ¶m_value, float *v) const; + CNodePtr CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node, + const std::vector &axis) const; + lite::STATUS AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode, + const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) const; + AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; + virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; + virtual CNodePtr CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const EquivPtr &body_equiv, + const std::string &base_name, const float smooth) const; + + private: + bool CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const CNodePtr &while_cnode, + float *smooth) const; + bool CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) const; + + lite::STATUS GetConcatedParam(const std::vector ¶ms, const ParameterPtr &new_param, + bool is_bias) const; + + private: + size_t while_input_var_num_ = 0; + size_t while_inputs_num_ = 0; + size_t cond_nodes_num_ = 0; + size_t cond_cnodes_num_ = 0; + size_t body_nodes_num_ = 0; + size_t body_cnodes_num_ = 0; +}; + +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_