From: @wangzhe128 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -0,0 +1,134 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp32/gru_fp32.h" | |||
| #include <string.h> | |||
| #include "nnacl/fp32/lstm_fp32.h" | |||
| #include "nnacl/fp32/activation_fp32.h" | |||
| #include "nnacl/fp32/arithmetic_fp32.h" | |||
| void InitGruGate(float *gate_buffer, const float *bias, const GruParameter *gru_parm) { | |||
| int gate_offest = 0; | |||
| for (int l = 0; l < 3; l++) { | |||
| int batch_offest = gate_offest; | |||
| int bias_offest = l * gru_parm->hidden_size_; | |||
| for (int b = 0; b < gru_parm->batch_; b++) { | |||
| memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float)); | |||
| batch_offest += gru_parm->hidden_size_; | |||
| } | |||
| gate_offest += gru_parm->batch_ * gru_parm->hidden_size_; | |||
| } | |||
| } | |||
| void GruStepUnit(float *output, const float *input, const float *input_reset_weight, const float *input_update_weight, | |||
| const float *input_hidden_weight, const float *state_reset_weight, const float *state_update_weight, | |||
| const float *state_hidden_weight, const float *bias, float *hidden_state, float *gate_buffer, | |||
| const GruParameter *gru_parm) { | |||
| InitGruGate(gate_buffer, bias, gru_parm); | |||
| float *update_gate = gate_buffer; | |||
| float *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_; | |||
| float *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2; | |||
| // input * weight | |||
| MatMulAcc(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); | |||
| MatMulAcc(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); | |||
| MatMulAcc(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); | |||
| // state * weight | |||
| MatMulAcc(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, | |||
| gru_parm->hidden_size_); | |||
| MatMulAcc(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_, | |||
| gru_parm->hidden_size_); | |||
| // update reset_gate | |||
| Sigmoid(reset_gate, gru_parm->batch_ * gru_parm->hidden_size_, reset_gate); | |||
| // update update_gate | |||
| Sigmoid(update_gate, gru_parm->batch_ * gru_parm->hidden_size_, update_gate); | |||
| ElementMul(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_); | |||
| MatMulAcc(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, | |||
| gru_parm->hidden_size_); | |||
| Tanh(hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_, hidden_buffer); | |||
| ElementMul(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); | |||
| ArithmeticParameter parameter; | |||
| parameter.in_elements_num0_ = 1; | |||
| parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_; | |||
| const float one = 1.0f; | |||
| ElementOptSub(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, ¶meter); | |||
| ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); | |||
| memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float)); | |||
| } | |||
| void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias, | |||
| float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm) { | |||
| // forward | |||
| const float *input_update_weight = weight_g; | |||
| const float *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_; | |||
| const float *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2; | |||
| const float *state_update_weight = weight_r; | |||
| const float *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_; | |||
| const float *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2; | |||
| for (int t = 0; t < check_seq_len; t++) { | |||
| const float *input_ptr = input + t * gru_parm->input_step_; | |||
| float *output_ptr = output + t * gru_parm->output_step_; | |||
| GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, state_reset_weight, | |||
| state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer, gru_parm); | |||
| } | |||
| // zero out extra fw outputs | |||
| for (int t = check_seq_len; t < gru_parm->seq_len_; t++) { | |||
| float *output_ptr = output + t * gru_parm->output_step_; | |||
| for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { | |||
| output_ptr[i] = 0.0f; | |||
| } | |||
| } | |||
| // backward | |||
| if (gru_parm->bidirectional_) { | |||
| input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3; | |||
| input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4; | |||
| input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5; | |||
| state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3; | |||
| state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4; | |||
| state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5; | |||
| float *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_; | |||
| const float *backward_bias = bias + 3 * gru_parm->hidden_size_; | |||
| float *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_; | |||
| for (int t = check_seq_len - 1; t >= 0; t--) { | |||
| const float *input_ptr = input + t * gru_parm->input_step_; | |||
| float *output_ptr = backward_output + t * gru_parm->output_step_; | |||
| GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, | |||
| state_reset_weight, state_update_weight, state_hidden_weight, backward_bias, backward_hidden_state, | |||
| gate_buffer, gru_parm); | |||
| } | |||
| // zero out extra bw outputs | |||
| for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) { | |||
| float *output_ptr = backward_output + t * gru_parm->output_step_; | |||
| for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { | |||
| output_ptr[i] = 0.0f; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ | |||
| #include "nnacl/op_base.h" | |||
| typedef struct GruParameter { | |||
| // Primitive parameter | |||
| OpParameter op_parameter_; | |||
| // shape correlative | |||
| int input_size_; | |||
| int hidden_size_; // output_size | |||
| int seq_len_; | |||
| int batch_; | |||
| // other parameter | |||
| int input_step_; | |||
| int output_step_; | |||
| bool bidirectional_; | |||
| } GruParameter; | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias, | |||
| float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ | |||
| @@ -16,6 +16,7 @@ | |||
| #include "nnacl/fp32/lstm_fp32.h" | |||
| #include <string.h> | |||
| #include <float.h> | |||
| #include "nnacl/fp32/activation_fp32.h" | |||
| #include "nnacl/fp32/arithmetic_fp32.h" | |||
| @@ -79,21 +80,63 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int | |||
| } | |||
| } | |||
| int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(input0 + index); | |||
| float32x4_t vout = vld1q_f32(output + index); | |||
| vout = vmlaq_n_f32(vout, vin0, input1); | |||
| vst1q_f32(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] += input0[index] * input1; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| void UpdataState(float *cell_state, const float *forget_gate, const float *input_gate, const float *cell_gate, | |||
| int batch, int hidden_size) { | |||
| float *state_buffer, int batch, int hidden_size, const float smooth) { | |||
| if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // smooth * old_cell_state | |||
| memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float)); | |||
| ArithmeticParameter parameter; | |||
| parameter.in_elements_num0_ = batch * hidden_size; | |||
| parameter.in_elements_num1_ = 1; | |||
| ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter); | |||
| } | |||
| ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size); | |||
| ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size); | |||
| if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // (1 - smooth) * new_cell_state | |||
| ElementOptMulAcc(cell_state, 1 - smooth, state_buffer, batch * hidden_size); | |||
| } | |||
| } | |||
| void UpdataOutput(const float *cell_state, const float *output_gate, float *hidden_state, int batch, int hidden_size) { | |||
| void UpdataOutput(const float *cell_state, const float *output_gate, float *hidden_state, float *state_buffer_in, | |||
| int batch, int hidden_size, const float smooth) { | |||
| float *state_buffer = state_buffer_in + batch * hidden_size; | |||
| if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { | |||
| memcpy(state_buffer, hidden_state, batch * hidden_size * sizeof(float)); | |||
| ArithmeticParameter parameter; | |||
| parameter.in_elements_num0_ = batch * hidden_size; | |||
| parameter.in_elements_num1_ = 1; | |||
| ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter); | |||
| } | |||
| Tanh(cell_state, batch * hidden_size, hidden_state); | |||
| ElementMul(hidden_state, output_gate, hidden_state, batch * hidden_size); | |||
| if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { | |||
| ElementOptMulAcc(hidden_state, 1 - smooth, state_buffer, batch * hidden_size); | |||
| } | |||
| } | |||
| void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight, | |||
| const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight, | |||
| const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight, | |||
| const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, | |||
| const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, | |||
| const LstmParameter *lstm_parm) { | |||
| InitGate(gate_buffer, bias, lstm_parm); | |||
| @@ -129,17 +172,26 @@ void LstmStepUnit(float *output, const float *input, const float *input_input_we | |||
| // update cell_gate | |||
| Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate); | |||
| // update cell state | |||
| UpdataState(cell_state, forget_gate, input_gate, cell_gate, lstm_parm->batch_, lstm_parm->hidden_size_); | |||
| UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, | |||
| lstm_parm->smooth_); | |||
| // update output_gate | |||
| Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate); | |||
| // update output | |||
| UpdataOutput(cell_state, output_gate, hidden_state, lstm_parm->batch_, lstm_parm->hidden_size_); | |||
| UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, | |||
| lstm_parm->smooth_); | |||
| memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); | |||
| if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) { | |||
| memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); | |||
| memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_, | |||
| lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); | |||
| } | |||
| } | |||
| void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, | |||
| float *hidden_state, float *cell_state, float *gate_buffer, const LstmParameter *lstm_parm) { | |||
| float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, | |||
| const LstmParameter *lstm_parm) { | |||
| // forward | |||
| const float *input_input_weight = weight_i; | |||
| const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; | |||
| @@ -156,7 +208,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float | |||
| float *output_ptr = output + t * lstm_parm->output_step_; | |||
| LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight, | |||
| state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state, | |||
| cell_state, gate_buffer, lstm_parm); | |||
| cell_state, gate_buffer, state_buffer, lstm_parm); | |||
| } | |||
| // backward | |||
| @@ -180,7 +232,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float | |||
| float *output_ptr = backward_output + t * lstm_parm->output_step_; | |||
| LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, | |||
| input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, | |||
| backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, lstm_parm); | |||
| backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, lstm_parm); | |||
| } | |||
| } | |||
| } | |||
| @@ -31,13 +31,24 @@ typedef struct LstmParameter { | |||
| int input_step_; | |||
| int output_step_; | |||
| bool bidirectional_; | |||
| // smooth factor for hidden/cell state calculation: | |||
| // output_hidden = old_hidden * smooth + new_hidden * (1 - smooth) | |||
| // output_cell = old_cell * smooth + new_cell * (1 - smooth) | |||
| float smooth_; | |||
| } LstmParameter; | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size); | |||
| void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size); | |||
| int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size); | |||
| void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, | |||
| float *hidden_state, float *cell_state, float *gate_buffer, const LstmParameter *lstm_parm); | |||
| float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, | |||
| const LstmParameter *lstm_parm); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -262,6 +262,7 @@ union PrimitiveType { | |||
| Merge, | |||
| Mod, | |||
| GeLU, | |||
| Gru, | |||
| } | |||
| enum QuantType: int { | |||
| @@ -1005,6 +1005,11 @@ table OneHot { | |||
| table Lstm{ | |||
| bidirection: bool = false; | |||
| smooth: float = 0.0; | |||
| } | |||
| table Gru{ | |||
| bidirection: bool = false; | |||
| } | |||
| table PriorBox { | |||
| @@ -0,0 +1,121 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/gru.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| bool Gru::GetBidirection() const { return this->primitive_->value.AsGru()->bidirection; } | |||
| void Gru::SetBidirection(bool bidirection) { this->primitive_->value.AsGru()->bidirection = bidirection; } | |||
| #else | |||
| bool Gru::GetBidirection() const { return this->primitive_->value_as_Gru()->bidirection(); } | |||
| int Gru::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_Gru(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Gru return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateGru(*fbb, attr->bidirection()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Gru, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *GruCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Gru>(primitive); } | |||
| Registry GruRegistry(schema::PrimitiveType_Gru, GruCreator); | |||
| #endif | |||
| const int kGruInputNum = 5; | |||
| const int kGruInputWithSeqLenNum = 6; | |||
| const int kGruOutputNum = 2; | |||
| int Gru::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| MS_ASSERT(this->primitive_ != nullptr); | |||
| if ((inputs_.size() != kGruInputNum && inputs_.size() != kGruInputWithSeqLenNum) || | |||
| outputs_.size() != kGruOutputNum) { | |||
| MS_LOG(ERROR) << "OpGru inputs or outputs size error."; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto weight_gate = inputs_.at(1); | |||
| MS_ASSERT(weight_gate != nullptr); | |||
| auto weight_recurrence = inputs_.at(2); | |||
| MS_ASSERT(weight_recurrence != nullptr); | |||
| auto bias = inputs_.at(3); | |||
| MS_ASSERT(bias != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| for (int i = 0; i < kGruOutputNum; i++) { | |||
| outputs_.at(i)->set_data_type(input->data_type()); | |||
| outputs_.at(i)->set_format(input->format()); | |||
| } | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| auto in_shape = input->shape(); // seq_len, batch, input_size | |||
| auto w_gate_shape = weight_gate->shape(); // num_direction, hidden_size * 3, input_size | |||
| auto w_recu_shape = weight_recurrence->shape(); // num_direction, hidden_size * 3, hidden_size | |||
| auto bias_shape = bias->shape(); // num_direction, hidden_size * 6 | |||
| if (in_shape.size() != 3 || w_gate_shape.size() != 3 || w_recu_shape.size() != 3) { | |||
| MS_LOG(ERROR) << "OpGru input dims should be 3."; | |||
| return RET_ERROR; | |||
| } | |||
| if (w_gate_shape[1] != w_recu_shape[1] || w_recu_shape[1] * 2 != bias_shape[1]) { | |||
| MS_LOG(ERROR) << "OpGru w_gate, w_recu and bias hidden size not match."; | |||
| return RET_ERROR; | |||
| } | |||
| if (inputs_.size() == kGruInputWithSeqLenNum) { | |||
| auto seq_len_shape = inputs_.at(5)->shape(); | |||
| if (seq_len_shape[0] > 1) { | |||
| MS_LOG(WARNING) << "OpGru with batch_size > 1 only support all same sequence_len now."; | |||
| return RET_ERROR; | |||
| } | |||
| if (seq_len_shape.size() != 1 && seq_len_shape[0] != in_shape[1]) { | |||
| MS_LOG(ERROR) << "OpGru sequence_len shape[0] and batch_size not match."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| int hidden_size = w_gate_shape[1] / 3; | |||
| // set output | |||
| std::vector<int> out_shape(in_shape); | |||
| out_shape[2] = hidden_size; | |||
| if (GetBidirection()) { | |||
| out_shape.insert(out_shape.begin() + 1, 2); | |||
| } else { | |||
| out_shape.insert(out_shape.begin() + 1, 1); | |||
| } | |||
| output->set_shape(out_shape); | |||
| // set hidden state | |||
| std::vector<int> state_shape(in_shape); | |||
| state_shape[0] = GetBidirection() ? 2 : 1; | |||
| state_shape[2] = hidden_size; | |||
| outputs_[1]->set_shape(state_shape); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_GRU_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_GRU_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| /* | |||
| * gru with linear_before_reset = 0 | |||
| */ | |||
| class Gru : public PrimitiveC { | |||
| public: | |||
| Gru() = default; | |||
| ~Gru() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Gru, PrimitiveC); | |||
| explicit Gru(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetBidirection(bool bidirection); | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| bool GetBidirection() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_OPS_GRU_H_ | |||
| @@ -25,11 +25,16 @@ namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| bool Lstm::GetBidirection() const { return this->primitive_->value.AsLstm()->bidirection; } | |||
| float Lstm::GetSmooth() const { return this->primitive_->value.AsLstm()->smooth; } | |||
| void Lstm::SetBidirection(bool bidirection) { this->primitive_->value.AsLstm()->bidirection = bidirection; } | |||
| void Lstm::SetSmooth(float smooth) { this->primitive_->value.AsLstm()->smooth = smooth; } | |||
| #else | |||
| bool Lstm::GetBidirection() const { return this->primitive_->value_as_Lstm()->bidirection(); } | |||
| float Lstm::GetSmooth() const { return this->primitive_->value_as_Lstm()->smooth(); } | |||
| int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| @@ -38,7 +43,7 @@ int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F | |||
| MS_LOG(ERROR) << "value_as_Lstm return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateLstm(*fbb, attr->bidirection()); | |||
| auto val_offset = schema::CreateLstm(*fbb, attr->bidirection(), attr->smooth()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Lstm, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| @@ -33,12 +33,14 @@ class Lstm : public PrimitiveC { | |||
| MS_DECLARE_PARENT(Lstm, PrimitiveC); | |||
| explicit Lstm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetBidirection(bool bidirection); | |||
| void SetSmooth(float smooth); | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| bool GetBidirection() const; | |||
| float GetSmooth() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/gru.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| #include "nnacl/fp32/gru_fp32.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OpParameter *PopulateGruParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| GruParameter *gru_param = reinterpret_cast<GruParameter *>(malloc(sizeof(GruParameter))); | |||
| if (gru_param == nullptr) { | |||
| MS_LOG(ERROR) << "malloc GruParameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(gru_param, 0, sizeof(GruParameter)); | |||
| gru_param->op_parameter_.type_ = primitive->Type(); | |||
| auto param = reinterpret_cast<mindspore::lite::Gru *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| if (param == nullptr) { | |||
| free(gru_param); | |||
| MS_LOG(ERROR) << "get Gru param nullptr."; | |||
| return nullptr; | |||
| } | |||
| gru_param->bidirectional_ = param->GetBidirection(); | |||
| return reinterpret_cast<OpParameter *>(gru_param); | |||
| } | |||
| Registry GruParameterRegistry(schema::PrimitiveType_Gru, PopulateGruParameter); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -36,6 +36,7 @@ OpParameter *PopulateLstmParameter(const mindspore::lite::PrimitiveC *primitive) | |||
| return nullptr; | |||
| } | |||
| lstm_param->bidirectional_ = param->GetBidirection(); | |||
| lstm_param->smooth_ = param->GetSmooth(); | |||
| return reinterpret_cast<OpParameter *>(lstm_param); | |||
| } | |||
| Registry LstmParameterRegistry(schema::PrimitiveType_Lstm, PopulateLstmParameter); | |||
| @@ -161,6 +161,7 @@ | |||
| #include "src/ops/switch.h" | |||
| #include "src/ops/partial.h" | |||
| #include "src/ops/gelu.h" | |||
| #include "src/ops/gru.h" | |||
| #ifdef SUPPORT_TRAIN | |||
| #include "src/ops/neg_grad.h" | |||
| @@ -995,6 +996,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) AssertOP(primitive); | |||
| case schema::PrimitiveType_GeLU: | |||
| return new (std::nothrow) GeLU(primitive); | |||
| case schema::PrimitiveType_Gru: | |||
| return new (std::nothrow) Gru(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| return new (std::nothrow) ActivationGrad(primitive); | |||
| @@ -0,0 +1,165 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32/gru_fp32.h" | |||
| #include <vector> | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Gru; | |||
| namespace mindspore::kernel { | |||
| void GruCPUKernel::FreeTmpBuffer() { | |||
| if (gate_buffer_ != nullptr) { | |||
| free(gate_buffer_); | |||
| gate_buffer_ = nullptr; | |||
| } | |||
| if (bias_ptr_ != nullptr) { | |||
| free(bias_ptr_); | |||
| bias_ptr_ = nullptr; | |||
| } | |||
| weight_g_ptr_ = nullptr; | |||
| weight_r_ptr_ = nullptr; | |||
| } | |||
| int GruCPUKernel::InitParam() { | |||
| auto input = in_tensors_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| std::vector<int> in_shape = input->shape(); | |||
| gru_parm_->seq_len_ = in_shape.at(0); | |||
| gru_parm_->batch_ = in_shape.at(1); | |||
| gru_parm_->input_size_ = in_shape.at(2); | |||
| auto weight_g = in_tensors_.at(1); | |||
| MS_ASSERT(weight_g != nullptr); | |||
| std::vector<int> w_shape = weight_g->shape(); | |||
| gru_parm_->hidden_size_ = w_shape.at(1) / 3; | |||
| gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_; | |||
| gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_ | |||
| : gru_parm_->batch_ * gru_parm_->hidden_size_; | |||
| return RET_OK; | |||
| } | |||
| int GruCPUKernel::InitBuffer() { | |||
| gate_buffer_ = reinterpret_cast<float *>(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float))); | |||
| if (gate_buffer_ == nullptr) { | |||
| MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int GruCPUKernel::InitWeightBias() { | |||
| auto weight_gate = in_tensors_.at(1); | |||
| MS_ASSERT(weight_gate != nullptr); | |||
| weight_g_ptr_ = reinterpret_cast<float *>(weight_gate->data_c()); | |||
| auto weight_recu = in_tensors_.at(2); | |||
| MS_ASSERT(weight_recu != nullptr); | |||
| weight_r_ptr_ = reinterpret_cast<float *>(weight_recu->data_c()); | |||
| int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_; | |||
| bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float))); | |||
| if (bias_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error."; | |||
| return RET_ERROR; | |||
| } | |||
| auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c()); | |||
| const int state_bias_offset = 3 * gru_parm_->hidden_size_; | |||
| for (int i = 0; i < state_bias_offset; i++) { | |||
| bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; | |||
| } | |||
| if (gru_parm_->bidirectional_) { | |||
| bias_data += 3 * gru_parm_->hidden_size_ * 2; | |||
| auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_; | |||
| for (int i = 0; i < state_bias_offset; i++) { | |||
| backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset]; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int GruCPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int GruCPUKernel::ReSize() { | |||
| FreeTmpBuffer(); | |||
| auto ret = InitParam(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "GruCPUKernel InitParam error."; | |||
| return RET_ERROR; | |||
| } | |||
| ret = InitWeightBias(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "GruCPUKernel InitWeightBias error."; | |||
| FreeTmpBuffer(); | |||
| return RET_ERROR; | |||
| } | |||
| ret = InitBuffer(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "GruCPUKernel InitBuffer error."; | |||
| FreeTmpBuffer(); | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int GruCPUKernel::Run() { | |||
| auto input = in_tensors_.at(kInputIndex); | |||
| MS_ASSERT(input != nullptr); | |||
| auto hidden_state = in_tensors_.at(4); | |||
| MS_ASSERT(hidden_state != nullptr); | |||
| auto output = out_tensors_.at(0); | |||
| MS_ASSERT(output != nullptr); | |||
| auto input_ptr = reinterpret_cast<float *>(input->data_c()); | |||
| MS_ASSERT(input_ptr); | |||
| auto output_ptr = reinterpret_cast<float *>(output->MutableData()); | |||
| MS_ASSERT(output_ptr); | |||
| auto output_hidden_state = out_tensors_[1]; | |||
| memcpy(output_hidden_state->MutableData(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float)); | |||
| int check_seq_len = gru_parm_->seq_len_; | |||
| if (in_tensors_.size() == 6) { | |||
| auto seq_len = reinterpret_cast<int *>(in_tensors_.at(5)->data_c()); | |||
| if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) { | |||
| MS_LOG(ERROR) << "different batch seq_len is currently not supported"; | |||
| return RET_ERROR; | |||
| } | |||
| check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0])); | |||
| } | |||
| MS_ASSERT(weight_g_ptr_); | |||
| MS_ASSERT(weight_r_ptr_); | |||
| MS_ASSERT(bias_ptr_); | |||
| MS_ASSERT(gate_buffer_); | |||
| Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_, | |||
| reinterpret_cast<float *>(output_hidden_state->MutableData()), gate_buffer_, check_seq_len, gru_parm_); | |||
| return RET_OK; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gru, LiteKernelCreator<GruCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/fp32/gru_fp32.h" | |||
| namespace mindspore::kernel { | |||
| class GruCPUKernel : public LiteKernel { | |||
| public: | |||
| GruCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| gru_parm_ = reinterpret_cast<GruParameter *>(op_parameter_); | |||
| } | |||
| ~GruCPUKernel() override { FreeTmpBuffer(); } | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| void FreeTmpBuffer(); | |||
| int InitParam(); | |||
| int InitBuffer(); | |||
| int InitWeightBias(); | |||
| float *gate_buffer_ = nullptr; | |||
| const float *weight_g_ptr_ = nullptr; | |||
| const float *weight_r_ptr_ = nullptr; | |||
| float *bias_ptr_ = nullptr; | |||
| GruParameter *gru_parm_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_ | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32/lstm_fp32.h" | |||
| #include <float.h> | |||
| #include <vector> | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| @@ -32,6 +33,10 @@ void LstmCPUKernel::FreeTmpBuffer() { | |||
| free(gate_buffer_); | |||
| gate_buffer_ = nullptr; | |||
| } | |||
| if (state_buffer_ != nullptr) { | |||
| free(state_buffer_); | |||
| state_buffer_ = nullptr; | |||
| } | |||
| if (weight_i_ptr_ != nullptr) { | |||
| free(weight_i_ptr_); | |||
| weight_i_ptr_ = nullptr; | |||
| @@ -71,6 +76,14 @@ int LstmCPUKernel::InitBuffer() { | |||
| MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; | |||
| return RET_ERROR; | |||
| } | |||
| if (!(lstm_parm_->smooth_ >= -FLT_EPSILON && lstm_parm_->smooth_ <= FLT_EPSILON)) { | |||
| int buffer_size = 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float); | |||
| state_buffer_ = reinterpret_cast<float *>(malloc(buffer_size)); | |||
| if (state_buffer_ == nullptr) { | |||
| MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -173,7 +186,7 @@ int LstmCPUKernel::Run() { | |||
| MS_ASSERT(gate_buffer_); | |||
| Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, | |||
| reinterpret_cast<float *>(output_hidden_state->MutableData()), | |||
| reinterpret_cast<float *>(output_cell_state->MutableData()), gate_buffer_, lstm_parm_); | |||
| reinterpret_cast<float *>(output_cell_state->MutableData()), gate_buffer_, state_buffer_, lstm_parm_); | |||
| return RET_OK; | |||
| } | |||
| @@ -44,6 +44,7 @@ class LstmCPUKernel : public LiteKernel { | |||
| int InitWeightBias(); | |||
| float *gate_buffer_ = nullptr; | |||
| float *state_buffer_ = nullptr; | |||
| float *weight_i_ptr_ = nullptr; | |||
| float *weight_h_ptr_ = nullptr; | |||
| float *bias_ptr_ = nullptr; | |||
| @@ -187,6 +187,9 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | |||
| @@ -0,0 +1 @@ | |||
| decoder_step_201217.pb 5 | |||
| @@ -925,6 +925,41 @@ function Run_arm64() { | |||
| fi | |||
| done < ${models_compatibility_config} | |||
| # Run tf converted models: | |||
| while read line; do | |||
| model_name=${line} | |||
| if [[ $model_name == \#* ]]; then | |||
| continue | |||
| fi | |||
| model_name=`echo ${tf_line_info}|awk -F ' ' '{print $1}'` | |||
| input_num=`echo ${tf_line_info}|awk -F ' ' '{print $2}'` | |||
| input_files='' | |||
| for i in $(seq 1 $input_num) | |||
| do | |||
| input_files=$input_files'/data/local/tmp/input_output/input/'$model_name'.ms_'$i'.bin,' | |||
| done | |||
| echo ${model_name} >> "${run_arm64_log_file}" | |||
| echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out' >> "${run_arm64_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out' >> adb_run_cmd.txt | |||
| adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" | |||
| if [ $? = 0 ]; then | |||
| run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||
| else | |||
| run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||
| fi | |||
| # run benchmark test without clib data | |||
| echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> "${run_arm64_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> adb_run_cmd.txt | |||
| adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" | |||
| if [ $? = 0 ]; then | |||
| run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||
| else | |||
| run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||
| fi | |||
| done < ${models_tf_config} | |||
| # Run tflite converted models: | |||
| while read line; do | |||
| model_name=${line} | |||
| @@ -46,6 +46,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/fusion/batchmatmul_fusion.cc | |||
| ../optimizer/fusion/sigmoid_mul_fusion.cc | |||
| ../optimizer/fusion/conv_conv_fusion.cc | |||
| ../optimizer/fusion/tflite_lstm_cell_fusion.cc | |||
| ../optimizer/fusion/tf_lstm_cell_fusion.cc | |||
| ../optimizer/fusion/bidirection_tf_gru_cell_fusion.cc | |||
| ../optimizer/graph/weight_format_transform_pass.cc | |||
| ../optimizer/graph/weight_format_hardcode_pass.cc | |||
| ../optimizer/graph/clip_convert_activation_pass.cc | |||
| @@ -29,6 +29,9 @@ | |||
| #include "tools/optimizer/fusion/batchmatmul_fusion.h" | |||
| #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | |||
| #include "tools/optimizer/fusion/conv_conv_fusion.h" | |||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||
| #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" | |||
| #include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" | |||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | |||
| #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" | |||
| #include "tools/optimizer/graph/identity_remove_pass.h" | |||
| @@ -114,6 +117,9 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::BiDirectionTfGruCellFusion>()); | |||
| } | |||
| auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>(); | |||
| weight_format_hardcode_pass->SetFmkType(config->fmk); | |||
| @@ -572,7 +572,12 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C | |||
| if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { | |||
| type = TypeIdToType(kObjectTypeTensorType); | |||
| } | |||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vector)); | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape_vector); | |||
| if (abstract == nullptr) { | |||
| MS_LOG(ERROR) << "create AbstractTensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| anf_node->set_abstract(abstract); | |||
| anf_node_map->insert(std::pair(op.name(), anf_node)); | |||
| } else { | |||
| AbstractBasePtrList abstractList; | |||
| @@ -589,6 +594,12 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C | |||
| std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue}; | |||
| CNodePtr getItemCNode = anf_graph->NewCNode(inputs); | |||
| std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||
| if (abstract == nullptr) { | |||
| MS_LOG(ERROR) << "create AbstractTensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| getItemCNode->set_abstract(abstract); | |||
| getItemCNode->set_fullname_with_scope(output_item_name); | |||
| anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); | |||
| } | |||
| @@ -63,7 +63,11 @@ STATUS TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| *output_size = 1; | |||
| return AddOpInput(tf_op, 0, inputs); | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| return AddOpInput(tf_op, 1, inputs); | |||
| } | |||
| TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser()); | |||
| } // namespace lite | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_select_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFSelectParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF SelectParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::SwitchT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Switch; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfSelectParser("Select", new TFSelectParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFSelectParser : public TFNodeParser { | |||
| public: | |||
| TFSelectParser() = default; | |||
| ~TFSelectParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_ | |||
| @@ -122,7 +122,7 @@ std::string TensorFlowUtils::GetFlattenNodeName(const std::string &input_name) { | |||
| std::sregex_token_iterator()); | |||
| std::string ret = input_name; | |||
| if (input_splits.size() == 3) { | |||
| if (input_splits[2] == "0") { | |||
| if (input_splits[2].compare("0") == 0) { | |||
| ret = input_splits[0]; | |||
| } else { | |||
| ret = input_splits[0] + ":" + input_splits[2]; // multi output node | |||
| @@ -0,0 +1,679 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" | |||
| #include <memory> | |||
| #include <functional> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/common/utils.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "securec/include/securec.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kWhileCommonInputsLength = 2; | |||
| constexpr size_t kWhileUniqInputsLength = 6; | |||
| constexpr size_t kCondNodesNum = 12; | |||
| constexpr size_t kCondCNodesNum = 4; | |||
| constexpr size_t kBodyNodesNum = 69; | |||
| constexpr size_t kBodyCNodesNum = 25; | |||
| const auto &p1 = std::placeholders::_1; | |||
| bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); } | |||
| bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) { | |||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||
| return opt::GetCNodeType(n) == type; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, bool multigraph) | |||
| : PatternProcessPass(name, multigraph) { | |||
| /* | |||
| * vars for while input | |||
| * common: | |||
| * 0:const0 1:init_state | |||
| * fw_while_inputs: | |||
| * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias | |||
| * bw_while_inputs: | |||
| * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias | |||
| */ | |||
| for (size_t i = 0; i < kWhileCommonInputsLength; ++i) { | |||
| common_vars_.emplace_back(std::make_shared<Var>()); | |||
| } | |||
| for (size_t i = 0; i < kWhileUniqInputsLength; ++i) { | |||
| fw_vars_.emplace_back(std::make_shared<Var>()); | |||
| bw_vars_.emplace_back(std::make_shared<Var>()); | |||
| } | |||
| input_ = std::make_shared<Var>(); | |||
| input_length_ = std::make_shared<Var>(); | |||
| transpose_input_ = std::make_shared<Var>(); | |||
| } | |||
| const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { | |||
| auto const1 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto ele_shape = std::make_shared<CondVar>(IsParameterNode); | |||
| // forward | |||
| auto fw_max1 = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_}); | |||
| auto fw_max2 = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, fw_max1}); | |||
| auto fw_shape = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_}); | |||
| auto fw_stride = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), fw_shape}); | |||
| auto fw_min = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2}); | |||
| auto fw_reserve = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape, | |||
| fw_stride}); | |||
| auto fw_from_tensor = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), | |||
| transpose_input_, ele_shape}); | |||
| auto is_fw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While)); | |||
| auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], common_vars_[0], fw_stride, common_vars_[0], | |||
| fw_reserve, common_vars_[1], fw_min, fw_from_tensor, input_length_}); | |||
| fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end()); | |||
| fw_while.emplace_back(common_vars_[1]); | |||
| auto fw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)), | |||
| fw_while, std::make_shared<Var>()}); | |||
| auto fw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)), | |||
| fw_get_item, ele_shape}); | |||
| auto fw_out_trans = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), fw_stack}); | |||
| // backward | |||
| auto bw_reverse_seq = VectorRef( | |||
| {std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), input_, input_length_}); | |||
| auto bw_max1 = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_}); | |||
| auto bw_max2 = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, bw_max1}); | |||
| auto bw_trans = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_reverse_seq}); | |||
| auto bw_shape = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans}); | |||
| auto bw_stride = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), bw_shape}); | |||
| auto bw_min = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2}); | |||
| auto bw_reserve = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape, | |||
| bw_stride}); | |||
| auto bw_from_tensor = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), bw_trans, | |||
| ele_shape}); | |||
| auto is_bw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While)); | |||
| auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], common_vars_[0], bw_stride, common_vars_[0], | |||
| bw_reserve, common_vars_[1], bw_min, bw_from_tensor, input_length_}); | |||
| bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end()); | |||
| bw_while.emplace_back(common_vars_[1]); | |||
| auto bw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)), | |||
| bw_while, std::make_shared<Var>()}); | |||
| auto bw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)), | |||
| bw_get_item, ele_shape}); | |||
| auto bw_out_trans = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_stack}); | |||
| auto bw_reverse1 = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), bw_out_trans, | |||
| input_length_}); | |||
| auto concat = VectorRef( | |||
| {std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Concat)), fw_out_trans, bw_reverse1}); | |||
| return concat; | |||
| } | |||
| AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||
| auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto is_parameter4 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); | |||
| auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); | |||
| auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_LogicalAnd)); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); | |||
| VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); | |||
| VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); | |||
| VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); | |||
| VectorRef return_ref = VectorRef({is_return, logicaland_ref}); | |||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||
| auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true); | |||
| return pattern; | |||
| } | |||
| AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||
| std::vector<CondVarPtr> placeholders; | |||
| for (int i = 0; i < 13; ++i) { | |||
| placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode)); | |||
| } | |||
| VectorRef add = VectorRef({std::make_shared<Var>(), placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||
| VectorRef add1 = VectorRef({std::make_shared<Var>(), placeholders[0], std::make_shared<CondVar>(IsParameterNode)}); | |||
| VectorRef get_item = VectorRef( | |||
| {std::make_shared<Var>("GetItem"), placeholders[6], placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||
| VectorRef concat_input_h = VectorRef({std::make_shared<Var>(), get_item, placeholders[4]}); | |||
| VectorRef matmul1 = VectorRef({std::make_shared<Var>("Matmul"), concat_input_h, placeholders[8]}); | |||
| VectorRef biasadd1 = VectorRef({std::make_shared<Var>("BiasAdd"), matmul1, placeholders[9]}); | |||
| VectorRef sigmoid1 = VectorRef({std::make_shared<Var>("Sigmoid"), biasadd1}); | |||
| VectorRef split = VectorRef({std::make_shared<Var>("Split"), sigmoid1}); | |||
| VectorRef get_item1 = VectorRef({std::make_shared<Var>("TupleGetItem"), split, std::make_shared<Var>()}); | |||
| VectorRef get_item2 = VectorRef({std::make_shared<Var>("TupleGetItem"), split, std::make_shared<Var>()}); | |||
| VectorRef pre_reset = VectorRef({std::make_shared<Var>("Mul"), get_item1, placeholders[4]}); | |||
| VectorRef concat2 = VectorRef({std::make_shared<Var>("Concat"), get_item, pre_reset}); | |||
| VectorRef matmul2 = VectorRef({std::make_shared<Var>("Matmul"), concat2, placeholders[10]}); | |||
| VectorRef biasadd2 = VectorRef({std::make_shared<Var>("BiasAdd"), matmul2, placeholders[11]}); | |||
| VectorRef tanh = VectorRef({std::make_shared<Var>("Tanh"), biasadd2}); | |||
| VectorRef update_hidden = VectorRef({std::make_shared<Var>("Mul"), get_item2, placeholders[4]}); | |||
| VectorRef minus_update = | |||
| VectorRef({std::make_shared<Var>("Sub"), std::make_shared<CondVar>(IsParameterNode), get_item2}); | |||
| VectorRef updated = VectorRef({std::make_shared<Var>("Mul"), minus_update, tanh}); | |||
| VectorRef new_hidden = VectorRef({std::make_shared<Var>("Add"), update_hidden, updated}); | |||
| VectorRef greater_equal = VectorRef({std::make_shared<Var>("GreaterEqual"), placeholders[2], placeholders[7]}); | |||
| VectorRef select_output = VectorRef({std::make_shared<Var>("Switch"), greater_equal, placeholders[12], new_hidden}); | |||
| VectorRef output = VectorRef({std::make_shared<Var>("SetItem"), placeholders[3], placeholders[2], select_output}); | |||
| VectorRef select_hidden = VectorRef({std::make_shared<Var>("Switch"), greater_equal, placeholders[4], new_hidden}); | |||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple)); | |||
| std::vector<BaseRef> outputs = {is_make_tuple, add1, placeholders[1], add, | |||
| output, select_hidden, placeholders[5], placeholders[6], | |||
| placeholders[7]}; | |||
| outputs.insert(outputs.end(), placeholders.begin() + 8, placeholders.end()); | |||
| VectorRef make_tuple_node = VectorRef(outputs); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); | |||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | |||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||
| auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true); | |||
| return pattern; | |||
| } | |||
| ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const { | |||
| MS_ASSERT(parameter_anf != nullptr); | |||
| if (!utils::isa<ParameterPtr>(parameter_anf)) { | |||
| MS_LOG(DEBUG) << "parameter_anf is not ParameterPtr"; | |||
| return nullptr; | |||
| } | |||
| auto parameter = utils::cast<ParameterPtr>(parameter_anf); | |||
| if (!parameter->has_default()) { | |||
| MS_LOG(DEBUG) << "parameter not have default value"; | |||
| return nullptr; | |||
| } | |||
| auto param_value = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); | |||
| return param_value; | |||
| } | |||
| STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, | |||
| const AnfNodePtr &bw_cand_kernel_anf, int *input_size, | |||
| int *hidden_size) const { | |||
| MS_ASSERT(fw_cand_kernel != nullptr); | |||
| MS_ASSERT(bw_cand_kernel != nullptr); | |||
| MS_ASSERT(input_size != nullptr); | |||
| MS_ASSERT(hidden_size != nullptr); | |||
| auto fw_cand_kernel_value = GetDefaultParamValue(fw_cand_kernel_anf); | |||
| if (fw_cand_kernel_value == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| auto fw_cand_kernel_shape = fw_cand_kernel_value->tensor_shape(); | |||
| if (fw_cand_kernel_shape.size() != 2) { | |||
| return RET_ERROR; | |||
| } | |||
| auto bw_cand_kernel_value = GetDefaultParamValue(bw_cand_kernel_anf); | |||
| if (bw_cand_kernel_value == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| auto bw_cand_kernel_shape = bw_cand_kernel_value->tensor_shape(); | |||
| if (bw_cand_kernel_shape.size() != 2) { | |||
| return RET_ERROR; | |||
| } | |||
| if (fw_cand_kernel_shape != bw_cand_kernel_shape) { | |||
| return RET_ERROR; | |||
| } | |||
| if (fw_cand_kernel_shape[1] <= 0 || fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1] <= 0) { | |||
| MS_LOG(DEBUG) << "gru input size or hidden size illegal"; | |||
| return RET_ERROR; | |||
| } | |||
| *hidden_size = fw_cand_kernel_shape[1]; | |||
| *input_size = fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1]; | |||
| return RET_OK; | |||
| } | |||
| ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name, | |||
| const std::vector<int> &shape, const TypeId type, | |||
| void **tensor_data) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(tensor_data != nullptr); | |||
| auto parameter = func_graph->add_parameter(); | |||
| parameter->set_name(name); | |||
| std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape_vector); | |||
| if (abstract_tensor == nullptr) { | |||
| return nullptr; | |||
| } | |||
| parameter->set_abstract(abstract_tensor); | |||
| auto gate_weight_default = std::make_shared<ParamValueLite>(); | |||
| if (gate_weight_default == nullptr) { | |||
| MS_LOG(ERROR) << "gate_weight_default is nullptr"; | |||
| return nullptr; | |||
| } | |||
| gate_weight_default->set_tensor_shape(shape); | |||
| gate_weight_default->set_tensor_type(type); | |||
| gate_weight_default->set_format(schema::Format_NHWC); | |||
| int data_len = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); | |||
| int data_size = 0; | |||
| if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) { | |||
| data_size = data_len * sizeof(float); | |||
| *tensor_data = new (std::nothrow) float[data_len]; | |||
| } else if (type == kNumberTypeInt || type == kNumberTypeInt32) { | |||
| data_size = data_len * sizeof(int); | |||
| *tensor_data = new (std::nothrow) int[data_len]; | |||
| } else { | |||
| MS_LOG(DEBUG) << "unsupported data type"; | |||
| return nullptr; | |||
| } | |||
| if (*tensor_data == nullptr) { | |||
| MS_LOG(ERROR) << "new data failed"; | |||
| return nullptr; | |||
| } | |||
| gate_weight_default->SetTensorData(*tensor_data, data_size); | |||
| parameter->set_default_param(gate_weight_default); | |||
| return parameter; | |||
| } | |||
| void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, | |||
| const int r1, const int c0, const int c1, float *data, | |||
| bool t) const { | |||
| MS_ASSERT(mat != nullptr); | |||
| MS_ASSERT(data != nullptr); | |||
| MS_ASSERT(0 <= r0 && r0 < r1 && r1 <= R); | |||
| MS_ASSERT(0 <= c0 && c0 < c1 && c1 <= C); | |||
| const int RT = r1 - r0; | |||
| const int CT = c1 - c0; | |||
| for (int i = r0; i < r1; ++i) { | |||
| for (int j = c0; j < c1; ++j) { | |||
| if (t) { | |||
| data[(j - c0) * RT + (i - r0)] = mat[i * C + j]; | |||
| } else { | |||
| data[(i - r0) * CT + (j - c0)] = mat[i * C + j]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, | |||
| const int input_size, const int hidden_size, | |||
| float *gate_tensor_data, float *recu_tensor_data) const { | |||
| MS_ASSERT(gate_weight != nullptr); | |||
| MS_ASSERT(cand_weight != nullptr); | |||
| MS_ASSERT(gate_tensor_data != nullptr); | |||
| MS_ASSERT(recu_tensor_data != nullptr); | |||
| const std::vector<int> gate_shape{input_size + hidden_size, hidden_size * 2}; | |||
| const std::vector<int> cand_shape{hidden_size * 2, hidden_size}; | |||
| auto gate_weight_value = GetDefaultParamValue(gate_weight); | |||
| if (gate_weight_value == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| auto gate_weight_data = reinterpret_cast<float *>(gate_weight_value->tensor_addr()); | |||
| if (gate_weight_data == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| auto gate_weight_shape = gate_weight_value->tensor_shape(); | |||
| auto cand_weight_value = GetDefaultParamValue(cand_weight); | |||
| if (cand_weight_value == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| auto cand_weight_data = reinterpret_cast<float *>(cand_weight_value->tensor_addr()); | |||
| if (cand_weight_data == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| auto cand_weight_shape = cand_weight_value->tensor_shape(); | |||
| if (gate_weight_shape != gate_shape || cand_weight_shape != cand_shape) { | |||
| return RET_ERROR; | |||
| } | |||
| // input_update_weight | |||
| CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, 0, input_size, hidden_size, | |||
| hidden_size * 2, gate_tensor_data, true); | |||
| // input_reset_weight | |||
| CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, 0, input_size, 0, hidden_size, | |||
| gate_tensor_data + input_size * hidden_size, true); | |||
| // input_hidden_weight | |||
| CopyFlattenMatData(cand_weight_data, input_size + hidden_size, hidden_size, 0, input_size, 0, hidden_size, | |||
| gate_tensor_data + input_size * hidden_size * 2, true); | |||
| // state_update_weight | |||
| CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, input_size, input_size + hidden_size, | |||
| hidden_size, hidden_size * 2, recu_tensor_data, true); | |||
| // state_reset_weight | |||
| CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, input_size, input_size + hidden_size, | |||
| 0, hidden_size, recu_tensor_data + hidden_size * hidden_size, true); | |||
| // state_hidden_weight | |||
| CopyFlattenMatData(cand_weight_data, input_size + hidden_size, hidden_size, input_size, input_size + hidden_size, 0, | |||
| hidden_size, recu_tensor_data + hidden_size * hidden_size * 2, true); | |||
| return RET_OK; | |||
| } | |||
| STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, | |||
| const int hidden_size, float *tensor_data) const { | |||
| MS_ASSERT(bias != nullptr); | |||
| MS_ASSERT(tensor_data != nullptr); | |||
| std::vector<int> gate_shape{hidden_size * 2}; | |||
| std::vector<int> cand_shape{hidden_size}; | |||
| auto gate_bias_value = GetDefaultParamValue(gate_bias); | |||
| if (gate_bias_value == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| auto gate_bias_data = reinterpret_cast<float *>(gate_bias_value->tensor_addr()); | |||
| auto gate_bias_shape = gate_bias_value->tensor_shape(); | |||
| auto cand_bias_value = GetDefaultParamValue(cand_bias); | |||
| if (cand_bias_value == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| auto cand_bias_data = reinterpret_cast<float *>(cand_bias_value->tensor_addr()); | |||
| auto cand_bias_shape = cand_bias_value->tensor_shape(); | |||
| if (gate_bias_shape != gate_shape || cand_bias_shape != cand_shape) { | |||
| return RET_ERROR; | |||
| } | |||
| // update_gate bias | |||
| CopyFlattenMatData(gate_bias_data, 1, hidden_size * 2, 0, 1, hidden_size, hidden_size * 2, tensor_data, false); | |||
| // reset_gate bias | |||
| CopyFlattenMatData(gate_bias_data, 1, hidden_size * 2, 0, 1, 0, hidden_size, tensor_data + hidden_size, false); | |||
| // hidden_gate bias | |||
| CopyFlattenMatData(cand_bias_data, 1, hidden_size, 0, 1, 0, hidden_size, tensor_data + hidden_size * 2, false); | |||
| return RET_OK; | |||
| } | |||
| CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &hidden_state, | |||
| const std::string base_name) const { | |||
| MS_ASSERT(func_graph); | |||
| MS_ASSERT(hidden_state); | |||
| auto stack_primitive = std::make_unique<schema::PrimitiveT>(); | |||
| std::unique_ptr<schema::StackT> attr = std::make_unique<schema::StackT>(); | |||
| attr->axis = 0; | |||
| stack_primitive->value.type = schema::PrimitiveType_Stack; | |||
| stack_primitive->value.value = attr.release(); | |||
| auto stack_cvalue = lite::PrimitiveC::Create(stack_primitive.release()); | |||
| auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(stack_cvalue)); | |||
| std::vector<AnfNodePtr> new_node_inputs = {value_node, hidden_state, hidden_state}; | |||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||
| new_node->set_abstract(hidden_state->abstract()->Clone()); | |||
| new_node->set_fullname_with_scope("stack_hidden_" + base_name); | |||
| return new_node; | |||
| } | |||
| CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||
| const EquivPtr &equiv, const EquivPtr &fw_body_equiv, | |||
| const EquivPtr &bw_body_equiv, | |||
| const std::string &base_name) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(input != nullptr); | |||
| MS_ASSERT(equiv != nullptr); | |||
| MS_ASSERT(fw_body_equiv != nullptr); | |||
| MS_ASSERT(bw_body_equiv != nullptr); | |||
| auto gru_primitive = std::make_unique<schema::PrimitiveT>(); | |||
| std::unique_ptr<schema::GruT> attr = std::make_unique<schema::GruT>(); | |||
| attr->bidirection = true; | |||
| gru_primitive->value.type = schema::PrimitiveType_Gru; | |||
| gru_primitive->value.value = attr.release(); | |||
| auto gru_cvalue = lite::PrimitiveC::Create(gru_primitive.release()); | |||
| auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(gru_cvalue)); | |||
| auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[2]]); | |||
| MS_ASSERT(fw_gate_kernel); | |||
| auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[3]]); | |||
| MS_ASSERT(fw_gate_bias); | |||
| auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[4]]); | |||
| MS_ASSERT(fw_cand_kernel); | |||
| auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[5]]); | |||
| MS_ASSERT(fw_cand_bias); | |||
| auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[2]]); | |||
| MS_ASSERT(bw_gate_kernel); | |||
| auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[3]]); | |||
| MS_ASSERT(bw_gate_bias); | |||
| auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[4]]); | |||
| MS_ASSERT(bw_cand_kernel); | |||
| auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[5]]); | |||
| MS_ASSERT(bw_cand_bias); | |||
| auto hidden = utils::cast<AnfNodePtr>((*equiv)[common_vars_[1]]); | |||
| MS_ASSERT(hidden); | |||
| auto stacked_hidden = GetStackedHiddenState(func_graph, hidden, base_name); | |||
| if (stacked_hidden == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto input_length = utils::cast<AnfNodePtr>((*equiv)[input_length_]); | |||
| MS_ASSERT(hidden); | |||
| int input_size = 0; | |||
| int hidden_size = 0; | |||
| auto status = GetInputAndHiddenSize(fw_cand_kernel, bw_cand_kernel, &input_size, &hidden_size); | |||
| if (status != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| std::vector<int> gate_weight_shape{2, hidden_size * 3, input_size}; | |||
| float *gate_tensor_data = nullptr; | |||
| auto gate_weight = AddDefaultParameter(func_graph, base_name + "_gate_weight", gate_weight_shape, kNumberTypeFloat32, | |||
| reinterpret_cast<void **>(&gate_tensor_data)); | |||
| if (gate_weight == nullptr) { | |||
| return nullptr; | |||
| } | |||
| std::vector<int> recu_weight_shape{2, hidden_size * 3, hidden_size}; | |||
| float *recu_tensor_data = nullptr; | |||
| auto recu_weight = AddDefaultParameter(func_graph, base_name + "_cand_weight", recu_weight_shape, kNumberTypeFloat32, | |||
| reinterpret_cast<void **>(&recu_tensor_data)); | |||
| if (recu_weight == nullptr) { | |||
| return nullptr; | |||
| } | |||
| std::vector<int> bias_shape{2, hidden_size * 6}; | |||
| float *bias_tensor_data = nullptr; | |||
| auto bias = AddDefaultParameter(func_graph, base_name + "_bias", bias_shape, kNumberTypeFloat32, | |||
| reinterpret_cast<void **>(&bias_tensor_data)); | |||
| if (bias == nullptr) { | |||
| return nullptr; | |||
| } | |||
| for (int i = 0; i < 2 * hidden_size * 6; ++i) { | |||
| bias_tensor_data[i] = 0.0f; | |||
| } | |||
| if (ConvertWeightData(fw_gate_kernel, fw_cand_kernel, input_size, hidden_size, gate_tensor_data, recu_tensor_data) != | |||
| RET_OK) { | |||
| return nullptr; | |||
| } | |||
| auto gate_data_diff = hidden_size * input_size * 3; | |||
| auto recu_data_diff = hidden_size * hidden_size * 3; | |||
| if (ConvertWeightData(bw_gate_kernel, bw_cand_kernel, input_size, hidden_size, gate_tensor_data + gate_data_diff, | |||
| recu_tensor_data + recu_data_diff) != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| if (ConvertBiasData(fw_gate_bias, fw_cand_bias, hidden_size, bias_tensor_data) != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| auto bias_data_diff = hidden_size * 6; | |||
| if (ConvertBiasData(bw_gate_bias, bw_cand_bias, hidden_size, bias_tensor_data + bias_data_diff) != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> new_node_inputs = {value_node, input, gate_weight, recu_weight, | |||
| bias, stacked_hidden, input_length}; | |||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||
| new_node->set_fullname_with_scope(base_name); | |||
| return new_node; | |||
| } | |||
| CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, | |||
| const std::string base_name) const { | |||
| MS_ASSERT(func_graph); | |||
| MS_ASSERT(gru_output); | |||
| auto split_primitive = std::make_unique<schema::PrimitiveT>(); | |||
| std::unique_ptr<schema::SplitT> split_attr = std::make_unique<schema::SplitT>(); | |||
| split_attr->numberSplit = 2; | |||
| split_attr->splitDim = 1; | |||
| split_primitive->value.type = schema::PrimitiveType_Split; | |||
| split_primitive->value.value = split_attr.release(); | |||
| auto split_cvalue = lite::PrimitiveC::Create(split_primitive.release()); | |||
| auto split_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(split_cvalue)); | |||
| std::vector<AnfNodePtr> new_node_inputs = {split_value_node, gru_output}; | |||
| auto split_new_node = func_graph->NewCNode(new_node_inputs); | |||
| split_new_node->set_fullname_with_scope("split_" + base_name); | |||
| if (TfliteLstmCellFusion::SetAbstractTuple(split_new_node, 2) != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| auto split_out1 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 0); | |||
| if (split_out1 == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto split_out2 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 1); | |||
| if (split_out2 == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto concat_primitive = std::make_unique<schema::PrimitiveT>(); | |||
| std::unique_ptr<schema::ConcatT> concat_attr = std::make_unique<schema::ConcatT>(); | |||
| concat_attr->axis = 3; | |||
| concat_primitive->value.type = schema::PrimitiveType_Concat; | |||
| concat_primitive->value.value = concat_attr.release(); | |||
| auto concat_cvalue = lite::PrimitiveC::Create(concat_primitive.release()); | |||
| auto concat_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(concat_cvalue)); | |||
| std::vector<AnfNodePtr> concat_new_node_inputs = {concat_value_node, split_out1, split_out2}; | |||
| auto concat_new_node = func_graph->NewCNode(concat_new_node_inputs); | |||
| concat_new_node->set_fullname_with_scope("concat_" + base_name); | |||
| concat_new_node->set_abstract(gru_output->abstract()->Clone()); | |||
| auto squeeze_primitive = std::make_unique<schema::PrimitiveT>(); | |||
| std::unique_ptr<schema::SqueezeT> squeeze_attr = std::make_unique<schema::SqueezeT>(); | |||
| squeeze_attr->axis = std::vector<int>{1}; | |||
| squeeze_primitive->value.type = schema::PrimitiveType_Squeeze; | |||
| squeeze_primitive->value.value = squeeze_attr.release(); | |||
| auto squeeze_cvalue = lite::PrimitiveC::Create(squeeze_primitive.release()); | |||
| auto squeeze_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(squeeze_cvalue)); | |||
| std::vector<AnfNodePtr> squeeze_new_node_inputs = {squeeze_value_node, concat_new_node}; | |||
| auto squeeze_new_node = func_graph->NewCNode(squeeze_new_node_inputs); | |||
| squeeze_new_node->set_fullname_with_scope("squeeze_" + base_name); | |||
| squeeze_new_node->set_abstract(gru_output->abstract()->Clone()); | |||
| auto transpose_primitive = std::make_unique<schema::PrimitiveT>(); | |||
| std::unique_ptr<schema::TransposeT> transpose_attr = std::make_unique<schema::TransposeT>(); | |||
| transpose_attr->perm = std::vector<int>{1, 0, 2}; | |||
| transpose_primitive->value.type = schema::PrimitiveType_Transpose; | |||
| transpose_primitive->value.value = transpose_attr.release(); | |||
| auto transpose_cvalue = lite::PrimitiveC::Create(transpose_primitive.release()); | |||
| auto transpose_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(transpose_cvalue)); | |||
| std::vector<AnfNodePtr> transpose_new_node_inputs = {transpose_value_node, squeeze_new_node}; | |||
| auto transpose_new_node = func_graph->NewCNode(transpose_new_node_inputs); | |||
| transpose_new_node->set_fullname_with_scope("transpose_" + base_name); | |||
| transpose_new_node->set_abstract(gru_output->abstract()->Clone()); | |||
| return transpose_new_node; | |||
| } | |||
| const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node, | |||
| const EquivPtr &equiv) const { | |||
| MS_ASSERT(func_graph); | |||
| MS_ASSERT(concat_node); | |||
| MS_LOG(DEBUG) << "bidirection tf gru fusion pass"; | |||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return nullptr; | |||
| } | |||
| auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]); | |||
| MS_ASSERT(transpose_input); | |||
| if (!utils::isa<CNodePtr>(transpose_input) || GetCNodeType(transpose_input) != schema::PrimitiveType_Transpose) { | |||
| return nullptr; | |||
| } | |||
| PrimitiveVarMapPtr fw_cond_primitive_vars = std::make_shared<PrimitiveVarMap>(); | |||
| auto fw_cond_graph_pattern = GetCondGraphPattern(fw_cond_primitive_vars); | |||
| auto fw_cond = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[0]]); | |||
| MS_ASSERT(fw_cond != nullptr); | |||
| auto fw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_cond_graph_pattern, fw_cond_primitive_vars, | |||
| fw_cond, kCondCNodesNum, kCondNodesNum); | |||
| if (fw_cond_equiv == nullptr || fw_cond_equiv->empty()) { | |||
| return nullptr; | |||
| } | |||
| PrimitiveVarMapPtr bw_cond_primitive_vars = std::make_shared<PrimitiveVarMap>(); | |||
| auto bw_cond_graph_pattern = GetCondGraphPattern(bw_cond_primitive_vars); | |||
| auto bw_cond = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[0]]); | |||
| MS_ASSERT(bw_cond != nullptr); | |||
| auto bw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_cond_graph_pattern, bw_cond_primitive_vars, | |||
| bw_cond, kCondCNodesNum, kCondNodesNum); | |||
| if (bw_cond_equiv == nullptr || bw_cond_equiv->empty()) { | |||
| return nullptr; | |||
| } | |||
| PrimitiveVarMapPtr fw_primitive_vars_body = std::make_shared<PrimitiveVarMap>(); | |||
| auto fw_body_graph_pattern = GetBodyGraphPattern(fw_primitive_vars_body); | |||
| auto fw_body = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[1]]); | |||
| MS_ASSERT(fw_body != nullptr); | |||
| auto fw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_body_graph_pattern, fw_primitive_vars_body, | |||
| fw_body, kBodyCNodesNum, kBodyNodesNum); | |||
| if (fw_body_equiv == nullptr || fw_body_equiv->empty()) { | |||
| return nullptr; | |||
| } | |||
| PrimitiveVarMapPtr bw_primitive_vars_body = std::make_shared<PrimitiveVarMap>(); | |||
| auto bw_body_graph_pattern = GetBodyGraphPattern(bw_primitive_vars_body); | |||
| auto bw_body = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[1]]); | |||
| MS_ASSERT(bw_body != nullptr); | |||
| auto bw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_body_graph_pattern, bw_primitive_vars_body, | |||
| bw_body, kBodyCNodesNum, kBodyNodesNum); | |||
| if (bw_body_equiv == nullptr || bw_body_equiv->empty()) { | |||
| return nullptr; | |||
| } | |||
| const std::string gru_name = "gru_" + concat_node->fullname_with_scope(); | |||
| auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, fw_body_equiv, bw_body_equiv, gru_name); | |||
| if (gru_node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0); | |||
| if (get_item_node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope()); | |||
| MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success"; | |||
| return output_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class BiDirectionTfGruCellFusion : public PatternProcessPass { | |||
| public: | |||
| explicit BiDirectionTfGruCellFusion(const std::string &name = "bidirection_tf_gru_cell_fusion", | |||
| bool multigraph = true); | |||
| ~BiDirectionTfGruCellFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| protected: | |||
| virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; | |||
| private: | |||
| AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; | |||
| CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv, | |||
| const EquivPtr &fw_body_equiv, const EquivPtr &bw_body_equiv, | |||
| const std::string &base_name) const; | |||
| ParamValueLitePtr GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const; | |||
| lite::STATUS GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, const AnfNodePtr &bw_cand_kernel_anf, | |||
| int *input_size, int *hidden_size) const; | |||
| ParameterPtr AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name, | |||
| const std::vector<int> &shape, const TypeId type, void **tensor_data) const; | |||
| lite::STATUS ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, const int input_size, | |||
| const int hidden_size, float *gate_tensor_data, float *recu_tensor_data) const; | |||
| lite::STATUS ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, const int hidden_size, | |||
| float *tensor_data) const; | |||
| void CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, const int r1, const int c0, | |||
| const int c1, float *data, bool t = false) const; | |||
| CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &hidden_state, | |||
| const std::string base_name) const; | |||
| CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, | |||
| const std::string base_name) const; | |||
| private: | |||
| std::vector<VarPtr> common_vars_; | |||
| std::vector<VarPtr> fw_vars_; | |||
| std::vector<VarPtr> bw_vars_; | |||
| VarPtr input_; | |||
| VarPtr input_length_; | |||
| VarPtr transpose_input_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ | |||
| @@ -0,0 +1,370 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" | |||
| #include <memory> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "securec/include/securec.h" | |||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kLstmInputsLength = 13; | |||
| constexpr size_t kLstmInputsVarNum = 11; | |||
| constexpr size_t kCondNodesNum = 12; | |||
| constexpr size_t kCondCNodesNum = 4; | |||
| constexpr size_t kBodyNodesNum = 82; | |||
| constexpr size_t kBodyCNodesNum = 30; | |||
| const auto &p1 = std::placeholders::_1; | |||
| bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); } | |||
| bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) { | |||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||
| return opt::GetCNodeType(n) == type; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| TfLstmCellFusion::TfLstmCellFusion(const std::string &name, bool multigraph) | |||
| : TfliteLstmCellFusion(name, multigraph, kLstmInputsLength, kLstmInputsVarNum, kCondNodesNum, kCondCNodesNum, | |||
| kBodyNodesNum, kBodyCNodesNum) { | |||
| /* | |||
| * vars for lstm cell input | |||
| * 0:cond 1:body 2:index 3:limit1 4:output 5:cell 6:hidden 7:limit2 8:input 9:kernel 10:bias | |||
| */ | |||
| forget_bias_ = std::make_shared<Var>(); | |||
| } | |||
| AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||
| std::vector<CondVarPtr> placeholders; | |||
| for (int i = 0; i < 10; ++i) { | |||
| placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode)); | |||
| } | |||
| VectorRef add2 = VectorRef({std::make_shared<Var>(), placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||
| VectorRef add3 = VectorRef({std::make_shared<Var>(), placeholders[0], std::make_shared<CondVar>(IsParameterNode)}); | |||
| VectorRef get_item = VectorRef( | |||
| {std::make_shared<Var>("GetItem"), placeholders[7], placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||
| VectorRef concat_input_h = VectorRef({std::make_shared<Var>(), get_item, placeholders[5]}); | |||
| VectorRef matmul = VectorRef({std::make_shared<Var>(), concat_input_h, placeholders[8]}); | |||
| VectorRef bias = VectorRef({std::make_shared<Var>(), matmul, placeholders[9]}); | |||
| VectorRef split = VectorRef({std::make_shared<Var>(), bias}); | |||
| VectorRef get_item1 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()}); | |||
| VectorRef get_item2 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()}); | |||
| VectorRef get_item3 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()}); | |||
| VectorRef get_item4 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()}); | |||
| VectorRef input_gate = VectorRef({std::make_shared<Var>("Sigmoid"), get_item1}); | |||
| VectorRef input_to_cell = VectorRef({std::make_shared<Var>("Tanh"), get_item2}); | |||
| VectorRef forget_bias = VectorRef({std::make_shared<Var>("Add"), get_item3, forget_bias_}); | |||
| VectorRef forget_gate = VectorRef({std::make_shared<Var>("Sigmoid"), forget_bias}); | |||
| VectorRef output_gate = VectorRef({std::make_shared<Var>("Sigmoid"), get_item4}); | |||
| VectorRef forgetted_cell = VectorRef({std::make_shared<Var>(""), forget_gate, placeholders[4]}); | |||
| VectorRef inputted_cell = VectorRef({std::make_shared<Var>(""), input_gate, input_to_cell}); | |||
| VectorRef input_forget_cell = VectorRef({std::make_shared<Var>("Add"), forgetted_cell, inputted_cell}); | |||
| VectorRef to_new_hidden = VectorRef({std::make_shared<Var>("Tanh"), input_forget_cell}); | |||
| VectorRef new_hidden = VectorRef({std::make_shared<Var>("Mul"), output_gate, to_new_hidden}); | |||
| VectorRef new_to_cell = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_new_, input_forget_cell}); | |||
| VectorRef old_to_cell = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_old_, placeholders[4]}); | |||
| VectorRef output_cell = VectorRef({std::make_shared<Var>("Add"), new_to_cell, old_to_cell}); | |||
| VectorRef new_to_hidden = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_new_, new_hidden}); | |||
| VectorRef old_to_hidden = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_old_, placeholders[5]}); | |||
| VectorRef output_hidden = VectorRef({std::make_shared<Var>("Add"), new_to_hidden, old_to_hidden}); | |||
| VectorRef set_item = VectorRef({std::make_shared<Var>(""), placeholders[3], placeholders[2], new_hidden}); | |||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple)); | |||
| std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, output_cell, output_hidden}; | |||
| outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); | |||
| VectorRef make_tuple_node = VectorRef(outputs); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); | |||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | |||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||
| auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true); | |||
| return pattern; | |||
| } | |||
| STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector<int> &shape, | |||
| const float *const data_ptr, const int hidden_size) const { | |||
| MS_ASSERT(weight != nullptr); | |||
| MS_ASSERT(data_ptr != nullptr); | |||
| auto default_param = std::make_shared<ParamValueLite>(); | |||
| if (default_param == nullptr) { | |||
| MS_LOG(ERROR) << "new_default is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| default_param->set_tensor_shape(shape); | |||
| default_param->set_tensor_type(kNumberTypeFloat32); | |||
| default_param->set_format(schema::Format_NHWC); | |||
| if (shape.size() != 3) { | |||
| MS_LOG(ERROR) << "lstm weight shape must have 3 dims"; | |||
| return RET_ERROR; | |||
| } | |||
| const auto param_num = shape[0] * shape[1] * shape[2]; | |||
| auto tensor_data = new (std::nothrow) float[param_num * 4]; | |||
| std::vector<int> data_diff{0, 3, 2, 1}; | |||
| if (tensor_data == nullptr) { | |||
| MS_LOG(DEBUG) << "new data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| for (int i = 0; i < 4; ++i) { | |||
| for (int j = 0; j < hidden_size; ++j) { | |||
| for (int t = 0; t < shape[2]; ++t) { | |||
| tensor_data[(i * hidden_size + j) * shape[2] + t] = data_ptr[t * shape[1] + data_diff[i] * hidden_size + j]; | |||
| } | |||
| } | |||
| } | |||
| default_param->SetTensorData(tensor_data, param_num * 4); | |||
| weight->set_default_param(default_param); | |||
| std::vector<int64_t> shape_vector_i(shape.begin(), shape.end()); | |||
| auto abstract_tensor_i = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector_i); | |||
| if (abstract_tensor_i == nullptr) { | |||
| MS_LOG(ERROR) << "abstract_tensor is nullptr"; | |||
| delete[] tensor_data; | |||
| return RET_ERROR; | |||
| } | |||
| weight->set_abstract(abstract_tensor_i); | |||
| return RET_OK; | |||
| } | |||
| STATUS TfLstmCellFusion::SplitWeights(const AnfNodePtr &weight, const ParameterPtr &weight_i, | |||
| const ParameterPtr &weight_c, int hidden_size) const { | |||
| // split input_size and hidden_size at dim 0 | |||
| // transform i,c,f,o to i,o,f,c at dim 1 | |||
| MS_ASSERT(weight != nullptr); | |||
| MS_ASSERT(wiehgt_i != nullptr); | |||
| MS_ASSERT(wiehgt_c != nullptr); | |||
| if (!utils::isa<ParameterPtr>(weight)) { | |||
| return RET_ERROR; | |||
| } | |||
| auto weight_param = utils::cast<ParameterPtr>(weight); | |||
| if (!weight_param->has_default()) { | |||
| MS_LOG(DEBUG) << "weight not have default value"; | |||
| return RET_ERROR; | |||
| } | |||
| if (!utils::isa<ParamValueLitePtr>(weight_param->default_param())) { | |||
| MS_LOG(DEBUG) << "default value is not ParamValueLite"; | |||
| return RET_FAILED; | |||
| } | |||
| auto origin_tensor = std::dynamic_pointer_cast<ParamValueLite>(weight_param->default_param()); | |||
| if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { | |||
| MS_LOG(DEBUG) << "origin_tensor is not float32 type"; | |||
| return RET_ERROR; | |||
| } | |||
| auto data_ptr = reinterpret_cast<float *>(origin_tensor->tensor_addr()); | |||
| auto data_shape = origin_tensor->tensor_shape(); | |||
| if (data_shape.size() != 2) { | |||
| MS_LOG(ERROR) << "weight data shape invalid"; | |||
| return RET_ERROR; | |||
| } | |||
| if (data_shape[0] <= hidden_size) { | |||
| MS_LOG(ERROR) << "weight data shape[0] invalid"; | |||
| return RET_ERROR; | |||
| } | |||
| if (hidden_size * 4 != data_shape[1]) { | |||
| MS_LOG(ERROR) << "weight data shape[1] invalid"; | |||
| return RET_ERROR; | |||
| } | |||
| const auto input_size = data_shape[0] - hidden_size; | |||
| std::vector<int> shape_i{1, 4 * hidden_size, input_size}; | |||
| if (SetWeightAbstractAndDefault(weight_i, shape_i, data_ptr, hidden_size) != RET_OK) { | |||
| MS_LOG(ERROR) << "get weight_i failed"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> shape_c{1, 4 * hidden_size, hidden_size}; | |||
| if (SetWeightAbstractAndDefault(weight_c, shape_c, data_ptr + input_size * data_shape[1], hidden_size) != RET_OK) { | |||
| MS_LOG(ERROR) << "get weight_i failed"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TfLstmCellFusion::PopulateBiasNode(const EquivPtr &body_equiv, const ParameterPtr &new_bias, | |||
| const AnfNodePtr &old_bias, const int hidden_size) const { | |||
| MS_ASSERT(body_equiv != nullptr); | |||
| MS_ASSERT(new_bias != nullptr); | |||
| MS_ASSERT(old_bias != nullptr); | |||
| if (!utils::isa<ParameterPtr>(old_bias)) { | |||
| MS_LOG(DEBUG) << "old_bias is not parameter"; | |||
| return RET_ERROR; | |||
| } | |||
| auto old_bias_param = utils::cast<ParameterPtr>(old_bias); | |||
| if (!old_bias_param->has_default()) { | |||
| MS_LOG(DEBUG) << "bias not have default value"; | |||
| return RET_ERROR; | |||
| } | |||
| if (!utils::isa<ParamValueLitePtr>(old_bias_param->default_param())) { | |||
| MS_LOG(DEBUG) << "default value is not ParamValueLite"; | |||
| return RET_FAILED; | |||
| } | |||
| auto origin_tensor = std::dynamic_pointer_cast<ParamValueLite>(old_bias_param->default_param()); | |||
| if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { | |||
| MS_LOG(DEBUG) << "origin_tensor is not float32 type"; | |||
| return RET_ERROR; | |||
| } | |||
| auto data_ptr = reinterpret_cast<float *>(origin_tensor->tensor_addr()); | |||
| auto data_shape = origin_tensor->tensor_shape(); | |||
| if (data_shape.size() != 1 || data_shape[0] != 4 * hidden_size) { | |||
| MS_LOG(DEBUG) << "bias data shape illegal"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> shape{1, 8 * hidden_size}; | |||
| auto default_param = std::make_shared<ParamValueLite>(); | |||
| if (default_param == nullptr) { | |||
| MS_LOG(ERROR) << "new_default is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| default_param->set_tensor_shape(shape); | |||
| default_param->set_tensor_type(kNumberTypeFloat32); | |||
| default_param->set_format(schema::Format_NHWC); | |||
| auto tensor_data = new (std::nothrow) float[hidden_size * 8]; | |||
| auto forget_bias_node = utils::cast<AnfNodePtr>((*body_equiv)[forget_bias_]); | |||
| if (forget_bias_node == nullptr) { | |||
| MS_LOG(ERROR) << "forget bias node is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| float forget_bias_value = 0.0f; | |||
| if (GetFloatScalarFromParamValueLite(forget_bias_node, &forget_bias_value) != RET_OK) { | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> data_diff{0, 3, 2, 1}; | |||
| for (int i = 0; i < 8; ++i) { | |||
| for (int j = 0; j < hidden_size; ++j) { | |||
| if (i < 4) { | |||
| tensor_data[i * hidden_size + j] = data_ptr[data_diff[i] * hidden_size + j]; | |||
| if (i == 2) { // forget bias | |||
| tensor_data[i * hidden_size + j] += forget_bias_value; | |||
| } | |||
| } else { | |||
| tensor_data[i * hidden_size + j] = 0.0f; | |||
| } | |||
| } | |||
| } | |||
| default_param->SetTensorData(tensor_data, hidden_size * 8 * 4); | |||
| new_bias->set_default_param(default_param); | |||
| std::vector<int64_t> shape_vector_i(shape.begin(), shape.end()); | |||
| auto abstract_tensor_i = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector_i); | |||
| if (abstract_tensor_i == nullptr) { | |||
| MS_LOG(ERROR) << "abstract_tensor is nullptr"; | |||
| delete[] tensor_data; | |||
| return RET_ERROR; | |||
| } | |||
| new_bias->set_abstract(abstract_tensor_i); | |||
| return RET_OK; | |||
| } | |||
| CNodePtr TfLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, | |||
| const EquivPtr &body_equiv, const std::string &base_name, | |||
| const float smooth) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(equiv != nullptr); | |||
| auto lstm_primitive = std::make_unique<schema::PrimitiveT>(); | |||
| std::unique_ptr<schema::LstmT> attr = std::make_unique<schema::LstmT>(); | |||
| attr->bidirection = false; | |||
| attr->smooth = smooth; | |||
| lstm_primitive->value.type = schema::PrimitiveType_Lstm; | |||
| lstm_primitive->value.value = attr.release(); | |||
| auto lstm_cvalue = lite::PrimitiveC::Create(lstm_primitive.release()); | |||
| auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(lstm_cvalue)); | |||
| auto &vars = while_input_vars_; | |||
| auto limit1 = utils::cast<AnfNodePtr>((*equiv)[vars[3]]); | |||
| MS_ASSERT(limit1); | |||
| auto limit2 = utils::cast<AnfNodePtr>((*equiv)[vars[7]]); | |||
| MS_ASSERT(limit2); | |||
| auto weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]); | |||
| MS_ASSERT(weight); | |||
| auto bias = utils::cast<AnfNodePtr>((*equiv)[vars[10]]); | |||
| MS_ASSERT(bias); | |||
| auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]); | |||
| MS_ASSERT(input); | |||
| auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]); | |||
| MS_ASSERT(cell); | |||
| auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]); | |||
| MS_ASSERT(hidden); | |||
| if (!utils::isa<ParameterPtr>(hidden)) { | |||
| MS_LOG(DEBUG) << "hidden is not parameter"; | |||
| return nullptr; | |||
| } | |||
| auto hidden_param = utils::cast<ParameterPtr>(hidden); | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(hidden_param->abstract())) { | |||
| MS_LOG(DEBUG) << "hidden abstract is not AbstractTensor"; | |||
| return nullptr; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(hidden_param->abstract()); | |||
| auto hidden_shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| if (hidden_shape.size() == 0) { | |||
| MS_LOG(DEBUG) << "can't get hidden shape"; | |||
| return nullptr; | |||
| } | |||
| auto i_weight = func_graph->add_parameter(); | |||
| i_weight->set_name(base_name + "_weight_i"); | |||
| i_weight->set_abstract(weight->abstract()->Clone()); | |||
| auto c_weight = func_graph->add_parameter(); | |||
| c_weight->set_name(base_name + "_weight_c"); | |||
| c_weight->set_abstract(weight->abstract()->Clone()); | |||
| if (SplitWeights(weight, i_weight, c_weight, hidden_shape.back()) != RET_OK) { | |||
| MS_LOG(DEBUG) << "split weight to i_weight and c_weight failed"; | |||
| return nullptr; | |||
| } | |||
| auto bias_node = func_graph->add_parameter(); | |||
| bias_node->set_name(base_name + "_bias"); | |||
| bias_node->set_abstract(bias->abstract()->Clone()); | |||
| if (PopulateBiasNode(body_equiv, bias_node, bias, hidden_shape.back()) != RET_OK) { | |||
| MS_LOG(DEBUG) << "reorder bias failed"; | |||
| return nullptr; | |||
| } | |||
| if (!utils::isa<CNodePtr>(input) || GetCNodeType(input) != schema::PrimitiveType_TensorListFromTensor) { | |||
| MS_LOG(DEBUG) << "input is not tensorlistfromtensor op"; | |||
| return nullptr; | |||
| } | |||
| auto tensor_list_cnode = utils::cast<CNodePtr>(input); | |||
| auto input_tensor_node = tensor_list_cnode->input(1); | |||
| std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias_node, hidden, | |||
| cell}; | |||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||
| new_node->set_fullname_with_scope(base_name); | |||
| return new_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TfLstmCellFusion : public TfliteLstmCellFusion { | |||
| public: | |||
| explicit TfLstmCellFusion(const std::string &name = "lstm_cell_fusion", bool multigraph = true); | |||
| ~TfLstmCellFusion() override = default; | |||
| private: | |||
| AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const override; | |||
| CNodePtr CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const EquivPtr &body_equiv, | |||
| const std::string &base_name, const float smooth) const override; | |||
| lite::STATUS SplitWeights(const AnfNodePtr &weight, const ParameterPtr &weight_i, const ParameterPtr &weight_c, | |||
| int hidden_size) const; | |||
| lite::STATUS SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector<int> &shape, | |||
| const float *const data_ptr, const int hidden_size) const; | |||
| lite::STATUS PopulateBiasNode(const EquivPtr &body_equiv, const ParameterPtr &new_bias, const AnfNodePtr &old_bias, | |||
| const int hidden_size) const; | |||
| private: | |||
| VarPtr forget_bias_ = nullptr; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_ | |||
| @@ -0,0 +1,727 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||
| #include <memory> | |||
| #include <functional> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "securec/include/securec.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kWhileInputsLength = 23; | |||
| constexpr size_t kWhileInputsVarNum = 21; | |||
| constexpr size_t kCondNodesNum = 12; | |||
| constexpr size_t kCondCNodesNum = 4; | |||
| constexpr size_t kBodyNodesNum = 95; | |||
| constexpr size_t kBodyCNodesNum = 34; | |||
| constexpr size_t kLSTMOutputNum = 3; | |||
| const auto &p1 = std::placeholders::_1; | |||
| constexpr float EPSILON = 1e-5; | |||
| bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); } | |||
| bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) { | |||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||
| return opt::GetCNodeType(n) == type; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| STATUS TfliteLstmCellFusion::GetFloatScalarFromParamValueLite(const AnfNodePtr ¶m_value, float *v) const { | |||
| if (param_value == nullptr || v == nullptr) { | |||
| MS_LOG(ERROR) << "param_value or v is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (!utils::isa<ParameterPtr>(param_value)) { | |||
| MS_LOG(DEBUG) << "param_value is not ParamValueLitePtr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto param_ptr = utils::cast<ParameterPtr>(param_value); | |||
| if (!param_ptr->has_default()) { | |||
| MS_LOG(DEBUG) << "param not have default"; | |||
| return RET_ERROR; | |||
| } | |||
| auto default_param = param_ptr->default_param(); | |||
| if (!utils::isa<ParamValueLitePtr>(default_param)) { | |||
| MS_LOG(DEBUG) << "param_value is not ParamValueLitePtr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto default_param_ptr = utils::cast<ParamValueLitePtr>(default_param); | |||
| auto tensor_shape = default_param_ptr->tensor_shape(); | |||
| if (!(tensor_shape.size() == 0 || (tensor_shape.size() == 1 && tensor_shape[0] == 1))) { | |||
| MS_LOG(DEBUG) << "default param is not scalar"; | |||
| return RET_ERROR; | |||
| } | |||
| if (default_param_ptr->tensor_type() != kNumberTypeFloat32 && default_param_ptr->tensor_type() != kNumberTypeFloat) { | |||
| MS_LOG(DEBUG) << "default param is not float"; | |||
| return RET_ERROR; | |||
| } | |||
| *v = *(reinterpret_cast<float *>(default_param_ptr->tensor_addr())); | |||
| return RET_OK; | |||
| } | |||
| TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num, | |||
| int cond_nodes_num, int cond_cnodes_num, int body_nodes_num, | |||
| int body_cnodes_num) | |||
| : PatternProcessPass(name, multigraph) { | |||
| /* | |||
| * input vars for lstm while node | |||
| * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_ | |||
| * 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_ | |||
| */ | |||
| this->while_inputs_num_ = input_length == 0 ? kWhileInputsLength : input_length; | |||
| this->while_input_var_num_ = var_num == 0 ? kWhileInputsVarNum : var_num; | |||
| this->cond_nodes_num_ = cond_nodes_num == 0 ? kCondNodesNum : cond_nodes_num; | |||
| this->cond_cnodes_num_ = cond_cnodes_num == 0 ? kCondCNodesNum : cond_cnodes_num; | |||
| this->body_nodes_num_ = body_nodes_num == 0 ? kBodyNodesNum : body_nodes_num; | |||
| this->body_cnodes_num_ = body_cnodes_num == 0 ? kBodyCNodesNum : body_cnodes_num; | |||
| for (size_t i = 0; i < this->while_input_var_num_; ++i) { | |||
| while_input_vars_.emplace_back(std::make_shared<Var>()); | |||
| } | |||
| cell_smooth_old_ = std::make_shared<Var>(); | |||
| cell_smooth_new_ = std::make_shared<Var>(); | |||
| hidden_smooth_old_ = std::make_shared<Var>(); | |||
| hidden_smooth_new_ = std::make_shared<Var>(); | |||
| } | |||
| AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||
| auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto is_parameter4 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); | |||
| auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); | |||
| auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_LogicalAnd)); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); | |||
| VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); | |||
| VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); | |||
| VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); | |||
| VectorRef return_ref = VectorRef({is_return, logicaland_ref}); | |||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||
| auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true); | |||
| return pattern; | |||
| } | |||
| AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||
| std::vector<CondVarPtr> placeholders; | |||
| for (int i = 0; i < 20; ++i) { | |||
| placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode)); | |||
| } | |||
| VectorRef add2 = VectorRef({std::make_shared<Var>(), placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||
| VectorRef add3 = VectorRef({std::make_shared<Var>(), placeholders[0], std::make_shared<CondVar>(IsParameterNode)}); | |||
| VectorRef concat_i_w = VectorRef({std::make_shared<Var>(), placeholders[8], placeholders[12]}); | |||
| VectorRef concat_f_w = VectorRef({std::make_shared<Var>(), placeholders[9], placeholders[13]}); | |||
| VectorRef concat_c_w = VectorRef({std::make_shared<Var>(), placeholders[10], placeholders[14]}); | |||
| VectorRef concat_o_w = VectorRef({std::make_shared<Var>(), placeholders[11], placeholders[15]}); | |||
| VectorRef get_item = VectorRef( | |||
| {std::make_shared<Var>("GetItem"), placeholders[7], placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||
| VectorRef concat_input_h = VectorRef({std::make_shared<Var>(), get_item, placeholders[5]}); | |||
| VectorRef matmul_input = VectorRef({std::make_shared<Var>(), concat_input_h, concat_i_w}); | |||
| VectorRef matmul_forget = VectorRef({std::make_shared<Var>(), concat_input_h, concat_f_w}); | |||
| VectorRef matmul_cell = VectorRef({std::make_shared<Var>(), concat_input_h, concat_c_w}); | |||
| VectorRef matmul_output = VectorRef({std::make_shared<Var>(), concat_input_h, concat_o_w}); | |||
| VectorRef bias_input = VectorRef({std::make_shared<Var>(), matmul_input, placeholders[16]}); | |||
| VectorRef bias_forget = VectorRef({std::make_shared<Var>(), matmul_forget, placeholders[17]}); | |||
| VectorRef bias_cell = VectorRef({std::make_shared<Var>(), matmul_cell, placeholders[18]}); | |||
| VectorRef bias_output = VectorRef({std::make_shared<Var>(), matmul_output, placeholders[19]}); | |||
| VectorRef cell = VectorRef({std::make_shared<Var>("Tanh"), bias_cell}); | |||
| VectorRef input_gate = VectorRef({std::make_shared<Var>("Sigmoid"), bias_input}); | |||
| VectorRef cell_input = VectorRef({std::make_shared<Var>("Mul"), input_gate, cell}); | |||
| VectorRef forget_gate = VectorRef({std::make_shared<Var>("Sigmoid"), bias_forget}); | |||
| VectorRef cell_forgeted = VectorRef({std::make_shared<Var>("Mul"), forget_gate, placeholders[4]}); | |||
| VectorRef cell_new = VectorRef({std::make_shared<Var>("Add"), cell_forgeted, cell_input}); | |||
| VectorRef smooth_cell_old = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_old_, placeholders[4]}); | |||
| VectorRef smooth_cell_new = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_new_, cell_new}); | |||
| VectorRef cell_output = VectorRef({std::make_shared<Var>("Add"), smooth_cell_new, smooth_cell_old}); | |||
| VectorRef output_gate = VectorRef({std::make_shared<Var>("Sigmoid"), bias_output}); | |||
| VectorRef cell_to_output = VectorRef({std::make_shared<Var>("Tanh"), cell_new}); | |||
| VectorRef output = VectorRef({std::make_shared<Var>("Mul"), output_gate, cell_to_output}); | |||
| VectorRef smooth_hidden_old = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_old_, placeholders[5]}); | |||
| VectorRef smooth_hidden_new = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_new_, output}); | |||
| VectorRef hidden_output = VectorRef({std::make_shared<Var>("Add"), smooth_hidden_new, smooth_hidden_old}); | |||
| VectorRef set_item = VectorRef({std::make_shared<Var>("SetItem"), placeholders[3], placeholders[2], output}); | |||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple)); | |||
| std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output}; | |||
| outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); | |||
| VectorRef make_tuple_node = VectorRef(outputs); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); | |||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | |||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||
| auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true); | |||
| return pattern; | |||
| } | |||
| const BaseRef TfliteLstmCellFusion::DefinePattern() const { | |||
| auto is_while_node = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While)); | |||
| VectorRef while_node = VectorRef({is_while_node}); | |||
| auto while_inputs = while_input_vars_; | |||
| while_inputs.insert(while_inputs.begin() + 4, while_input_vars_[2]); | |||
| while_node.insert(while_node.end(), while_inputs.begin(), while_inputs.end()); | |||
| auto is_tuple_get_item = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)); | |||
| VectorRef while_output = VectorRef({is_tuple_get_item, while_node, std::make_shared<Var>()}); | |||
| auto is_tensor_list_stack = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)); | |||
| auto is_parameter = std::make_shared<CondVar>(IsParameterNode); | |||
| VectorRef tensor_list_stack_node = VectorRef({is_tensor_list_stack, while_output, is_parameter}); | |||
| return tensor_list_stack_node; | |||
| } | |||
| EquivPtr TfliteLstmCellFusion::MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars, | |||
| const AnfNodePtr &pattern) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(pattern != nullptr); | |||
| auto return_node = func_graph->get_return(); | |||
| PatternEngine pattern_engine(PatternEngine(std::make_shared<DefaultVisitor>(), | |||
| std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual), | |||
| std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))); | |||
| auto empty_equiv = std::make_shared<Equiv>(); | |||
| EquivPtr equiv = pattern_engine.Match(pattern, return_node, *primitive_vars, empty_equiv); | |||
| return equiv; | |||
| } | |||
| // make sure that only 3,4,5 output of while is referenced | |||
| bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(while_cnode != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto while_node_users = manager->node_users()[while_cnode]; | |||
| std::vector<size_t> valid_indexes{3, 4, 5}; | |||
| for (auto &node_user : while_node_users) { | |||
| if (!utils::isa<CNodePtr>(node_user.first)) { | |||
| return false; | |||
| } | |||
| auto cnode = utils::cast<CNodePtr>(node_user.first); | |||
| if (GetCNodeType(cnode) != schema::PrimitiveType_TupleGetItem) { | |||
| return false; | |||
| } | |||
| auto index = GetTupleGetItemOutIndex(cnode); | |||
| if (!lite::IsContain(valid_indexes, index)) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| EquivPtr TfliteLstmCellFusion::CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern, | |||
| const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph, | |||
| const size_t cnode_num, const size_t all_node_num) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(pattern != nullptr); | |||
| MS_ASSERT(anf_sub_graph != nullptr); | |||
| auto sub_graph = GetValueNode<FuncGraphPtr>(anf_sub_graph); | |||
| auto nodes = TopoSort(sub_graph->get_return()); | |||
| auto cnodes = sub_graph->GetOrderedCnodes(); | |||
| if (cnodes.size() != cnode_num || nodes.size() != all_node_num) { | |||
| MS_LOG(DEBUG) << "sub graph nodes num not match"; | |||
| return nullptr; | |||
| } | |||
| return MatchGraph(sub_graph, primitive_vars, pattern); | |||
| } | |||
| bool TfliteLstmCellFusion::CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv, | |||
| const CNodePtr &while_cnode, float *smooth) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(equiv != nullptr); | |||
| MS_ASSERT(while_cnode != nullptr); | |||
| MS_ASSERT(smooth != nullptr); | |||
| auto cell_smooth_old_node = utils::cast<AnfNodePtr>((*equiv)[cell_smooth_old_]); | |||
| MS_ASSERT(cell_smooth_old_node != nullptr); | |||
| auto cell_smooth_new_node = utils::cast<AnfNodePtr>((*equiv)[cell_smooth_new_]); | |||
| MS_ASSERT(cell_smooth_new_node != nullptr); | |||
| auto hidden_smooth_old_node = utils::cast<AnfNodePtr>((*equiv)[hidden_smooth_old_]); | |||
| MS_ASSERT(hidden_smooth_old_node != nullptr); | |||
| auto hidden_smooth_new_node = utils::cast<AnfNodePtr>((*equiv)[hidden_smooth_new_]); | |||
| MS_ASSERT(hidden_smooth_new_node != nullptr); | |||
| float cell_old, cell_new, hidden_old, hidden_new; | |||
| if (GetFloatScalarFromParamValueLite(cell_smooth_old_node, &cell_old) != RET_OK) { | |||
| return false; | |||
| } | |||
| if (GetFloatScalarFromParamValueLite(cell_smooth_new_node, &cell_new) != RET_OK) { | |||
| return false; | |||
| } | |||
| if (GetFloatScalarFromParamValueLite(hidden_smooth_old_node, &hidden_old) != RET_OK) { | |||
| return false; | |||
| } | |||
| if (GetFloatScalarFromParamValueLite(hidden_smooth_new_node, &hidden_new) != RET_OK) { | |||
| return false; | |||
| } | |||
| if (cell_old < 0.0f || cell_old > 1.0f || cell_new < 0.0f || cell_new > 1.0f) { | |||
| MS_LOG(DEBUG) << "cell smooth value illegal"; | |||
| return false; | |||
| } | |||
| if (hidden_old < 0.0f || hidden_old > 1.0f || hidden_new < 0.0f || hidden_new > 1.0f) { | |||
| MS_LOG(DEBUG) << "hidden smooth value illegal"; | |||
| return false; | |||
| } | |||
| if (std::abs(cell_old + cell_new - 1.0f) > EPSILON || std::abs(hidden_old + hidden_new - 1.0f) > EPSILON || | |||
| std::abs(cell_old - hidden_old) > EPSILON) { | |||
| MS_LOG(DEBUG) << "smooth value illegal"; | |||
| return false; | |||
| } | |||
| *smooth = cell_old; | |||
| return true; | |||
| } | |||
| STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector<AnfNodePtr> ¶ms, const ParameterPtr &new_param, | |||
| bool is_bias) const { | |||
| MS_ASSERT(new_param != nullptr); | |||
| MS_ASSERT(params.size() == 4); | |||
| std::vector<float *> data_ptrs; | |||
| std::vector<std::vector<int>> data_shapes; | |||
| for (auto ¶m : params) { | |||
| if (!utils::isa<ParameterPtr>(param)) { | |||
| MS_LOG(DEBUG) << "param is not Parameter node"; | |||
| return RET_FAILED; | |||
| } | |||
| auto param_t = utils::cast<ParameterPtr>(param); | |||
| if (!param_t->has_default()) { | |||
| MS_LOG(DEBUG) << "param not have default value"; | |||
| return RET_FAILED; | |||
| } | |||
| if (!utils::isa<ParamValueLitePtr>(param_t->default_param())) { | |||
| MS_LOG(DEBUG) << "default value is not ParamValueLite"; | |||
| return RET_FAILED; | |||
| } | |||
| auto origin_tensor = std::dynamic_pointer_cast<ParamValueLite>(param_t->default_param()); | |||
| if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { | |||
| MS_LOG(DEBUG) << "origin_tensor is not float32 type"; | |||
| return RET_FAILED; | |||
| } | |||
| auto data_ptr = reinterpret_cast<float *>(origin_tensor->tensor_addr()); | |||
| auto data_shape = origin_tensor->tensor_shape(); | |||
| data_ptrs.push_back(data_ptr); | |||
| data_shapes.push_back(data_shape); | |||
| } | |||
| for (size_t i = 1; i < data_shapes.size(); ++i) { | |||
| if (data_shapes[i] != data_shapes[0]) { | |||
| MS_LOG(DEBUG) << "data shape not same"; | |||
| return RET_FAILED; | |||
| } | |||
| } | |||
| auto new_default = std::make_shared<ParamValueLite>(); | |||
| if (new_default == nullptr) { | |||
| MS_LOG(ERROR) << "new_default is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> new_shape; | |||
| float *tensor_data = nullptr; | |||
| int step = 0; | |||
| int data_size = 0; | |||
| if (is_bias) { | |||
| if (data_shapes[0].size() != 1) { | |||
| MS_LOG(ERROR) << "bias data shape error"; | |||
| return RET_ERROR; | |||
| } | |||
| step = data_shapes[0][0]; | |||
| data_size = 8 * step; | |||
| new_shape = std::vector<int>({1, data_size}); | |||
| } else { | |||
| if (data_shapes[0].size() != 2) { | |||
| MS_LOG(ERROR) << "weight data shape error"; | |||
| return RET_ERROR; | |||
| } | |||
| new_shape = std::vector<int>({1, data_shapes[0][0] * 4, data_shapes[0][1]}); | |||
| step = data_shapes[0][0] * data_shapes[0][1]; | |||
| data_size = 4 * step; | |||
| } | |||
| tensor_data = new (std::nothrow) float[data_size]; | |||
| if (tensor_data == nullptr) { | |||
| MS_LOG(ERROR) << "new data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| for (int i = 0; i < data_size; ++i) { // bias are stored into first 4*hidden_size buffer, the rest is all 0 | |||
| tensor_data[i] = 0.0f; | |||
| } | |||
| for (size_t i = 0; i < data_ptrs.size(); ++i) { | |||
| auto source_len = std::accumulate(data_shapes[i].begin(), data_shapes[i].end(), 1, std::multiplies<int>()); | |||
| auto ret = memcpy_s(tensor_data + i * step, step * sizeof(float), data_ptrs[i], source_len * sizeof(float)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s error"; | |||
| delete[] tensor_data; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| new_default->set_tensor_shape(new_shape); | |||
| new_default->set_tensor_type(kNumberTypeFloat32); | |||
| new_default->set_format(schema::Format_NHWC); | |||
| new_default->SetTensorData(tensor_data, data_size * sizeof(float)); | |||
| new_param->set_default_param(new_default); | |||
| std::vector<int64_t> shape_vector(new_shape.begin(), new_shape.end()); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||
| if (abstract_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "abstract_tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| new_param->set_abstract(abstract_tensor); | |||
| return RET_OK; | |||
| } | |||
| CNodePtr TfliteLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, | |||
| const EquivPtr &body_equiv, const std::string &base_name, | |||
| const float smooth) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(equiv != nullptr); | |||
| MS_ASSERT(body_equiv != nullptr); | |||
| /* | |||
| * input vars for while node | |||
| * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_ | |||
| * 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_ | |||
| */ | |||
| auto lstm_primitive = std::make_unique<schema::PrimitiveT>(); | |||
| std::unique_ptr<schema::LstmT> attr = std::make_unique<schema::LstmT>(); | |||
| attr->bidirection = false; | |||
| attr->smooth = smooth; | |||
| lstm_primitive->value.type = schema::PrimitiveType_Lstm; | |||
| lstm_primitive->value.value = attr.release(); | |||
| auto lstm_cvalue = lite::PrimitiveC::Create(lstm_primitive.release()); | |||
| auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(lstm_cvalue)); | |||
| auto &vars = while_input_vars_; | |||
| auto limit1 = utils::cast<AnfNodePtr>((*equiv)[vars[3]]); | |||
| MS_ASSERT(limit1); | |||
| auto limit2 = utils::cast<AnfNodePtr>((*equiv)[vars[7]]); | |||
| MS_ASSERT(limit2); | |||
| auto i2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]); | |||
| MS_ASSERT(i2i_weight); | |||
| auto i2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[10]]); | |||
| MS_ASSERT(i2f_weight); | |||
| auto i2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[11]]); | |||
| MS_ASSERT(i2c_weight); | |||
| auto i2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[12]]); | |||
| MS_ASSERT(i2o_weight); | |||
| auto c2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[13]]); | |||
| MS_ASSERT(c2i_weight); | |||
| auto c2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[14]]); | |||
| MS_ASSERT(c2f_weight); | |||
| auto c2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[15]]); | |||
| MS_ASSERT(c2c_weight); | |||
| auto c2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[16]]); | |||
| MS_ASSERT(c2o_weight); | |||
| auto i_bias = utils::cast<AnfNodePtr>((*equiv)[vars[17]]); | |||
| MS_ASSERT(i_bias); | |||
| auto f_bias = utils::cast<AnfNodePtr>((*equiv)[vars[18]]); | |||
| MS_ASSERT(f_bias); | |||
| auto c_bias = utils::cast<AnfNodePtr>((*equiv)[vars[19]]); | |||
| MS_ASSERT(c_bias); | |||
| auto o_bias = utils::cast<AnfNodePtr>((*equiv)[vars[20]]); | |||
| MS_ASSERT(o_bias); | |||
| auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]); | |||
| MS_ASSERT(input); | |||
| auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]); | |||
| MS_ASSERT(cell); | |||
| auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]); | |||
| MS_ASSERT(hidden); | |||
| std::vector<AnfNodePtr> i_weights{i2i_weight, i2o_weight, i2f_weight, i2c_weight}; | |||
| auto i_weight = func_graph->add_parameter(); | |||
| auto status = GetConcatedParam(i_weights, i_weight, false); | |||
| if (status != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| i_weight->set_name(base_name + "_weight_i"); | |||
| std::vector<AnfNodePtr> c_weights{c2i_weight, c2o_weight, c2f_weight, c2c_weight}; | |||
| auto c_weight = func_graph->add_parameter(); | |||
| status = GetConcatedParam(c_weights, c_weight, false); | |||
| if (status != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| c_weight->set_name(base_name + "_weight_c"); | |||
| std::vector<AnfNodePtr> biases{i_bias, o_bias, f_bias, c_bias}; | |||
| auto bias = func_graph->add_parameter(); | |||
| status = GetConcatedParam(biases, bias, true); | |||
| if (status != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| bias->set_name(base_name + "_bias"); | |||
| if (!utils::isa<CNodePtr>(input) || GetCNodeType(input) != schema::PrimitiveType_TensorListFromTensor) { | |||
| MS_LOG(DEBUG) << "input is not tensorlistfromtensor op"; | |||
| return nullptr; | |||
| } | |||
| auto tensor_list_cnode = utils::cast<CNodePtr>(input); | |||
| auto input_tensor_node = tensor_list_cnode->input(1); | |||
| std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias, hidden, cell}; | |||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||
| new_node->set_fullname_with_scope(base_name); | |||
| return new_node; | |||
| } | |||
| CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| const int item_index) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(get_items != nullptr); | |||
| auto tuple_get_item_prim_ptr = lite::GetTupleGetItemPrim(); | |||
| if (tuple_get_item_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | |||
| auto get_item_value = NewValueNode(MakeValue<int>(item_index)); | |||
| if (tuple_get_item_prim == nullptr || get_item_value == nullptr) { | |||
| MS_LOG(ERROR) << "NewValueNode is nullptr"; | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> inputs{tuple_get_item_prim, node, get_item_value}; | |||
| CNodePtr get_item_cnode = func_graph->NewCNode(inputs); | |||
| std::vector<int64_t> shape_vector; | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||
| if (abstract_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "create abstract_tensor failed"; | |||
| return nullptr; | |||
| } | |||
| get_item_cnode->set_abstract(abstract_tensor); | |||
| get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" + | |||
| std::to_string(item_index)); | |||
| return get_item_cnode; | |||
| } | |||
| STATUS TfliteLstmCellFusion::AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode, | |||
| const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(while_cnode != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto tr = manager->Transact(); | |||
| auto while_node_users = manager->node_users()[while_cnode]; | |||
| for (auto &node_user : while_node_users) { | |||
| if (node_user.first == output_get_item) { | |||
| continue; | |||
| } | |||
| if (!utils::isa<CNodePtr>(node_user.first)) { | |||
| return RET_ERROR; | |||
| } | |||
| auto get_item = utils::cast<CNodePtr>(node_user.first); | |||
| if (GetCNodeType(get_item) != schema::PrimitiveType_TupleGetItem) { | |||
| return RET_ERROR; | |||
| } | |||
| auto new_inputs = get_item->inputs(); | |||
| if (new_inputs.size() != 3) { | |||
| return RET_ERROR; | |||
| } | |||
| new_inputs[1] = lstm_cnode; | |||
| auto index_vnode = get_item->input(2); | |||
| if (!utils::isa<ValueNode>(index_vnode)) { | |||
| MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node"; | |||
| return RET_ERROR; | |||
| } | |||
| auto value_node = utils::cast<ValueNodePtr>(index_vnode); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(ERROR) << "cast to ValueNode failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto origin_index = GetValue<int>(value_node->value()); | |||
| int new_index = origin_index == 4 ? 2 : 1; | |||
| auto new_index_vnode = NewValueNode(MakeValue<int>(new_index)); | |||
| new_inputs[2] = new_index_vnode; | |||
| get_item->set_inputs(new_inputs); | |||
| get_item->set_fullname_with_scope(lstm_cnode->fullname_with_scope() + "_getitem_" + std::to_string(new_index)); | |||
| if (get_item->abstract() == nullptr) { | |||
| MS_LOG(ERROR) << "get_item's abstract is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> squeeze_axis{0}; | |||
| auto squeeze_node = CreateSqueezeNode(func_graph, get_item, squeeze_axis); | |||
| if (squeeze_node == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| auto get_item_users = manager->node_users()[get_item]; | |||
| for (auto &get_item_user : get_item_users) { | |||
| tr.SetEdge(get_item_user.first, get_item_user.second, squeeze_node); | |||
| } | |||
| } | |||
| tr.Commit(); | |||
| return RET_OK; | |||
| } | |||
| STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int output_num) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| AbstractBasePtrList abstract_list; | |||
| for (int i = 0; i < output_num; ++i) { | |||
| std::vector<int64_t> shape_vector; | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||
| if (abstract_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "create abstract_tensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| abstract_list.emplace_back(abstract_tensor); | |||
| } | |||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||
| if (abstract_tuple == nullptr) { | |||
| MS_LOG(ERROR) << "create abstract_tuple failed"; | |||
| return RET_ERROR; | |||
| } | |||
| cnode->set_abstract(abstract_tuple); | |||
| return RET_OK; | |||
| } | |||
| CNodePtr TfliteLstmCellFusion::CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node, | |||
| const std::vector<int> &axis) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| std::unique_ptr<schema::SqueezeT> attr = std::make_unique<schema::SqueezeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new SqueezeT failed"; | |||
| return nullptr; | |||
| } | |||
| attr->axis = axis; | |||
| auto new_primitive_t = std::make_unique<schema::PrimitiveT>(); | |||
| if (new_primitive_t == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_t is nullptr"; | |||
| return nullptr; | |||
| } | |||
| new_primitive_t->value.type = schema::PrimitiveType_Squeeze; | |||
| new_primitive_t->value.value = attr.release(); | |||
| auto new_primtive_c = std::shared_ptr<lite::PrimitiveC>(lite::PrimitiveC::Create(new_primitive_t.release())); | |||
| if (new_primtive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return nullptr; | |||
| } | |||
| ValueNodePtr value_node = NewValueNode(new_primtive_c); | |||
| auto squeeze_cnode = func_graph->NewCNode({value_node, input_node}); | |||
| squeeze_cnode->set_abstract(input_node->abstract()->Clone()); | |||
| squeeze_cnode->set_fullname_with_scope("squeeze_" + input_node->fullname_with_scope()); | |||
| return squeeze_cnode; | |||
| } | |||
| const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_LOG(DEBUG) << "lstm fusion pass"; | |||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return nullptr; | |||
| } | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| return nullptr; | |||
| } | |||
| auto tensor_list_stack_cnode = utils::cast<CNodePtr>(node); | |||
| auto tuple_get_item_node = tensor_list_stack_cnode->input(1); | |||
| if (!utils::isa<CNodePtr>(tuple_get_item_node)) { | |||
| return nullptr; | |||
| } | |||
| auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item_node); | |||
| auto while_node = tuple_get_item_cnode->input(1); | |||
| if (!utils::isa<CNodePtr>(while_node)) { | |||
| return nullptr; | |||
| } | |||
| auto while_cnode = utils::cast<CNodePtr>(while_node); | |||
| if (CheckIfCNodeIsNull(while_cnode) != RET_OK || CheckInputSize(while_cnode, while_inputs_num_) != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| if (!CheckReferencedOutputs(func_graph, while_cnode)) { | |||
| return nullptr; | |||
| } | |||
| PrimitiveVarMapPtr primitive_vars_cond = std::make_shared<PrimitiveVarMap>(); | |||
| auto cond_graph_pattern = GetCondGraphPattern(primitive_vars_cond); | |||
| auto cond_equiv = CheckSubGraph(func_graph, cond_graph_pattern, primitive_vars_cond, while_cnode->input(1), | |||
| cond_cnodes_num_, cond_nodes_num_); | |||
| if (cond_equiv == nullptr || cond_equiv->empty()) { | |||
| return nullptr; | |||
| } | |||
| PrimitiveVarMapPtr primitive_vars_body = std::make_shared<PrimitiveVarMap>(); | |||
| auto body_graph_pattern = GetBodyGraphPattern(primitive_vars_body); | |||
| auto body_equiv = CheckSubGraph(func_graph, body_graph_pattern, primitive_vars_body, while_cnode->input(2), | |||
| body_cnodes_num_, body_nodes_num_); | |||
| if (body_equiv == nullptr || body_equiv->empty()) { | |||
| return nullptr; | |||
| } | |||
| float smooth = 0.0f; | |||
| if (!CheckBodyGraph(func_graph, body_equiv, while_cnode, &smooth)) { | |||
| return nullptr; | |||
| } | |||
| const std::string lstm_name = "lstm_" + while_cnode->fullname_with_scope(); | |||
| auto lstm_node = CreateLSTMNode(func_graph, equiv, body_equiv, lstm_name, smooth); | |||
| if (lstm_node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto status = SetAbstractTuple(lstm_node, kLSTMOutputNum); | |||
| if (status != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| auto get_item_node = CreateOutputGetItem(func_graph, lstm_node, 0); | |||
| if (get_item_node == nullptr) { | |||
| MS_LOG(DEBUG) << "create lstm output get_item node failed"; | |||
| return nullptr; | |||
| } | |||
| status = AdjustOtherGetItems(func_graph, while_cnode, lstm_node, tuple_get_item_cnode); | |||
| if (status != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| std::vector<int> squeeze_axis{1}; // our lstm output:0 have an extra axis that tflite not have, it must be squeezed | |||
| auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis); | |||
| if (squeeze_node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto cond_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 1); | |||
| func_graph->DropFuncGraphCNodeIndex(cond_cnode_index_pair); | |||
| auto body_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 2); | |||
| func_graph->DropFuncGraphCNodeIndex(body_cnode_index_pair); | |||
| MS_LOG(INFO) << "lstm node:" << lstm_node->fullname_with_scope() << " fusion success"; | |||
| return squeeze_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,82 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TfliteLstmCellFusion : public PatternProcessPass { | |||
| public: | |||
| explicit TfliteLstmCellFusion(const std::string &name = "tflite_lstm_cell_fusion", bool multigraph = true, | |||
| int input_length = 0, int var_num = 0, int cond_nodes_num = 0, int cond_cnodes_num = 0, | |||
| int body_nodes_num = 0, int body_cnodes_num = 0); | |||
| ~TfliteLstmCellFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| public: | |||
| static EquivPtr MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars, | |||
| const AnfNodePtr &pattern); | |||
| static EquivPtr CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern, | |||
| const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph, | |||
| const size_t cnode_num, const size_t all_node_num); | |||
| static lite::STATUS SetAbstractTuple(const CNodePtr &cnode, const int output_num); | |||
| static CNodePtr CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, const int item_index); | |||
| protected: | |||
| VarPtr cell_smooth_old_ = nullptr; | |||
| VarPtr cell_smooth_new_ = nullptr; | |||
| VarPtr hidden_smooth_old_ = nullptr; | |||
| VarPtr hidden_smooth_new_ = nullptr; | |||
| std::vector<VarPtr> while_input_vars_; | |||
| lite::STATUS GetFloatScalarFromParamValueLite(const AnfNodePtr ¶m_value, float *v) const; | |||
| CNodePtr CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node, | |||
| const std::vector<int> &axis) const; | |||
| lite::STATUS AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode, | |||
| const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) const; | |||
| AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; | |||
| virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; | |||
| virtual CNodePtr CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const EquivPtr &body_equiv, | |||
| const std::string &base_name, const float smooth) const; | |||
| private: | |||
| bool CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const CNodePtr &while_cnode, | |||
| float *smooth) const; | |||
| bool CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) const; | |||
| lite::STATUS GetConcatedParam(const std::vector<AnfNodePtr> ¶ms, const ParameterPtr &new_param, | |||
| bool is_bias) const; | |||
| private: | |||
| size_t while_input_var_num_ = 0; | |||
| size_t while_inputs_num_ = 0; | |||
| size_t cond_nodes_num_ = 0; | |||
| size_t cond_cnodes_num_ = 0; | |||
| size_t body_nodes_num_ = 0; | |||
| size_t body_cnodes_num_ = 0; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_ | |||