From: @wangzhe128 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -0,0 +1,134 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp32/gru_fp32.h" | |||||
| #include <string.h> | |||||
| #include "nnacl/fp32/lstm_fp32.h" | |||||
| #include "nnacl/fp32/activation_fp32.h" | |||||
| #include "nnacl/fp32/arithmetic_fp32.h" | |||||
| void InitGruGate(float *gate_buffer, const float *bias, const GruParameter *gru_parm) { | |||||
| int gate_offest = 0; | |||||
| for (int l = 0; l < 3; l++) { | |||||
| int batch_offest = gate_offest; | |||||
| int bias_offest = l * gru_parm->hidden_size_; | |||||
| for (int b = 0; b < gru_parm->batch_; b++) { | |||||
| memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float)); | |||||
| batch_offest += gru_parm->hidden_size_; | |||||
| } | |||||
| gate_offest += gru_parm->batch_ * gru_parm->hidden_size_; | |||||
| } | |||||
| } | |||||
| void GruStepUnit(float *output, const float *input, const float *input_reset_weight, const float *input_update_weight, | |||||
| const float *input_hidden_weight, const float *state_reset_weight, const float *state_update_weight, | |||||
| const float *state_hidden_weight, const float *bias, float *hidden_state, float *gate_buffer, | |||||
| const GruParameter *gru_parm) { | |||||
| InitGruGate(gate_buffer, bias, gru_parm); | |||||
| float *update_gate = gate_buffer; | |||||
| float *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_; | |||||
| float *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2; | |||||
| // input * weight | |||||
| MatMulAcc(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); | |||||
| MatMulAcc(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); | |||||
| MatMulAcc(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); | |||||
| // state * weight | |||||
| MatMulAcc(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, | |||||
| gru_parm->hidden_size_); | |||||
| MatMulAcc(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_, | |||||
| gru_parm->hidden_size_); | |||||
| // update reset_gate | |||||
| Sigmoid(reset_gate, gru_parm->batch_ * gru_parm->hidden_size_, reset_gate); | |||||
| // update update_gate | |||||
| Sigmoid(update_gate, gru_parm->batch_ * gru_parm->hidden_size_, update_gate); | |||||
| ElementMul(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_); | |||||
| MatMulAcc(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, | |||||
| gru_parm->hidden_size_); | |||||
| Tanh(hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_, hidden_buffer); | |||||
| ElementMul(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); | |||||
| ArithmeticParameter parameter; | |||||
| parameter.in_elements_num0_ = 1; | |||||
| parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_; | |||||
| const float one = 1.0f; | |||||
| ElementOptSub(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, ¶meter); | |||||
| ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); | |||||
| memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float)); | |||||
| } | |||||
| void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias, | |||||
| float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm) { | |||||
| // forward | |||||
| const float *input_update_weight = weight_g; | |||||
| const float *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_; | |||||
| const float *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2; | |||||
| const float *state_update_weight = weight_r; | |||||
| const float *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_; | |||||
| const float *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2; | |||||
| for (int t = 0; t < check_seq_len; t++) { | |||||
| const float *input_ptr = input + t * gru_parm->input_step_; | |||||
| float *output_ptr = output + t * gru_parm->output_step_; | |||||
| GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, state_reset_weight, | |||||
| state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer, gru_parm); | |||||
| } | |||||
| // zero out extra fw outputs | |||||
| for (int t = check_seq_len; t < gru_parm->seq_len_; t++) { | |||||
| float *output_ptr = output + t * gru_parm->output_step_; | |||||
| for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { | |||||
| output_ptr[i] = 0.0f; | |||||
| } | |||||
| } | |||||
| // backward | |||||
| if (gru_parm->bidirectional_) { | |||||
| input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3; | |||||
| input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4; | |||||
| input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5; | |||||
| state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3; | |||||
| state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4; | |||||
| state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5; | |||||
| float *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_; | |||||
| const float *backward_bias = bias + 3 * gru_parm->hidden_size_; | |||||
| float *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_; | |||||
| for (int t = check_seq_len - 1; t >= 0; t--) { | |||||
| const float *input_ptr = input + t * gru_parm->input_step_; | |||||
| float *output_ptr = backward_output + t * gru_parm->output_step_; | |||||
| GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, | |||||
| state_reset_weight, state_update_weight, state_hidden_weight, backward_bias, backward_hidden_state, | |||||
| gate_buffer, gru_parm); | |||||
| } | |||||
| // zero out extra bw outputs | |||||
| for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) { | |||||
| float *output_ptr = backward_output + t * gru_parm->output_step_; | |||||
| for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { | |||||
| output_ptr[i] = 0.0f; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ | |||||
| #include "nnacl/op_base.h" | |||||
| typedef struct GruParameter { | |||||
| // Primitive parameter | |||||
| OpParameter op_parameter_; | |||||
| // shape correlative | |||||
| int input_size_; | |||||
| int hidden_size_; // output_size | |||||
| int seq_len_; | |||||
| int batch_; | |||||
| // other parameter | |||||
| int input_step_; | |||||
| int output_step_; | |||||
| bool bidirectional_; | |||||
| } GruParameter; | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias, | |||||
| float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ | |||||
| @@ -16,6 +16,7 @@ | |||||
| #include "nnacl/fp32/lstm_fp32.h" | #include "nnacl/fp32/lstm_fp32.h" | ||||
| #include <string.h> | #include <string.h> | ||||
| #include <float.h> | |||||
| #include "nnacl/fp32/activation_fp32.h" | #include "nnacl/fp32/activation_fp32.h" | ||||
| #include "nnacl/fp32/arithmetic_fp32.h" | #include "nnacl/fp32/arithmetic_fp32.h" | ||||
| @@ -79,21 +80,63 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int | |||||
| } | } | ||||
| } | } | ||||
| int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size) { | |||||
| int index = 0; | |||||
| #ifdef ENABLE_NEON | |||||
| for (; index <= element_size - 4; index += C4NUM) { | |||||
| float32x4_t vin0 = vld1q_f32(input0 + index); | |||||
| float32x4_t vout = vld1q_f32(output + index); | |||||
| vout = vmlaq_n_f32(vout, vin0, input1); | |||||
| vst1q_f32(output + index, vout); | |||||
| } | |||||
| #endif | |||||
| for (; index < element_size; index++) { | |||||
| output[index] += input0[index] * input1; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| void UpdataState(float *cell_state, const float *forget_gate, const float *input_gate, const float *cell_gate, | void UpdataState(float *cell_state, const float *forget_gate, const float *input_gate, const float *cell_gate, | ||||
| int batch, int hidden_size) { | |||||
| float *state_buffer, int batch, int hidden_size, const float smooth) { | |||||
| if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // smooth * old_cell_state | |||||
| memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float)); | |||||
| ArithmeticParameter parameter; | |||||
| parameter.in_elements_num0_ = batch * hidden_size; | |||||
| parameter.in_elements_num1_ = 1; | |||||
| ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter); | |||||
| } | |||||
| ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size); | ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size); | ||||
| ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size); | ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size); | ||||
| if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // (1 - smooth) * new_cell_state | |||||
| ElementOptMulAcc(cell_state, 1 - smooth, state_buffer, batch * hidden_size); | |||||
| } | |||||
| } | } | ||||
| void UpdataOutput(const float *cell_state, const float *output_gate, float *hidden_state, int batch, int hidden_size) { | |||||
| void UpdataOutput(const float *cell_state, const float *output_gate, float *hidden_state, float *state_buffer_in, | |||||
| int batch, int hidden_size, const float smooth) { | |||||
| float *state_buffer = state_buffer_in + batch * hidden_size; | |||||
| if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { | |||||
| memcpy(state_buffer, hidden_state, batch * hidden_size * sizeof(float)); | |||||
| ArithmeticParameter parameter; | |||||
| parameter.in_elements_num0_ = batch * hidden_size; | |||||
| parameter.in_elements_num1_ = 1; | |||||
| ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter); | |||||
| } | |||||
| Tanh(cell_state, batch * hidden_size, hidden_state); | Tanh(cell_state, batch * hidden_size, hidden_state); | ||||
| ElementMul(hidden_state, output_gate, hidden_state, batch * hidden_size); | ElementMul(hidden_state, output_gate, hidden_state, batch * hidden_size); | ||||
| if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { | |||||
| ElementOptMulAcc(hidden_state, 1 - smooth, state_buffer, batch * hidden_size); | |||||
| } | |||||
| } | } | ||||
| void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight, | void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight, | ||||
| const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight, | const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight, | ||||
| const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight, | const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight, | ||||
| const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, | |||||
| const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, | |||||
| const LstmParameter *lstm_parm) { | const LstmParameter *lstm_parm) { | ||||
| InitGate(gate_buffer, bias, lstm_parm); | InitGate(gate_buffer, bias, lstm_parm); | ||||
| @@ -129,17 +172,26 @@ void LstmStepUnit(float *output, const float *input, const float *input_input_we | |||||
| // update cell_gate | // update cell_gate | ||||
| Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate); | Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate); | ||||
| // update cell state | // update cell state | ||||
| UpdataState(cell_state, forget_gate, input_gate, cell_gate, lstm_parm->batch_, lstm_parm->hidden_size_); | |||||
| UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, | |||||
| lstm_parm->smooth_); | |||||
| // update output_gate | // update output_gate | ||||
| Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate); | Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate); | ||||
| // update output | // update output | ||||
| UpdataOutput(cell_state, output_gate, hidden_state, lstm_parm->batch_, lstm_parm->hidden_size_); | |||||
| UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, | |||||
| lstm_parm->smooth_); | |||||
| memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); | memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); | ||||
| if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) { | |||||
| memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); | |||||
| memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_, | |||||
| lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); | |||||
| } | |||||
| } | } | ||||
| void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, | void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, | ||||
| float *hidden_state, float *cell_state, float *gate_buffer, const LstmParameter *lstm_parm) { | |||||
| float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, | |||||
| const LstmParameter *lstm_parm) { | |||||
| // forward | // forward | ||||
| const float *input_input_weight = weight_i; | const float *input_input_weight = weight_i; | ||||
| const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; | const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; | ||||
| @@ -156,7 +208,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float | |||||
| float *output_ptr = output + t * lstm_parm->output_step_; | float *output_ptr = output + t * lstm_parm->output_step_; | ||||
| LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight, | LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight, | ||||
| state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state, | state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state, | ||||
| cell_state, gate_buffer, lstm_parm); | |||||
| cell_state, gate_buffer, state_buffer, lstm_parm); | |||||
| } | } | ||||
| // backward | // backward | ||||
| @@ -180,7 +232,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float | |||||
| float *output_ptr = backward_output + t * lstm_parm->output_step_; | float *output_ptr = backward_output + t * lstm_parm->output_step_; | ||||
| LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, | LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, | ||||
| input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, | input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, | ||||
| backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, lstm_parm); | |||||
| backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, lstm_parm); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -31,13 +31,24 @@ typedef struct LstmParameter { | |||||
| int input_step_; | int input_step_; | ||||
| int output_step_; | int output_step_; | ||||
| bool bidirectional_; | bool bidirectional_; | ||||
| // smooth factor for hidden/cell state calculation: | |||||
| // output_hidden = old_hidden * smooth + new_hidden * (1 - smooth) | |||||
| // output_cell = old_cell * smooth + new_cell * (1 - smooth) | |||||
| float smooth_; | |||||
| } LstmParameter; | } LstmParameter; | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size); | |||||
| void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size); | |||||
| int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size); | |||||
| void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, | void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, | ||||
| float *hidden_state, float *cell_state, float *gate_buffer, const LstmParameter *lstm_parm); | |||||
| float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, | |||||
| const LstmParameter *lstm_parm); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -262,6 +262,7 @@ union PrimitiveType { | |||||
| Merge, | Merge, | ||||
| Mod, | Mod, | ||||
| GeLU, | GeLU, | ||||
| Gru, | |||||
| } | } | ||||
| enum QuantType: int { | enum QuantType: int { | ||||
| @@ -1005,6 +1005,11 @@ table OneHot { | |||||
| table Lstm{ | table Lstm{ | ||||
| bidirection: bool = false; | bidirection: bool = false; | ||||
| smooth: float = 0.0; | |||||
| } | |||||
| table Gru{ | |||||
| bidirection: bool = false; | |||||
| } | } | ||||
| table PriorBox { | table PriorBox { | ||||
| @@ -0,0 +1,121 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/gru.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| bool Gru::GetBidirection() const { return this->primitive_->value.AsGru()->bidirection; } | |||||
| void Gru::SetBidirection(bool bidirection) { this->primitive_->value.AsGru()->bidirection = bidirection; } | |||||
| #else | |||||
| bool Gru::GetBidirection() const { return this->primitive_->value_as_Gru()->bidirection(); } | |||||
| int Gru::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_Gru(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Gru return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateGru(*fbb, attr->bidirection()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Gru, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *GruCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Gru>(primitive); } | |||||
| Registry GruRegistry(schema::PrimitiveType_Gru, GruCreator); | |||||
| #endif | |||||
| const int kGruInputNum = 5; | |||||
| const int kGruInputWithSeqLenNum = 6; | |||||
| const int kGruOutputNum = 2; | |||||
| int Gru::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| MS_ASSERT(this->primitive_ != nullptr); | |||||
| if ((inputs_.size() != kGruInputNum && inputs_.size() != kGruInputWithSeqLenNum) || | |||||
| outputs_.size() != kGruOutputNum) { | |||||
| MS_LOG(ERROR) << "OpGru inputs or outputs size error."; | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| } | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto weight_gate = inputs_.at(1); | |||||
| MS_ASSERT(weight_gate != nullptr); | |||||
| auto weight_recurrence = inputs_.at(2); | |||||
| MS_ASSERT(weight_recurrence != nullptr); | |||||
| auto bias = inputs_.at(3); | |||||
| MS_ASSERT(bias != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| for (int i = 0; i < kGruOutputNum; i++) { | |||||
| outputs_.at(i)->set_data_type(input->data_type()); | |||||
| outputs_.at(i)->set_format(input->format()); | |||||
| } | |||||
| if (!infer_flag()) { | |||||
| return RET_INFER_INVALID; | |||||
| } | |||||
| auto in_shape = input->shape(); // seq_len, batch, input_size | |||||
| auto w_gate_shape = weight_gate->shape(); // num_direction, hidden_size * 3, input_size | |||||
| auto w_recu_shape = weight_recurrence->shape(); // num_direction, hidden_size * 3, hidden_size | |||||
| auto bias_shape = bias->shape(); // num_direction, hidden_size * 6 | |||||
| if (in_shape.size() != 3 || w_gate_shape.size() != 3 || w_recu_shape.size() != 3) { | |||||
| MS_LOG(ERROR) << "OpGru input dims should be 3."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (w_gate_shape[1] != w_recu_shape[1] || w_recu_shape[1] * 2 != bias_shape[1]) { | |||||
| MS_LOG(ERROR) << "OpGru w_gate, w_recu and bias hidden size not match."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (inputs_.size() == kGruInputWithSeqLenNum) { | |||||
| auto seq_len_shape = inputs_.at(5)->shape(); | |||||
| if (seq_len_shape[0] > 1) { | |||||
| MS_LOG(WARNING) << "OpGru with batch_size > 1 only support all same sequence_len now."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (seq_len_shape.size() != 1 && seq_len_shape[0] != in_shape[1]) { | |||||
| MS_LOG(ERROR) << "OpGru sequence_len shape[0] and batch_size not match."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| int hidden_size = w_gate_shape[1] / 3; | |||||
| // set output | |||||
| std::vector<int> out_shape(in_shape); | |||||
| out_shape[2] = hidden_size; | |||||
| if (GetBidirection()) { | |||||
| out_shape.insert(out_shape.begin() + 1, 2); | |||||
| } else { | |||||
| out_shape.insert(out_shape.begin() + 1, 1); | |||||
| } | |||||
| output->set_shape(out_shape); | |||||
| // set hidden state | |||||
| std::vector<int> state_shape(in_shape); | |||||
| state_shape[0] = GetBidirection() ? 2 : 1; | |||||
| state_shape[2] = hidden_size; | |||||
| outputs_[1]->set_shape(state_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_OPS_GRU_H_ | |||||
| #define MINDSPORE_LITE_SRC_OPS_GRU_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| /* | |||||
| * gru with linear_before_reset = 0 | |||||
| */ | |||||
| class Gru : public PrimitiveC { | |||||
| public: | |||||
| Gru() = default; | |||||
| ~Gru() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(Gru, PrimitiveC); | |||||
| explicit Gru(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| void SetBidirection(bool bidirection); | |||||
| #else | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| bool GetBidirection() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_OPS_GRU_H_ | |||||
| @@ -25,11 +25,16 @@ namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| bool Lstm::GetBidirection() const { return this->primitive_->value.AsLstm()->bidirection; } | bool Lstm::GetBidirection() const { return this->primitive_->value.AsLstm()->bidirection; } | ||||
| float Lstm::GetSmooth() const { return this->primitive_->value.AsLstm()->smooth; } | |||||
| void Lstm::SetBidirection(bool bidirection) { this->primitive_->value.AsLstm()->bidirection = bidirection; } | void Lstm::SetBidirection(bool bidirection) { this->primitive_->value.AsLstm()->bidirection = bidirection; } | ||||
| void Lstm::SetSmooth(float smooth) { this->primitive_->value.AsLstm()->smooth = smooth; } | |||||
| #else | #else | ||||
| bool Lstm::GetBidirection() const { return this->primitive_->value_as_Lstm()->bidirection(); } | bool Lstm::GetBidirection() const { return this->primitive_->value_as_Lstm()->bidirection(); } | ||||
| float Lstm::GetSmooth() const { return this->primitive_->value_as_Lstm()->smooth(); } | |||||
| int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| MS_ASSERT(nullptr != fbb); | MS_ASSERT(nullptr != fbb); | ||||
| @@ -38,7 +43,7 @@ int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F | |||||
| MS_LOG(ERROR) << "value_as_Lstm return nullptr"; | MS_LOG(ERROR) << "value_as_Lstm return nullptr"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto val_offset = schema::CreateLstm(*fbb, attr->bidirection()); | |||||
| auto val_offset = schema::CreateLstm(*fbb, attr->bidirection(), attr->smooth()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Lstm, val_offset.o); | auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Lstm, val_offset.o); | ||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -33,12 +33,14 @@ class Lstm : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(Lstm, PrimitiveC); | MS_DECLARE_PARENT(Lstm, PrimitiveC); | ||||
| explicit Lstm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit Lstm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| void SetBidirection(bool bidirection); | void SetBidirection(bool bidirection); | ||||
| void SetSmooth(float smooth); | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | ||||
| bool GetBidirection() const; | bool GetBidirection() const; | ||||
| float GetSmooth() const; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/gru.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| #include "nnacl/fp32/gru_fp32.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateGruParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| GruParameter *gru_param = reinterpret_cast<GruParameter *>(malloc(sizeof(GruParameter))); | |||||
| if (gru_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc GruParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(gru_param, 0, sizeof(GruParameter)); | |||||
| gru_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::Gru *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| if (param == nullptr) { | |||||
| free(gru_param); | |||||
| MS_LOG(ERROR) << "get Gru param nullptr."; | |||||
| return nullptr; | |||||
| } | |||||
| gru_param->bidirectional_ = param->GetBidirection(); | |||||
| return reinterpret_cast<OpParameter *>(gru_param); | |||||
| } | |||||
| Registry GruParameterRegistry(schema::PrimitiveType_Gru, PopulateGruParameter); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -36,6 +36,7 @@ OpParameter *PopulateLstmParameter(const mindspore::lite::PrimitiveC *primitive) | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| lstm_param->bidirectional_ = param->GetBidirection(); | lstm_param->bidirectional_ = param->GetBidirection(); | ||||
| lstm_param->smooth_ = param->GetSmooth(); | |||||
| return reinterpret_cast<OpParameter *>(lstm_param); | return reinterpret_cast<OpParameter *>(lstm_param); | ||||
| } | } | ||||
| Registry LstmParameterRegistry(schema::PrimitiveType_Lstm, PopulateLstmParameter); | Registry LstmParameterRegistry(schema::PrimitiveType_Lstm, PopulateLstmParameter); | ||||
| @@ -161,6 +161,7 @@ | |||||
| #include "src/ops/switch.h" | #include "src/ops/switch.h" | ||||
| #include "src/ops/partial.h" | #include "src/ops/partial.h" | ||||
| #include "src/ops/gelu.h" | #include "src/ops/gelu.h" | ||||
| #include "src/ops/gru.h" | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| #include "src/ops/neg_grad.h" | #include "src/ops/neg_grad.h" | ||||
| @@ -995,6 +996,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) AssertOP(primitive); | return new (std::nothrow) AssertOP(primitive); | ||||
| case schema::PrimitiveType_GeLU: | case schema::PrimitiveType_GeLU: | ||||
| return new (std::nothrow) GeLU(primitive); | return new (std::nothrow) GeLU(primitive); | ||||
| case schema::PrimitiveType_Gru: | |||||
| return new (std::nothrow) Gru(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| return new (std::nothrow) ActivationGrad(primitive); | return new (std::nothrow) ActivationGrad(primitive); | ||||
| @@ -0,0 +1,165 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/gru_fp32.h" | |||||
| #include <vector> | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Gru; | |||||
| namespace mindspore::kernel { | |||||
| void GruCPUKernel::FreeTmpBuffer() { | |||||
| if (gate_buffer_ != nullptr) { | |||||
| free(gate_buffer_); | |||||
| gate_buffer_ = nullptr; | |||||
| } | |||||
| if (bias_ptr_ != nullptr) { | |||||
| free(bias_ptr_); | |||||
| bias_ptr_ = nullptr; | |||||
| } | |||||
| weight_g_ptr_ = nullptr; | |||||
| weight_r_ptr_ = nullptr; | |||||
| } | |||||
| int GruCPUKernel::InitParam() { | |||||
| auto input = in_tensors_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| std::vector<int> in_shape = input->shape(); | |||||
| gru_parm_->seq_len_ = in_shape.at(0); | |||||
| gru_parm_->batch_ = in_shape.at(1); | |||||
| gru_parm_->input_size_ = in_shape.at(2); | |||||
| auto weight_g = in_tensors_.at(1); | |||||
| MS_ASSERT(weight_g != nullptr); | |||||
| std::vector<int> w_shape = weight_g->shape(); | |||||
| gru_parm_->hidden_size_ = w_shape.at(1) / 3; | |||||
| gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_; | |||||
| gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_ | |||||
| : gru_parm_->batch_ * gru_parm_->hidden_size_; | |||||
| return RET_OK; | |||||
| } | |||||
| int GruCPUKernel::InitBuffer() { | |||||
| gate_buffer_ = reinterpret_cast<float *>(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float))); | |||||
| if (gate_buffer_ == nullptr) { | |||||
| MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int GruCPUKernel::InitWeightBias() { | |||||
| auto weight_gate = in_tensors_.at(1); | |||||
| MS_ASSERT(weight_gate != nullptr); | |||||
| weight_g_ptr_ = reinterpret_cast<float *>(weight_gate->data_c()); | |||||
| auto weight_recu = in_tensors_.at(2); | |||||
| MS_ASSERT(weight_recu != nullptr); | |||||
| weight_r_ptr_ = reinterpret_cast<float *>(weight_recu->data_c()); | |||||
| int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_; | |||||
| bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float))); | |||||
| if (bias_ptr_ == nullptr) { | |||||
| MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c()); | |||||
| const int state_bias_offset = 3 * gru_parm_->hidden_size_; | |||||
| for (int i = 0; i < state_bias_offset; i++) { | |||||
| bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; | |||||
| } | |||||
| if (gru_parm_->bidirectional_) { | |||||
| bias_data += 3 * gru_parm_->hidden_size_ * 2; | |||||
| auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_; | |||||
| for (int i = 0; i < state_bias_offset; i++) { | |||||
| backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset]; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int GruCPUKernel::Init() { | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int GruCPUKernel::ReSize() { | |||||
| FreeTmpBuffer(); | |||||
| auto ret = InitParam(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "GruCPUKernel InitParam error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ret = InitWeightBias(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "GruCPUKernel InitWeightBias error."; | |||||
| FreeTmpBuffer(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| ret = InitBuffer(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "GruCPUKernel InitBuffer error."; | |||||
| FreeTmpBuffer(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int GruCPUKernel::Run() { | |||||
| auto input = in_tensors_.at(kInputIndex); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto hidden_state = in_tensors_.at(4); | |||||
| MS_ASSERT(hidden_state != nullptr); | |||||
| auto output = out_tensors_.at(0); | |||||
| MS_ASSERT(output != nullptr); | |||||
| auto input_ptr = reinterpret_cast<float *>(input->data_c()); | |||||
| MS_ASSERT(input_ptr); | |||||
| auto output_ptr = reinterpret_cast<float *>(output->MutableData()); | |||||
| MS_ASSERT(output_ptr); | |||||
| auto output_hidden_state = out_tensors_[1]; | |||||
| memcpy(output_hidden_state->MutableData(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float)); | |||||
| int check_seq_len = gru_parm_->seq_len_; | |||||
| if (in_tensors_.size() == 6) { | |||||
| auto seq_len = reinterpret_cast<int *>(in_tensors_.at(5)->data_c()); | |||||
| if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) { | |||||
| MS_LOG(ERROR) << "different batch seq_len is currently not supported"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0])); | |||||
| } | |||||
| MS_ASSERT(weight_g_ptr_); | |||||
| MS_ASSERT(weight_r_ptr_); | |||||
| MS_ASSERT(bias_ptr_); | |||||
| MS_ASSERT(gate_buffer_); | |||||
| Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_, | |||||
| reinterpret_cast<float *>(output_hidden_state->MutableData()), gate_buffer_, check_seq_len, gru_parm_); | |||||
| return RET_OK; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gru, LiteKernelCreator<GruCPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "nnacl/fp32/gru_fp32.h" | |||||
| namespace mindspore::kernel { | |||||
| class GruCPUKernel : public LiteKernel { | |||||
| public: | |||||
| GruCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| gru_parm_ = reinterpret_cast<GruParameter *>(op_parameter_); | |||||
| } | |||||
| ~GruCPUKernel() override { FreeTmpBuffer(); } | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| void FreeTmpBuffer(); | |||||
| int InitParam(); | |||||
| int InitBuffer(); | |||||
| int InitWeightBias(); | |||||
| float *gate_buffer_ = nullptr; | |||||
| const float *weight_g_ptr_ = nullptr; | |||||
| const float *weight_r_ptr_ = nullptr; | |||||
| float *bias_ptr_ = nullptr; | |||||
| GruParameter *gru_parm_ = nullptr; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_ | |||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32/lstm_fp32.h" | #include "src/runtime/kernel/arm/fp32/lstm_fp32.h" | ||||
| #include <float.h> | |||||
| #include <vector> | #include <vector> | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| @@ -32,6 +33,10 @@ void LstmCPUKernel::FreeTmpBuffer() { | |||||
| free(gate_buffer_); | free(gate_buffer_); | ||||
| gate_buffer_ = nullptr; | gate_buffer_ = nullptr; | ||||
| } | } | ||||
| if (state_buffer_ != nullptr) { | |||||
| free(state_buffer_); | |||||
| state_buffer_ = nullptr; | |||||
| } | |||||
| if (weight_i_ptr_ != nullptr) { | if (weight_i_ptr_ != nullptr) { | ||||
| free(weight_i_ptr_); | free(weight_i_ptr_); | ||||
| weight_i_ptr_ = nullptr; | weight_i_ptr_ = nullptr; | ||||
| @@ -71,6 +76,14 @@ int LstmCPUKernel::InitBuffer() { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; | MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (!(lstm_parm_->smooth_ >= -FLT_EPSILON && lstm_parm_->smooth_ <= FLT_EPSILON)) { | |||||
| int buffer_size = 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float); | |||||
| state_buffer_ = reinterpret_cast<float *>(malloc(buffer_size)); | |||||
| if (state_buffer_ == nullptr) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -173,7 +186,7 @@ int LstmCPUKernel::Run() { | |||||
| MS_ASSERT(gate_buffer_); | MS_ASSERT(gate_buffer_); | ||||
| Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, | Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, | ||||
| reinterpret_cast<float *>(output_hidden_state->MutableData()), | reinterpret_cast<float *>(output_hidden_state->MutableData()), | ||||
| reinterpret_cast<float *>(output_cell_state->MutableData()), gate_buffer_, lstm_parm_); | |||||
| reinterpret_cast<float *>(output_cell_state->MutableData()), gate_buffer_, state_buffer_, lstm_parm_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -44,6 +44,7 @@ class LstmCPUKernel : public LiteKernel { | |||||
| int InitWeightBias(); | int InitWeightBias(); | ||||
| float *gate_buffer_ = nullptr; | float *gate_buffer_ = nullptr; | ||||
| float *state_buffer_ = nullptr; | |||||
| float *weight_i_ptr_ = nullptr; | float *weight_i_ptr_ = nullptr; | ||||
| float *weight_h_ptr_ = nullptr; | float *weight_h_ptr_ = nullptr; | ||||
| float *bias_ptr_ = nullptr; | float *bias_ptr_ = nullptr; | ||||
| @@ -187,6 +187,9 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ||||
| @@ -0,0 +1 @@ | |||||
| decoder_step_201217.pb 5 | |||||
| @@ -925,6 +925,41 @@ function Run_arm64() { | |||||
| fi | fi | ||||
| done < ${models_compatibility_config} | done < ${models_compatibility_config} | ||||
| # Run tf converted models: | |||||
| while read line; do | |||||
| model_name=${line} | |||||
| if [[ $model_name == \#* ]]; then | |||||
| continue | |||||
| fi | |||||
| model_name=`echo ${tf_line_info}|awk -F ' ' '{print $1}'` | |||||
| input_num=`echo ${tf_line_info}|awk -F ' ' '{print $2}'` | |||||
| input_files='' | |||||
| for i in $(seq 1 $input_num) | |||||
| do | |||||
| input_files=$input_files'/data/local/tmp/input_output/input/'$model_name'.ms_'$i'.bin,' | |||||
| done | |||||
| echo ${model_name} >> "${run_arm64_log_file}" | |||||
| echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt | |||||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out' >> "${run_arm64_log_file}" | |||||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out' >> adb_run_cmd.txt | |||||
| adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" | |||||
| if [ $? = 0 ]; then | |||||
| run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||||
| else | |||||
| run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||||
| fi | |||||
| # run benchmark test without clib data | |||||
| echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt | |||||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> "${run_arm64_log_file}" | |||||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> adb_run_cmd.txt | |||||
| adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" | |||||
| if [ $? = 0 ]; then | |||||
| run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||||
| else | |||||
| run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||||
| fi | |||||
| done < ${models_tf_config} | |||||
| # Run tflite converted models: | # Run tflite converted models: | ||||
| while read line; do | while read line; do | ||||
| model_name=${line} | model_name=${line} | ||||
| @@ -46,6 +46,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/fusion/batchmatmul_fusion.cc | ../optimizer/fusion/batchmatmul_fusion.cc | ||||
| ../optimizer/fusion/sigmoid_mul_fusion.cc | ../optimizer/fusion/sigmoid_mul_fusion.cc | ||||
| ../optimizer/fusion/conv_conv_fusion.cc | ../optimizer/fusion/conv_conv_fusion.cc | ||||
| ../optimizer/fusion/tflite_lstm_cell_fusion.cc | |||||
| ../optimizer/fusion/tf_lstm_cell_fusion.cc | |||||
| ../optimizer/fusion/bidirection_tf_gru_cell_fusion.cc | |||||
| ../optimizer/graph/weight_format_transform_pass.cc | ../optimizer/graph/weight_format_transform_pass.cc | ||||
| ../optimizer/graph/weight_format_hardcode_pass.cc | ../optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ../optimizer/graph/clip_convert_activation_pass.cc | ../optimizer/graph/clip_convert_activation_pass.cc | ||||
| @@ -29,6 +29,9 @@ | |||||
| #include "tools/optimizer/fusion/batchmatmul_fusion.h" | #include "tools/optimizer/fusion/batchmatmul_fusion.h" | ||||
| #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | ||||
| #include "tools/optimizer/fusion/conv_conv_fusion.h" | #include "tools/optimizer/fusion/conv_conv_fusion.h" | ||||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||||
| #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" | |||||
| #include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" | |||||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | #include "tools/optimizer/graph/mindir_adjust_pass.h" | ||||
| #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" | #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" | ||||
| #include "tools/optimizer/graph/identity_remove_pass.h" | #include "tools/optimizer/graph/identity_remove_pass.h" | ||||
| @@ -114,6 +117,9 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||||
| fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>()); | fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>()); | ||||
| fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>()); | fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>()); | ||||
| fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>()); | fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>()); | ||||
| fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::BiDirectionTfGruCellFusion>()); | |||||
| } | } | ||||
| auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>(); | auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>(); | ||||
| weight_format_hardcode_pass->SetFmkType(config->fmk); | weight_format_hardcode_pass->SetFmkType(config->fmk); | ||||
| @@ -572,7 +572,12 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C | |||||
| if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { | if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { | ||||
| type = TypeIdToType(kObjectTypeTensorType); | type = TypeIdToType(kObjectTypeTensorType); | ||||
| } | } | ||||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vector)); | |||||
| auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape_vector); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "create AbstractTensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| anf_node->set_abstract(abstract); | |||||
| anf_node_map->insert(std::pair(op.name(), anf_node)); | anf_node_map->insert(std::pair(op.name(), anf_node)); | ||||
| } else { | } else { | ||||
| AbstractBasePtrList abstractList; | AbstractBasePtrList abstractList; | ||||
| @@ -589,6 +594,12 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C | |||||
| std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue}; | std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue}; | ||||
| CNodePtr getItemCNode = anf_graph->NewCNode(inputs); | CNodePtr getItemCNode = anf_graph->NewCNode(inputs); | ||||
| std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | ||||
| auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "create AbstractTensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| getItemCNode->set_abstract(abstract); | |||||
| getItemCNode->set_fullname_with_scope(output_item_name); | getItemCNode->set_fullname_with_scope(output_item_name); | ||||
| anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); | anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); | ||||
| } | } | ||||
| @@ -63,7 +63,11 @@ STATUS TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| } | } | ||||
| *output_size = 1; | *output_size = 1; | ||||
| return AddOpInput(tf_op, 0, inputs); | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| return AddOpInput(tf_op, 1, inputs); | |||||
| } | } | ||||
| TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser()); | TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_select_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFSelectParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF SelectParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::SwitchT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Switch; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||||
| inputs->emplace_back(tf_op.input(i)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TFNodeRegistrar g_tfSelectParser("Select", new TFSelectParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFSelectParser : public TFNodeParser { | |||||
| public: | |||||
| TFSelectParser() = default; | |||||
| ~TFSelectParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SELECT_PARSER_H_ | |||||
| @@ -122,7 +122,7 @@ std::string TensorFlowUtils::GetFlattenNodeName(const std::string &input_name) { | |||||
| std::sregex_token_iterator()); | std::sregex_token_iterator()); | ||||
| std::string ret = input_name; | std::string ret = input_name; | ||||
| if (input_splits.size() == 3) { | if (input_splits.size() == 3) { | ||||
| if (input_splits[2] == "0") { | |||||
| if (input_splits[2].compare("0") == 0) { | |||||
| ret = input_splits[0]; | ret = input_splits[0]; | ||||
| } else { | } else { | ||||
| ret = input_splits[0] + ":" + input_splits[2]; // multi output node | ret = input_splits[0] + ":" + input_splits[2]; // multi output node | ||||
| @@ -0,0 +1,679 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" | |||||
| #include <memory> | |||||
| #include <functional> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "utils/utils.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| #include "securec/include/securec.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| constexpr size_t kWhileCommonInputsLength = 2; | |||||
| constexpr size_t kWhileUniqInputsLength = 6; | |||||
| constexpr size_t kCondNodesNum = 12; | |||||
| constexpr size_t kCondCNodesNum = 4; | |||||
| constexpr size_t kBodyNodesNum = 69; | |||||
| constexpr size_t kBodyCNodesNum = 25; | |||||
| const auto &p1 = std::placeholders::_1; | |||||
| bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); } | |||||
| bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) { | |||||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||||
| return opt::GetCNodeType(n) == type; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace | |||||
| BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, bool multigraph) | |||||
| : PatternProcessPass(name, multigraph) { | |||||
| /* | |||||
| * vars for while input | |||||
| * common: | |||||
| * 0:const0 1:init_state | |||||
| * fw_while_inputs: | |||||
| * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias | |||||
| * bw_while_inputs: | |||||
| * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias | |||||
| */ | |||||
| for (size_t i = 0; i < kWhileCommonInputsLength; ++i) { | |||||
| common_vars_.emplace_back(std::make_shared<Var>()); | |||||
| } | |||||
| for (size_t i = 0; i < kWhileUniqInputsLength; ++i) { | |||||
| fw_vars_.emplace_back(std::make_shared<Var>()); | |||||
| bw_vars_.emplace_back(std::make_shared<Var>()); | |||||
| } | |||||
| input_ = std::make_shared<Var>(); | |||||
| input_length_ = std::make_shared<Var>(); | |||||
| transpose_input_ = std::make_shared<Var>(); | |||||
| } | |||||
| const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { | |||||
| auto const1 = std::make_shared<CondVar>(IsParameterNode); | |||||
| auto ele_shape = std::make_shared<CondVar>(IsParameterNode); | |||||
| // forward | |||||
| auto fw_max1 = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_}); | |||||
| auto fw_max2 = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, fw_max1}); | |||||
| auto fw_shape = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_}); | |||||
| auto fw_stride = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), fw_shape}); | |||||
| auto fw_min = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2}); | |||||
| auto fw_reserve = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape, | |||||
| fw_stride}); | |||||
| auto fw_from_tensor = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), | |||||
| transpose_input_, ele_shape}); | |||||
| auto is_fw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While)); | |||||
| auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], common_vars_[0], fw_stride, common_vars_[0], | |||||
| fw_reserve, common_vars_[1], fw_min, fw_from_tensor, input_length_}); | |||||
| fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end()); | |||||
| fw_while.emplace_back(common_vars_[1]); | |||||
| auto fw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)), | |||||
| fw_while, std::make_shared<Var>()}); | |||||
| auto fw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)), | |||||
| fw_get_item, ele_shape}); | |||||
| auto fw_out_trans = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), fw_stack}); | |||||
| // backward | |||||
| auto bw_reverse_seq = VectorRef( | |||||
| {std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), input_, input_length_}); | |||||
| auto bw_max1 = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_}); | |||||
| auto bw_max2 = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, bw_max1}); | |||||
| auto bw_trans = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_reverse_seq}); | |||||
| auto bw_shape = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans}); | |||||
| auto bw_stride = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), bw_shape}); | |||||
| auto bw_min = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2}); | |||||
| auto bw_reserve = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape, | |||||
| bw_stride}); | |||||
| auto bw_from_tensor = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), bw_trans, | |||||
| ele_shape}); | |||||
| auto is_bw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While)); | |||||
| auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], common_vars_[0], bw_stride, common_vars_[0], | |||||
| bw_reserve, common_vars_[1], bw_min, bw_from_tensor, input_length_}); | |||||
| bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end()); | |||||
| bw_while.emplace_back(common_vars_[1]); | |||||
| auto bw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)), | |||||
| bw_while, std::make_shared<Var>()}); | |||||
| auto bw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)), | |||||
| bw_get_item, ele_shape}); | |||||
| auto bw_out_trans = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_stack}); | |||||
| auto bw_reverse1 = | |||||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), bw_out_trans, | |||||
| input_length_}); | |||||
| auto concat = VectorRef( | |||||
| {std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Concat)), fw_out_trans, bw_reverse1}); | |||||
| return concat; | |||||
| } | |||||
| AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||||
| auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode); | |||||
| auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode); | |||||
| auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode); | |||||
| auto is_parameter4 = std::make_shared<CondVar>(IsParameterNode); | |||||
| auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); | |||||
| auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); | |||||
| auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_LogicalAnd)); | |||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); | |||||
| VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); | |||||
| VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); | |||||
| VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); | |||||
| VectorRef return_ref = VectorRef({is_return, logicaland_ref}); | |||||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||||
| auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true); | |||||
| return pattern; | |||||
| } | |||||
| AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||||
| std::vector<CondVarPtr> placeholders; | |||||
| for (int i = 0; i < 13; ++i) { | |||||
| placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode)); | |||||
| } | |||||
| VectorRef add = VectorRef({std::make_shared<Var>(), placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||||
| VectorRef add1 = VectorRef({std::make_shared<Var>(), placeholders[0], std::make_shared<CondVar>(IsParameterNode)}); | |||||
| VectorRef get_item = VectorRef( | |||||
| {std::make_shared<Var>("GetItem"), placeholders[6], placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||||
| VectorRef concat_input_h = VectorRef({std::make_shared<Var>(), get_item, placeholders[4]}); | |||||
| VectorRef matmul1 = VectorRef({std::make_shared<Var>("Matmul"), concat_input_h, placeholders[8]}); | |||||
| VectorRef biasadd1 = VectorRef({std::make_shared<Var>("BiasAdd"), matmul1, placeholders[9]}); | |||||
| VectorRef sigmoid1 = VectorRef({std::make_shared<Var>("Sigmoid"), biasadd1}); | |||||
| VectorRef split = VectorRef({std::make_shared<Var>("Split"), sigmoid1}); | |||||
| VectorRef get_item1 = VectorRef({std::make_shared<Var>("TupleGetItem"), split, std::make_shared<Var>()}); | |||||
| VectorRef get_item2 = VectorRef({std::make_shared<Var>("TupleGetItem"), split, std::make_shared<Var>()}); | |||||
| VectorRef pre_reset = VectorRef({std::make_shared<Var>("Mul"), get_item1, placeholders[4]}); | |||||
| VectorRef concat2 = VectorRef({std::make_shared<Var>("Concat"), get_item, pre_reset}); | |||||
| VectorRef matmul2 = VectorRef({std::make_shared<Var>("Matmul"), concat2, placeholders[10]}); | |||||
| VectorRef biasadd2 = VectorRef({std::make_shared<Var>("BiasAdd"), matmul2, placeholders[11]}); | |||||
| VectorRef tanh = VectorRef({std::make_shared<Var>("Tanh"), biasadd2}); | |||||
| VectorRef update_hidden = VectorRef({std::make_shared<Var>("Mul"), get_item2, placeholders[4]}); | |||||
| VectorRef minus_update = | |||||
| VectorRef({std::make_shared<Var>("Sub"), std::make_shared<CondVar>(IsParameterNode), get_item2}); | |||||
| VectorRef updated = VectorRef({std::make_shared<Var>("Mul"), minus_update, tanh}); | |||||
| VectorRef new_hidden = VectorRef({std::make_shared<Var>("Add"), update_hidden, updated}); | |||||
| VectorRef greater_equal = VectorRef({std::make_shared<Var>("GreaterEqual"), placeholders[2], placeholders[7]}); | |||||
| VectorRef select_output = VectorRef({std::make_shared<Var>("Switch"), greater_equal, placeholders[12], new_hidden}); | |||||
| VectorRef output = VectorRef({std::make_shared<Var>("SetItem"), placeholders[3], placeholders[2], select_output}); | |||||
| VectorRef select_hidden = VectorRef({std::make_shared<Var>("Switch"), greater_equal, placeholders[4], new_hidden}); | |||||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple)); | |||||
| std::vector<BaseRef> outputs = {is_make_tuple, add1, placeholders[1], add, | |||||
| output, select_hidden, placeholders[5], placeholders[6], | |||||
| placeholders[7]}; | |||||
| outputs.insert(outputs.end(), placeholders.begin() + 8, placeholders.end()); | |||||
| VectorRef make_tuple_node = VectorRef(outputs); | |||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); | |||||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | |||||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||||
| auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true); | |||||
| return pattern; | |||||
| } | |||||
| ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const { | |||||
| MS_ASSERT(parameter_anf != nullptr); | |||||
| if (!utils::isa<ParameterPtr>(parameter_anf)) { | |||||
| MS_LOG(DEBUG) << "parameter_anf is not ParameterPtr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto parameter = utils::cast<ParameterPtr>(parameter_anf); | |||||
| if (!parameter->has_default()) { | |||||
| MS_LOG(DEBUG) << "parameter not have default value"; | |||||
| return nullptr; | |||||
| } | |||||
| auto param_value = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); | |||||
| return param_value; | |||||
| } | |||||
| STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, | |||||
| const AnfNodePtr &bw_cand_kernel_anf, int *input_size, | |||||
| int *hidden_size) const { | |||||
| MS_ASSERT(fw_cand_kernel != nullptr); | |||||
| MS_ASSERT(bw_cand_kernel != nullptr); | |||||
| MS_ASSERT(input_size != nullptr); | |||||
| MS_ASSERT(hidden_size != nullptr); | |||||
| auto fw_cand_kernel_value = GetDefaultParamValue(fw_cand_kernel_anf); | |||||
| if (fw_cand_kernel_value == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto fw_cand_kernel_shape = fw_cand_kernel_value->tensor_shape(); | |||||
| if (fw_cand_kernel_shape.size() != 2) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto bw_cand_kernel_value = GetDefaultParamValue(bw_cand_kernel_anf); | |||||
| if (bw_cand_kernel_value == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto bw_cand_kernel_shape = bw_cand_kernel_value->tensor_shape(); | |||||
| if (bw_cand_kernel_shape.size() != 2) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (fw_cand_kernel_shape != bw_cand_kernel_shape) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (fw_cand_kernel_shape[1] <= 0 || fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1] <= 0) { | |||||
| MS_LOG(DEBUG) << "gru input size or hidden size illegal"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *hidden_size = fw_cand_kernel_shape[1]; | |||||
| *input_size = fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1]; | |||||
| return RET_OK; | |||||
| } | |||||
| ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name, | |||||
| const std::vector<int> &shape, const TypeId type, | |||||
| void **tensor_data) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(tensor_data != nullptr); | |||||
| auto parameter = func_graph->add_parameter(); | |||||
| parameter->set_name(name); | |||||
| std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape_vector); | |||||
| if (abstract_tensor == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| auto gate_weight_default = std::make_shared<ParamValueLite>(); | |||||
| if (gate_weight_default == nullptr) { | |||||
| MS_LOG(ERROR) << "gate_weight_default is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| gate_weight_default->set_tensor_shape(shape); | |||||
| gate_weight_default->set_tensor_type(type); | |||||
| gate_weight_default->set_format(schema::Format_NHWC); | |||||
| int data_len = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); | |||||
| int data_size = 0; | |||||
| if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) { | |||||
| data_size = data_len * sizeof(float); | |||||
| *tensor_data = new (std::nothrow) float[data_len]; | |||||
| } else if (type == kNumberTypeInt || type == kNumberTypeInt32) { | |||||
| data_size = data_len * sizeof(int); | |||||
| *tensor_data = new (std::nothrow) int[data_len]; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "unsupported data type"; | |||||
| return nullptr; | |||||
| } | |||||
| if (*tensor_data == nullptr) { | |||||
| MS_LOG(ERROR) << "new data failed"; | |||||
| return nullptr; | |||||
| } | |||||
| gate_weight_default->SetTensorData(*tensor_data, data_size); | |||||
| parameter->set_default_param(gate_weight_default); | |||||
| return parameter; | |||||
| } | |||||
| void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, | |||||
| const int r1, const int c0, const int c1, float *data, | |||||
| bool t) const { | |||||
| MS_ASSERT(mat != nullptr); | |||||
| MS_ASSERT(data != nullptr); | |||||
| MS_ASSERT(0 <= r0 && r0 < r1 && r1 <= R); | |||||
| MS_ASSERT(0 <= c0 && c0 < c1 && c1 <= C); | |||||
| const int RT = r1 - r0; | |||||
| const int CT = c1 - c0; | |||||
| for (int i = r0; i < r1; ++i) { | |||||
| for (int j = c0; j < c1; ++j) { | |||||
| if (t) { | |||||
| data[(j - c0) * RT + (i - r0)] = mat[i * C + j]; | |||||
| } else { | |||||
| data[(i - r0) * CT + (j - c0)] = mat[i * C + j]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, | |||||
| const int input_size, const int hidden_size, | |||||
| float *gate_tensor_data, float *recu_tensor_data) const { | |||||
| MS_ASSERT(gate_weight != nullptr); | |||||
| MS_ASSERT(cand_weight != nullptr); | |||||
| MS_ASSERT(gate_tensor_data != nullptr); | |||||
| MS_ASSERT(recu_tensor_data != nullptr); | |||||
| const std::vector<int> gate_shape{input_size + hidden_size, hidden_size * 2}; | |||||
| const std::vector<int> cand_shape{hidden_size * 2, hidden_size}; | |||||
| auto gate_weight_value = GetDefaultParamValue(gate_weight); | |||||
| if (gate_weight_value == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto gate_weight_data = reinterpret_cast<float *>(gate_weight_value->tensor_addr()); | |||||
| if (gate_weight_data == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto gate_weight_shape = gate_weight_value->tensor_shape(); | |||||
| auto cand_weight_value = GetDefaultParamValue(cand_weight); | |||||
| if (cand_weight_value == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto cand_weight_data = reinterpret_cast<float *>(cand_weight_value->tensor_addr()); | |||||
| if (cand_weight_data == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto cand_weight_shape = cand_weight_value->tensor_shape(); | |||||
| if (gate_weight_shape != gate_shape || cand_weight_shape != cand_shape) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| // input_update_weight | |||||
| CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, 0, input_size, hidden_size, | |||||
| hidden_size * 2, gate_tensor_data, true); | |||||
| // input_reset_weight | |||||
| CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, 0, input_size, 0, hidden_size, | |||||
| gate_tensor_data + input_size * hidden_size, true); | |||||
| // input_hidden_weight | |||||
| CopyFlattenMatData(cand_weight_data, input_size + hidden_size, hidden_size, 0, input_size, 0, hidden_size, | |||||
| gate_tensor_data + input_size * hidden_size * 2, true); | |||||
| // state_update_weight | |||||
| CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, input_size, input_size + hidden_size, | |||||
| hidden_size, hidden_size * 2, recu_tensor_data, true); | |||||
| // state_reset_weight | |||||
| CopyFlattenMatData(gate_weight_data, input_size + hidden_size, hidden_size * 2, input_size, input_size + hidden_size, | |||||
| 0, hidden_size, recu_tensor_data + hidden_size * hidden_size, true); | |||||
| // state_hidden_weight | |||||
| CopyFlattenMatData(cand_weight_data, input_size + hidden_size, hidden_size, input_size, input_size + hidden_size, 0, | |||||
| hidden_size, recu_tensor_data + hidden_size * hidden_size * 2, true); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, | |||||
| const int hidden_size, float *tensor_data) const { | |||||
| MS_ASSERT(bias != nullptr); | |||||
| MS_ASSERT(tensor_data != nullptr); | |||||
| std::vector<int> gate_shape{hidden_size * 2}; | |||||
| std::vector<int> cand_shape{hidden_size}; | |||||
| auto gate_bias_value = GetDefaultParamValue(gate_bias); | |||||
| if (gate_bias_value == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto gate_bias_data = reinterpret_cast<float *>(gate_bias_value->tensor_addr()); | |||||
| auto gate_bias_shape = gate_bias_value->tensor_shape(); | |||||
| auto cand_bias_value = GetDefaultParamValue(cand_bias); | |||||
| if (cand_bias_value == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto cand_bias_data = reinterpret_cast<float *>(cand_bias_value->tensor_addr()); | |||||
| auto cand_bias_shape = cand_bias_value->tensor_shape(); | |||||
| if (gate_bias_shape != gate_shape || cand_bias_shape != cand_shape) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| // update_gate bias | |||||
| CopyFlattenMatData(gate_bias_data, 1, hidden_size * 2, 0, 1, hidden_size, hidden_size * 2, tensor_data, false); | |||||
| // reset_gate bias | |||||
| CopyFlattenMatData(gate_bias_data, 1, hidden_size * 2, 0, 1, 0, hidden_size, tensor_data + hidden_size, false); | |||||
| // hidden_gate bias | |||||
| CopyFlattenMatData(cand_bias_data, 1, hidden_size, 0, 1, 0, hidden_size, tensor_data + hidden_size * 2, false); | |||||
| return RET_OK; | |||||
| } | |||||
| CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, | |||||
| const AnfNodePtr &hidden_state, | |||||
| const std::string base_name) const { | |||||
| MS_ASSERT(func_graph); | |||||
| MS_ASSERT(hidden_state); | |||||
| auto stack_primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| std::unique_ptr<schema::StackT> attr = std::make_unique<schema::StackT>(); | |||||
| attr->axis = 0; | |||||
| stack_primitive->value.type = schema::PrimitiveType_Stack; | |||||
| stack_primitive->value.value = attr.release(); | |||||
| auto stack_cvalue = lite::PrimitiveC::Create(stack_primitive.release()); | |||||
| auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(stack_cvalue)); | |||||
| std::vector<AnfNodePtr> new_node_inputs = {value_node, hidden_state, hidden_state}; | |||||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||||
| new_node->set_abstract(hidden_state->abstract()->Clone()); | |||||
| new_node->set_fullname_with_scope("stack_hidden_" + base_name); | |||||
| return new_node; | |||||
| } | |||||
| CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||||
| const EquivPtr &equiv, const EquivPtr &fw_body_equiv, | |||||
| const EquivPtr &bw_body_equiv, | |||||
| const std::string &base_name) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(input != nullptr); | |||||
| MS_ASSERT(equiv != nullptr); | |||||
| MS_ASSERT(fw_body_equiv != nullptr); | |||||
| MS_ASSERT(bw_body_equiv != nullptr); | |||||
| auto gru_primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| std::unique_ptr<schema::GruT> attr = std::make_unique<schema::GruT>(); | |||||
| attr->bidirection = true; | |||||
| gru_primitive->value.type = schema::PrimitiveType_Gru; | |||||
| gru_primitive->value.value = attr.release(); | |||||
| auto gru_cvalue = lite::PrimitiveC::Create(gru_primitive.release()); | |||||
| auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(gru_cvalue)); | |||||
| auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[2]]); | |||||
| MS_ASSERT(fw_gate_kernel); | |||||
| auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[3]]); | |||||
| MS_ASSERT(fw_gate_bias); | |||||
| auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[4]]); | |||||
| MS_ASSERT(fw_cand_kernel); | |||||
| auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[5]]); | |||||
| MS_ASSERT(fw_cand_bias); | |||||
| auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[2]]); | |||||
| MS_ASSERT(bw_gate_kernel); | |||||
| auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[3]]); | |||||
| MS_ASSERT(bw_gate_bias); | |||||
| auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[4]]); | |||||
| MS_ASSERT(bw_cand_kernel); | |||||
| auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[5]]); | |||||
| MS_ASSERT(bw_cand_bias); | |||||
| auto hidden = utils::cast<AnfNodePtr>((*equiv)[common_vars_[1]]); | |||||
| MS_ASSERT(hidden); | |||||
| auto stacked_hidden = GetStackedHiddenState(func_graph, hidden, base_name); | |||||
| if (stacked_hidden == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto input_length = utils::cast<AnfNodePtr>((*equiv)[input_length_]); | |||||
| MS_ASSERT(hidden); | |||||
| int input_size = 0; | |||||
| int hidden_size = 0; | |||||
| auto status = GetInputAndHiddenSize(fw_cand_kernel, bw_cand_kernel, &input_size, &hidden_size); | |||||
| if (status != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<int> gate_weight_shape{2, hidden_size * 3, input_size}; | |||||
| float *gate_tensor_data = nullptr; | |||||
| auto gate_weight = AddDefaultParameter(func_graph, base_name + "_gate_weight", gate_weight_shape, kNumberTypeFloat32, | |||||
| reinterpret_cast<void **>(&gate_tensor_data)); | |||||
| if (gate_weight == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<int> recu_weight_shape{2, hidden_size * 3, hidden_size}; | |||||
| float *recu_tensor_data = nullptr; | |||||
| auto recu_weight = AddDefaultParameter(func_graph, base_name + "_cand_weight", recu_weight_shape, kNumberTypeFloat32, | |||||
| reinterpret_cast<void **>(&recu_tensor_data)); | |||||
| if (recu_weight == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<int> bias_shape{2, hidden_size * 6}; | |||||
| float *bias_tensor_data = nullptr; | |||||
| auto bias = AddDefaultParameter(func_graph, base_name + "_bias", bias_shape, kNumberTypeFloat32, | |||||
| reinterpret_cast<void **>(&bias_tensor_data)); | |||||
| if (bias == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| for (int i = 0; i < 2 * hidden_size * 6; ++i) { | |||||
| bias_tensor_data[i] = 0.0f; | |||||
| } | |||||
| if (ConvertWeightData(fw_gate_kernel, fw_cand_kernel, input_size, hidden_size, gate_tensor_data, recu_tensor_data) != | |||||
| RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| auto gate_data_diff = hidden_size * input_size * 3; | |||||
| auto recu_data_diff = hidden_size * hidden_size * 3; | |||||
| if (ConvertWeightData(bw_gate_kernel, bw_cand_kernel, input_size, hidden_size, gate_tensor_data + gate_data_diff, | |||||
| recu_tensor_data + recu_data_diff) != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| if (ConvertBiasData(fw_gate_bias, fw_cand_bias, hidden_size, bias_tensor_data) != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| auto bias_data_diff = hidden_size * 6; | |||||
| if (ConvertBiasData(bw_gate_bias, bw_cand_bias, hidden_size, bias_tensor_data + bias_data_diff) != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> new_node_inputs = {value_node, input, gate_weight, recu_weight, | |||||
| bias, stacked_hidden, input_length}; | |||||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||||
| new_node->set_fullname_with_scope(base_name); | |||||
| return new_node; | |||||
| } | |||||
| CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, | |||||
| const std::string base_name) const { | |||||
| MS_ASSERT(func_graph); | |||||
| MS_ASSERT(gru_output); | |||||
| auto split_primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| std::unique_ptr<schema::SplitT> split_attr = std::make_unique<schema::SplitT>(); | |||||
| split_attr->numberSplit = 2; | |||||
| split_attr->splitDim = 1; | |||||
| split_primitive->value.type = schema::PrimitiveType_Split; | |||||
| split_primitive->value.value = split_attr.release(); | |||||
| auto split_cvalue = lite::PrimitiveC::Create(split_primitive.release()); | |||||
| auto split_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(split_cvalue)); | |||||
| std::vector<AnfNodePtr> new_node_inputs = {split_value_node, gru_output}; | |||||
| auto split_new_node = func_graph->NewCNode(new_node_inputs); | |||||
| split_new_node->set_fullname_with_scope("split_" + base_name); | |||||
| if (TfliteLstmCellFusion::SetAbstractTuple(split_new_node, 2) != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| auto split_out1 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 0); | |||||
| if (split_out1 == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto split_out2 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 1); | |||||
| if (split_out2 == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto concat_primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| std::unique_ptr<schema::ConcatT> concat_attr = std::make_unique<schema::ConcatT>(); | |||||
| concat_attr->axis = 3; | |||||
| concat_primitive->value.type = schema::PrimitiveType_Concat; | |||||
| concat_primitive->value.value = concat_attr.release(); | |||||
| auto concat_cvalue = lite::PrimitiveC::Create(concat_primitive.release()); | |||||
| auto concat_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(concat_cvalue)); | |||||
| std::vector<AnfNodePtr> concat_new_node_inputs = {concat_value_node, split_out1, split_out2}; | |||||
| auto concat_new_node = func_graph->NewCNode(concat_new_node_inputs); | |||||
| concat_new_node->set_fullname_with_scope("concat_" + base_name); | |||||
| concat_new_node->set_abstract(gru_output->abstract()->Clone()); | |||||
| auto squeeze_primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| std::unique_ptr<schema::SqueezeT> squeeze_attr = std::make_unique<schema::SqueezeT>(); | |||||
| squeeze_attr->axis = std::vector<int>{1}; | |||||
| squeeze_primitive->value.type = schema::PrimitiveType_Squeeze; | |||||
| squeeze_primitive->value.value = squeeze_attr.release(); | |||||
| auto squeeze_cvalue = lite::PrimitiveC::Create(squeeze_primitive.release()); | |||||
| auto squeeze_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(squeeze_cvalue)); | |||||
| std::vector<AnfNodePtr> squeeze_new_node_inputs = {squeeze_value_node, concat_new_node}; | |||||
| auto squeeze_new_node = func_graph->NewCNode(squeeze_new_node_inputs); | |||||
| squeeze_new_node->set_fullname_with_scope("squeeze_" + base_name); | |||||
| squeeze_new_node->set_abstract(gru_output->abstract()->Clone()); | |||||
| auto transpose_primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| std::unique_ptr<schema::TransposeT> transpose_attr = std::make_unique<schema::TransposeT>(); | |||||
| transpose_attr->perm = std::vector<int>{1, 0, 2}; | |||||
| transpose_primitive->value.type = schema::PrimitiveType_Transpose; | |||||
| transpose_primitive->value.value = transpose_attr.release(); | |||||
| auto transpose_cvalue = lite::PrimitiveC::Create(transpose_primitive.release()); | |||||
| auto transpose_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(transpose_cvalue)); | |||||
| std::vector<AnfNodePtr> transpose_new_node_inputs = {transpose_value_node, squeeze_new_node}; | |||||
| auto transpose_new_node = func_graph->NewCNode(transpose_new_node_inputs); | |||||
| transpose_new_node->set_fullname_with_scope("transpose_" + base_name); | |||||
| transpose_new_node->set_abstract(gru_output->abstract()->Clone()); | |||||
| return transpose_new_node; | |||||
| } | |||||
| const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_ASSERT(func_graph); | |||||
| MS_ASSERT(concat_node); | |||||
| MS_LOG(DEBUG) << "bidirection tf gru fusion pass"; | |||||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) { | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||||
| return nullptr; | |||||
| } | |||||
| auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]); | |||||
| MS_ASSERT(transpose_input); | |||||
| if (!utils::isa<CNodePtr>(transpose_input) || GetCNodeType(transpose_input) != schema::PrimitiveType_Transpose) { | |||||
| return nullptr; | |||||
| } | |||||
| PrimitiveVarMapPtr fw_cond_primitive_vars = std::make_shared<PrimitiveVarMap>(); | |||||
| auto fw_cond_graph_pattern = GetCondGraphPattern(fw_cond_primitive_vars); | |||||
| auto fw_cond = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[0]]); | |||||
| MS_ASSERT(fw_cond != nullptr); | |||||
| auto fw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_cond_graph_pattern, fw_cond_primitive_vars, | |||||
| fw_cond, kCondCNodesNum, kCondNodesNum); | |||||
| if (fw_cond_equiv == nullptr || fw_cond_equiv->empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| PrimitiveVarMapPtr bw_cond_primitive_vars = std::make_shared<PrimitiveVarMap>(); | |||||
| auto bw_cond_graph_pattern = GetCondGraphPattern(bw_cond_primitive_vars); | |||||
| auto bw_cond = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[0]]); | |||||
| MS_ASSERT(bw_cond != nullptr); | |||||
| auto bw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_cond_graph_pattern, bw_cond_primitive_vars, | |||||
| bw_cond, kCondCNodesNum, kCondNodesNum); | |||||
| if (bw_cond_equiv == nullptr || bw_cond_equiv->empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| PrimitiveVarMapPtr fw_primitive_vars_body = std::make_shared<PrimitiveVarMap>(); | |||||
| auto fw_body_graph_pattern = GetBodyGraphPattern(fw_primitive_vars_body); | |||||
| auto fw_body = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[1]]); | |||||
| MS_ASSERT(fw_body != nullptr); | |||||
| auto fw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_body_graph_pattern, fw_primitive_vars_body, | |||||
| fw_body, kBodyCNodesNum, kBodyNodesNum); | |||||
| if (fw_body_equiv == nullptr || fw_body_equiv->empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| PrimitiveVarMapPtr bw_primitive_vars_body = std::make_shared<PrimitiveVarMap>(); | |||||
| auto bw_body_graph_pattern = GetBodyGraphPattern(bw_primitive_vars_body); | |||||
| auto bw_body = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[1]]); | |||||
| MS_ASSERT(bw_body != nullptr); | |||||
| auto bw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_body_graph_pattern, bw_primitive_vars_body, | |||||
| bw_body, kBodyCNodesNum, kBodyNodesNum); | |||||
| if (bw_body_equiv == nullptr || bw_body_equiv->empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| const std::string gru_name = "gru_" + concat_node->fullname_with_scope(); | |||||
| auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, fw_body_equiv, bw_body_equiv, gru_name); | |||||
| if (gru_node == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0); | |||||
| if (get_item_node == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope()); | |||||
| MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success"; | |||||
| return output_node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "utils/utils.h" | |||||
| #include "include/errorcode.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class BiDirectionTfGruCellFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit BiDirectionTfGruCellFusion(const std::string &name = "bidirection_tf_gru_cell_fusion", | |||||
| bool multigraph = true); | |||||
| ~BiDirectionTfGruCellFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| protected: | |||||
| virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; | |||||
| private: | |||||
| AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; | |||||
| CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv, | |||||
| const EquivPtr &fw_body_equiv, const EquivPtr &bw_body_equiv, | |||||
| const std::string &base_name) const; | |||||
| ParamValueLitePtr GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const; | |||||
| lite::STATUS GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, const AnfNodePtr &bw_cand_kernel_anf, | |||||
| int *input_size, int *hidden_size) const; | |||||
| ParameterPtr AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name, | |||||
| const std::vector<int> &shape, const TypeId type, void **tensor_data) const; | |||||
| lite::STATUS ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, const int input_size, | |||||
| const int hidden_size, float *gate_tensor_data, float *recu_tensor_data) const; | |||||
| lite::STATUS ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, const int hidden_size, | |||||
| float *tensor_data) const; | |||||
| void CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, const int r1, const int c0, | |||||
| const int c1, float *data, bool t = false) const; | |||||
| CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &hidden_state, | |||||
| const std::string base_name) const; | |||||
| CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, | |||||
| const std::string base_name) const; | |||||
| private: | |||||
| std::vector<VarPtr> common_vars_; | |||||
| std::vector<VarPtr> fw_vars_; | |||||
| std::vector<VarPtr> bw_vars_; | |||||
| VarPtr input_; | |||||
| VarPtr input_length_; | |||||
| VarPtr transpose_input_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ | |||||
| @@ -0,0 +1,370 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" | |||||
| #include <memory> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "utils/utils.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| #include "securec/include/securec.h" | |||||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| constexpr size_t kLstmInputsLength = 13; | |||||
| constexpr size_t kLstmInputsVarNum = 11; | |||||
| constexpr size_t kCondNodesNum = 12; | |||||
| constexpr size_t kCondCNodesNum = 4; | |||||
| constexpr size_t kBodyNodesNum = 82; | |||||
| constexpr size_t kBodyCNodesNum = 30; | |||||
| const auto &p1 = std::placeholders::_1; | |||||
| bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); } | |||||
| bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) { | |||||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||||
| return opt::GetCNodeType(n) == type; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace | |||||
| TfLstmCellFusion::TfLstmCellFusion(const std::string &name, bool multigraph) | |||||
| : TfliteLstmCellFusion(name, multigraph, kLstmInputsLength, kLstmInputsVarNum, kCondNodesNum, kCondCNodesNum, | |||||
| kBodyNodesNum, kBodyCNodesNum) { | |||||
| /* | |||||
| * vars for lstm cell input | |||||
| * 0:cond 1:body 2:index 3:limit1 4:output 5:cell 6:hidden 7:limit2 8:input 9:kernel 10:bias | |||||
| */ | |||||
| forget_bias_ = std::make_shared<Var>(); | |||||
| } | |||||
| AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||||
| std::vector<CondVarPtr> placeholders; | |||||
| for (int i = 0; i < 10; ++i) { | |||||
| placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode)); | |||||
| } | |||||
| VectorRef add2 = VectorRef({std::make_shared<Var>(), placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||||
| VectorRef add3 = VectorRef({std::make_shared<Var>(), placeholders[0], std::make_shared<CondVar>(IsParameterNode)}); | |||||
| VectorRef get_item = VectorRef( | |||||
| {std::make_shared<Var>("GetItem"), placeholders[7], placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||||
| VectorRef concat_input_h = VectorRef({std::make_shared<Var>(), get_item, placeholders[5]}); | |||||
| VectorRef matmul = VectorRef({std::make_shared<Var>(), concat_input_h, placeholders[8]}); | |||||
| VectorRef bias = VectorRef({std::make_shared<Var>(), matmul, placeholders[9]}); | |||||
| VectorRef split = VectorRef({std::make_shared<Var>(), bias}); | |||||
| VectorRef get_item1 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()}); | |||||
| VectorRef get_item2 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()}); | |||||
| VectorRef get_item3 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()}); | |||||
| VectorRef get_item4 = VectorRef({std::make_shared<Var>(), split, std::make_shared<Var>()}); | |||||
| VectorRef input_gate = VectorRef({std::make_shared<Var>("Sigmoid"), get_item1}); | |||||
| VectorRef input_to_cell = VectorRef({std::make_shared<Var>("Tanh"), get_item2}); | |||||
| VectorRef forget_bias = VectorRef({std::make_shared<Var>("Add"), get_item3, forget_bias_}); | |||||
| VectorRef forget_gate = VectorRef({std::make_shared<Var>("Sigmoid"), forget_bias}); | |||||
| VectorRef output_gate = VectorRef({std::make_shared<Var>("Sigmoid"), get_item4}); | |||||
| VectorRef forgetted_cell = VectorRef({std::make_shared<Var>(""), forget_gate, placeholders[4]}); | |||||
| VectorRef inputted_cell = VectorRef({std::make_shared<Var>(""), input_gate, input_to_cell}); | |||||
| VectorRef input_forget_cell = VectorRef({std::make_shared<Var>("Add"), forgetted_cell, inputted_cell}); | |||||
| VectorRef to_new_hidden = VectorRef({std::make_shared<Var>("Tanh"), input_forget_cell}); | |||||
| VectorRef new_hidden = VectorRef({std::make_shared<Var>("Mul"), output_gate, to_new_hidden}); | |||||
| VectorRef new_to_cell = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_new_, input_forget_cell}); | |||||
| VectorRef old_to_cell = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_old_, placeholders[4]}); | |||||
| VectorRef output_cell = VectorRef({std::make_shared<Var>("Add"), new_to_cell, old_to_cell}); | |||||
| VectorRef new_to_hidden = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_new_, new_hidden}); | |||||
| VectorRef old_to_hidden = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_old_, placeholders[5]}); | |||||
| VectorRef output_hidden = VectorRef({std::make_shared<Var>("Add"), new_to_hidden, old_to_hidden}); | |||||
| VectorRef set_item = VectorRef({std::make_shared<Var>(""), placeholders[3], placeholders[2], new_hidden}); | |||||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple)); | |||||
| std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, output_cell, output_hidden}; | |||||
| outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); | |||||
| VectorRef make_tuple_node = VectorRef(outputs); | |||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); | |||||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | |||||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||||
| auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true); | |||||
| return pattern; | |||||
| } | |||||
| STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector<int> &shape, | |||||
| const float *const data_ptr, const int hidden_size) const { | |||||
| MS_ASSERT(weight != nullptr); | |||||
| MS_ASSERT(data_ptr != nullptr); | |||||
| auto default_param = std::make_shared<ParamValueLite>(); | |||||
| if (default_param == nullptr) { | |||||
| MS_LOG(ERROR) << "new_default is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| default_param->set_tensor_shape(shape); | |||||
| default_param->set_tensor_type(kNumberTypeFloat32); | |||||
| default_param->set_format(schema::Format_NHWC); | |||||
| if (shape.size() != 3) { | |||||
| MS_LOG(ERROR) << "lstm weight shape must have 3 dims"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| const auto param_num = shape[0] * shape[1] * shape[2]; | |||||
| auto tensor_data = new (std::nothrow) float[param_num * 4]; | |||||
| std::vector<int> data_diff{0, 3, 2, 1}; | |||||
| if (tensor_data == nullptr) { | |||||
| MS_LOG(DEBUG) << "new data failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| for (int i = 0; i < 4; ++i) { | |||||
| for (int j = 0; j < hidden_size; ++j) { | |||||
| for (int t = 0; t < shape[2]; ++t) { | |||||
| tensor_data[(i * hidden_size + j) * shape[2] + t] = data_ptr[t * shape[1] + data_diff[i] * hidden_size + j]; | |||||
| } | |||||
| } | |||||
| } | |||||
| default_param->SetTensorData(tensor_data, param_num * 4); | |||||
| weight->set_default_param(default_param); | |||||
| std::vector<int64_t> shape_vector_i(shape.begin(), shape.end()); | |||||
| auto abstract_tensor_i = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector_i); | |||||
| if (abstract_tensor_i == nullptr) { | |||||
| MS_LOG(ERROR) << "abstract_tensor is nullptr"; | |||||
| delete[] tensor_data; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weight->set_abstract(abstract_tensor_i); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfLstmCellFusion::SplitWeights(const AnfNodePtr &weight, const ParameterPtr &weight_i, | |||||
| const ParameterPtr &weight_c, int hidden_size) const { | |||||
| // split input_size and hidden_size at dim 0 | |||||
| // transform i,c,f,o to i,o,f,c at dim 1 | |||||
| MS_ASSERT(weight != nullptr); | |||||
| MS_ASSERT(wiehgt_i != nullptr); | |||||
| MS_ASSERT(wiehgt_c != nullptr); | |||||
| if (!utils::isa<ParameterPtr>(weight)) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto weight_param = utils::cast<ParameterPtr>(weight); | |||||
| if (!weight_param->has_default()) { | |||||
| MS_LOG(DEBUG) << "weight not have default value"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<ParamValueLitePtr>(weight_param->default_param())) { | |||||
| MS_LOG(DEBUG) << "default value is not ParamValueLite"; | |||||
| return RET_FAILED; | |||||
| } | |||||
| auto origin_tensor = std::dynamic_pointer_cast<ParamValueLite>(weight_param->default_param()); | |||||
| if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { | |||||
| MS_LOG(DEBUG) << "origin_tensor is not float32 type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto data_ptr = reinterpret_cast<float *>(origin_tensor->tensor_addr()); | |||||
| auto data_shape = origin_tensor->tensor_shape(); | |||||
| if (data_shape.size() != 2) { | |||||
| MS_LOG(ERROR) << "weight data shape invalid"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (data_shape[0] <= hidden_size) { | |||||
| MS_LOG(ERROR) << "weight data shape[0] invalid"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (hidden_size * 4 != data_shape[1]) { | |||||
| MS_LOG(ERROR) << "weight data shape[1] invalid"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| const auto input_size = data_shape[0] - hidden_size; | |||||
| std::vector<int> shape_i{1, 4 * hidden_size, input_size}; | |||||
| if (SetWeightAbstractAndDefault(weight_i, shape_i, data_ptr, hidden_size) != RET_OK) { | |||||
| MS_LOG(ERROR) << "get weight_i failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int> shape_c{1, 4 * hidden_size, hidden_size}; | |||||
| if (SetWeightAbstractAndDefault(weight_c, shape_c, data_ptr + input_size * data_shape[1], hidden_size) != RET_OK) { | |||||
| MS_LOG(ERROR) << "get weight_i failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfLstmCellFusion::PopulateBiasNode(const EquivPtr &body_equiv, const ParameterPtr &new_bias, | |||||
| const AnfNodePtr &old_bias, const int hidden_size) const { | |||||
| MS_ASSERT(body_equiv != nullptr); | |||||
| MS_ASSERT(new_bias != nullptr); | |||||
| MS_ASSERT(old_bias != nullptr); | |||||
| if (!utils::isa<ParameterPtr>(old_bias)) { | |||||
| MS_LOG(DEBUG) << "old_bias is not parameter"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto old_bias_param = utils::cast<ParameterPtr>(old_bias); | |||||
| if (!old_bias_param->has_default()) { | |||||
| MS_LOG(DEBUG) << "bias not have default value"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<ParamValueLitePtr>(old_bias_param->default_param())) { | |||||
| MS_LOG(DEBUG) << "default value is not ParamValueLite"; | |||||
| return RET_FAILED; | |||||
| } | |||||
| auto origin_tensor = std::dynamic_pointer_cast<ParamValueLite>(old_bias_param->default_param()); | |||||
| if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { | |||||
| MS_LOG(DEBUG) << "origin_tensor is not float32 type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto data_ptr = reinterpret_cast<float *>(origin_tensor->tensor_addr()); | |||||
| auto data_shape = origin_tensor->tensor_shape(); | |||||
| if (data_shape.size() != 1 || data_shape[0] != 4 * hidden_size) { | |||||
| MS_LOG(DEBUG) << "bias data shape illegal"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int> shape{1, 8 * hidden_size}; | |||||
| auto default_param = std::make_shared<ParamValueLite>(); | |||||
| if (default_param == nullptr) { | |||||
| MS_LOG(ERROR) << "new_default is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| default_param->set_tensor_shape(shape); | |||||
| default_param->set_tensor_type(kNumberTypeFloat32); | |||||
| default_param->set_format(schema::Format_NHWC); | |||||
| auto tensor_data = new (std::nothrow) float[hidden_size * 8]; | |||||
| auto forget_bias_node = utils::cast<AnfNodePtr>((*body_equiv)[forget_bias_]); | |||||
| if (forget_bias_node == nullptr) { | |||||
| MS_LOG(ERROR) << "forget bias node is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| float forget_bias_value = 0.0f; | |||||
| if (GetFloatScalarFromParamValueLite(forget_bias_node, &forget_bias_value) != RET_OK) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int> data_diff{0, 3, 2, 1}; | |||||
| for (int i = 0; i < 8; ++i) { | |||||
| for (int j = 0; j < hidden_size; ++j) { | |||||
| if (i < 4) { | |||||
| tensor_data[i * hidden_size + j] = data_ptr[data_diff[i] * hidden_size + j]; | |||||
| if (i == 2) { // forget bias | |||||
| tensor_data[i * hidden_size + j] += forget_bias_value; | |||||
| } | |||||
| } else { | |||||
| tensor_data[i * hidden_size + j] = 0.0f; | |||||
| } | |||||
| } | |||||
| } | |||||
| default_param->SetTensorData(tensor_data, hidden_size * 8 * 4); | |||||
| new_bias->set_default_param(default_param); | |||||
| std::vector<int64_t> shape_vector_i(shape.begin(), shape.end()); | |||||
| auto abstract_tensor_i = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector_i); | |||||
| if (abstract_tensor_i == nullptr) { | |||||
| MS_LOG(ERROR) << "abstract_tensor is nullptr"; | |||||
| delete[] tensor_data; | |||||
| return RET_ERROR; | |||||
| } | |||||
| new_bias->set_abstract(abstract_tensor_i); | |||||
| return RET_OK; | |||||
| } | |||||
| CNodePtr TfLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, | |||||
| const EquivPtr &body_equiv, const std::string &base_name, | |||||
| const float smooth) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(equiv != nullptr); | |||||
| auto lstm_primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| std::unique_ptr<schema::LstmT> attr = std::make_unique<schema::LstmT>(); | |||||
| attr->bidirection = false; | |||||
| attr->smooth = smooth; | |||||
| lstm_primitive->value.type = schema::PrimitiveType_Lstm; | |||||
| lstm_primitive->value.value = attr.release(); | |||||
| auto lstm_cvalue = lite::PrimitiveC::Create(lstm_primitive.release()); | |||||
| auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(lstm_cvalue)); | |||||
| auto &vars = while_input_vars_; | |||||
| auto limit1 = utils::cast<AnfNodePtr>((*equiv)[vars[3]]); | |||||
| MS_ASSERT(limit1); | |||||
| auto limit2 = utils::cast<AnfNodePtr>((*equiv)[vars[7]]); | |||||
| MS_ASSERT(limit2); | |||||
| auto weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]); | |||||
| MS_ASSERT(weight); | |||||
| auto bias = utils::cast<AnfNodePtr>((*equiv)[vars[10]]); | |||||
| MS_ASSERT(bias); | |||||
| auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]); | |||||
| MS_ASSERT(input); | |||||
| auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]); | |||||
| MS_ASSERT(cell); | |||||
| auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]); | |||||
| MS_ASSERT(hidden); | |||||
| if (!utils::isa<ParameterPtr>(hidden)) { | |||||
| MS_LOG(DEBUG) << "hidden is not parameter"; | |||||
| return nullptr; | |||||
| } | |||||
| auto hidden_param = utils::cast<ParameterPtr>(hidden); | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(hidden_param->abstract())) { | |||||
| MS_LOG(DEBUG) << "hidden abstract is not AbstractTensor"; | |||||
| return nullptr; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(hidden_param->abstract()); | |||||
| auto hidden_shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||||
| if (hidden_shape.size() == 0) { | |||||
| MS_LOG(DEBUG) << "can't get hidden shape"; | |||||
| return nullptr; | |||||
| } | |||||
| auto i_weight = func_graph->add_parameter(); | |||||
| i_weight->set_name(base_name + "_weight_i"); | |||||
| i_weight->set_abstract(weight->abstract()->Clone()); | |||||
| auto c_weight = func_graph->add_parameter(); | |||||
| c_weight->set_name(base_name + "_weight_c"); | |||||
| c_weight->set_abstract(weight->abstract()->Clone()); | |||||
| if (SplitWeights(weight, i_weight, c_weight, hidden_shape.back()) != RET_OK) { | |||||
| MS_LOG(DEBUG) << "split weight to i_weight and c_weight failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto bias_node = func_graph->add_parameter(); | |||||
| bias_node->set_name(base_name + "_bias"); | |||||
| bias_node->set_abstract(bias->abstract()->Clone()); | |||||
| if (PopulateBiasNode(body_equiv, bias_node, bias, hidden_shape.back()) != RET_OK) { | |||||
| MS_LOG(DEBUG) << "reorder bias failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (!utils::isa<CNodePtr>(input) || GetCNodeType(input) != schema::PrimitiveType_TensorListFromTensor) { | |||||
| MS_LOG(DEBUG) << "input is not tensorlistfromtensor op"; | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_list_cnode = utils::cast<CNodePtr>(input); | |||||
| auto input_tensor_node = tensor_list_cnode->input(1); | |||||
| std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias_node, hidden, | |||||
| cell}; | |||||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||||
| new_node->set_fullname_with_scope(base_name); | |||||
| return new_node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "utils/utils.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "include/errorcode.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class TfLstmCellFusion : public TfliteLstmCellFusion { | |||||
| public: | |||||
| explicit TfLstmCellFusion(const std::string &name = "lstm_cell_fusion", bool multigraph = true); | |||||
| ~TfLstmCellFusion() override = default; | |||||
| private: | |||||
| AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const override; | |||||
| CNodePtr CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const EquivPtr &body_equiv, | |||||
| const std::string &base_name, const float smooth) const override; | |||||
| lite::STATUS SplitWeights(const AnfNodePtr &weight, const ParameterPtr &weight_i, const ParameterPtr &weight_c, | |||||
| int hidden_size) const; | |||||
| lite::STATUS SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector<int> &shape, | |||||
| const float *const data_ptr, const int hidden_size) const; | |||||
| lite::STATUS PopulateBiasNode(const EquivPtr &body_equiv, const ParameterPtr &new_bias, const AnfNodePtr &old_bias, | |||||
| const int hidden_size) const; | |||||
| private: | |||||
| VarPtr forget_bias_ = nullptr; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_LSTM_CELL_FUSION_H_ | |||||
| @@ -0,0 +1,727 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||||
| #include <memory> | |||||
| #include <functional> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "utils/utils.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| #include "securec/include/securec.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| constexpr size_t kWhileInputsLength = 23; | |||||
| constexpr size_t kWhileInputsVarNum = 21; | |||||
| constexpr size_t kCondNodesNum = 12; | |||||
| constexpr size_t kCondCNodesNum = 4; | |||||
| constexpr size_t kBodyNodesNum = 95; | |||||
| constexpr size_t kBodyCNodesNum = 34; | |||||
| constexpr size_t kLSTMOutputNum = 3; | |||||
| const auto &p1 = std::placeholders::_1; | |||||
| constexpr float EPSILON = 1e-5; | |||||
| bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); } | |||||
| bool IsOpType(const BaseRef &n, const schema::PrimitiveType &type) { | |||||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||||
| return opt::GetCNodeType(n) == type; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace | |||||
| STATUS TfliteLstmCellFusion::GetFloatScalarFromParamValueLite(const AnfNodePtr ¶m_value, float *v) const { | |||||
| if (param_value == nullptr || v == nullptr) { | |||||
| MS_LOG(ERROR) << "param_value or v is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<ParameterPtr>(param_value)) { | |||||
| MS_LOG(DEBUG) << "param_value is not ParamValueLitePtr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto param_ptr = utils::cast<ParameterPtr>(param_value); | |||||
| if (!param_ptr->has_default()) { | |||||
| MS_LOG(DEBUG) << "param not have default"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto default_param = param_ptr->default_param(); | |||||
| if (!utils::isa<ParamValueLitePtr>(default_param)) { | |||||
| MS_LOG(DEBUG) << "param_value is not ParamValueLitePtr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto default_param_ptr = utils::cast<ParamValueLitePtr>(default_param); | |||||
| auto tensor_shape = default_param_ptr->tensor_shape(); | |||||
| if (!(tensor_shape.size() == 0 || (tensor_shape.size() == 1 && tensor_shape[0] == 1))) { | |||||
| MS_LOG(DEBUG) << "default param is not scalar"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (default_param_ptr->tensor_type() != kNumberTypeFloat32 && default_param_ptr->tensor_type() != kNumberTypeFloat) { | |||||
| MS_LOG(DEBUG) << "default param is not float"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *v = *(reinterpret_cast<float *>(default_param_ptr->tensor_addr())); | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num, | |||||
| int cond_nodes_num, int cond_cnodes_num, int body_nodes_num, | |||||
| int body_cnodes_num) | |||||
| : PatternProcessPass(name, multigraph) { | |||||
| /* | |||||
| * input vars for lstm while node | |||||
| * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_ | |||||
| * 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_ | |||||
| */ | |||||
| this->while_inputs_num_ = input_length == 0 ? kWhileInputsLength : input_length; | |||||
| this->while_input_var_num_ = var_num == 0 ? kWhileInputsVarNum : var_num; | |||||
| this->cond_nodes_num_ = cond_nodes_num == 0 ? kCondNodesNum : cond_nodes_num; | |||||
| this->cond_cnodes_num_ = cond_cnodes_num == 0 ? kCondCNodesNum : cond_cnodes_num; | |||||
| this->body_nodes_num_ = body_nodes_num == 0 ? kBodyNodesNum : body_nodes_num; | |||||
| this->body_cnodes_num_ = body_cnodes_num == 0 ? kBodyCNodesNum : body_cnodes_num; | |||||
| for (size_t i = 0; i < this->while_input_var_num_; ++i) { | |||||
| while_input_vars_.emplace_back(std::make_shared<Var>()); | |||||
| } | |||||
| cell_smooth_old_ = std::make_shared<Var>(); | |||||
| cell_smooth_new_ = std::make_shared<Var>(); | |||||
| hidden_smooth_old_ = std::make_shared<Var>(); | |||||
| hidden_smooth_new_ = std::make_shared<Var>(); | |||||
| } | |||||
| AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||||
| auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode); | |||||
| auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode); | |||||
| auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode); | |||||
| auto is_parameter4 = std::make_shared<CondVar>(IsParameterNode); | |||||
| auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); | |||||
| auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Less)); | |||||
| auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_LogicalAnd)); | |||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); | |||||
| VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); | |||||
| VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); | |||||
| VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); | |||||
| VectorRef return_ref = VectorRef({is_return, logicaland_ref}); | |||||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||||
| auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true); | |||||
| return pattern; | |||||
| } | |||||
| AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||||
| std::vector<CondVarPtr> placeholders; | |||||
| for (int i = 0; i < 20; ++i) { | |||||
| placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode)); | |||||
| } | |||||
| VectorRef add2 = VectorRef({std::make_shared<Var>(), placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||||
| VectorRef add3 = VectorRef({std::make_shared<Var>(), placeholders[0], std::make_shared<CondVar>(IsParameterNode)}); | |||||
| VectorRef concat_i_w = VectorRef({std::make_shared<Var>(), placeholders[8], placeholders[12]}); | |||||
| VectorRef concat_f_w = VectorRef({std::make_shared<Var>(), placeholders[9], placeholders[13]}); | |||||
| VectorRef concat_c_w = VectorRef({std::make_shared<Var>(), placeholders[10], placeholders[14]}); | |||||
| VectorRef concat_o_w = VectorRef({std::make_shared<Var>(), placeholders[11], placeholders[15]}); | |||||
| VectorRef get_item = VectorRef( | |||||
| {std::make_shared<Var>("GetItem"), placeholders[7], placeholders[2], std::make_shared<CondVar>(IsParameterNode)}); | |||||
| VectorRef concat_input_h = VectorRef({std::make_shared<Var>(), get_item, placeholders[5]}); | |||||
| VectorRef matmul_input = VectorRef({std::make_shared<Var>(), concat_input_h, concat_i_w}); | |||||
| VectorRef matmul_forget = VectorRef({std::make_shared<Var>(), concat_input_h, concat_f_w}); | |||||
| VectorRef matmul_cell = VectorRef({std::make_shared<Var>(), concat_input_h, concat_c_w}); | |||||
| VectorRef matmul_output = VectorRef({std::make_shared<Var>(), concat_input_h, concat_o_w}); | |||||
| VectorRef bias_input = VectorRef({std::make_shared<Var>(), matmul_input, placeholders[16]}); | |||||
| VectorRef bias_forget = VectorRef({std::make_shared<Var>(), matmul_forget, placeholders[17]}); | |||||
| VectorRef bias_cell = VectorRef({std::make_shared<Var>(), matmul_cell, placeholders[18]}); | |||||
| VectorRef bias_output = VectorRef({std::make_shared<Var>(), matmul_output, placeholders[19]}); | |||||
| VectorRef cell = VectorRef({std::make_shared<Var>("Tanh"), bias_cell}); | |||||
| VectorRef input_gate = VectorRef({std::make_shared<Var>("Sigmoid"), bias_input}); | |||||
| VectorRef cell_input = VectorRef({std::make_shared<Var>("Mul"), input_gate, cell}); | |||||
| VectorRef forget_gate = VectorRef({std::make_shared<Var>("Sigmoid"), bias_forget}); | |||||
| VectorRef cell_forgeted = VectorRef({std::make_shared<Var>("Mul"), forget_gate, placeholders[4]}); | |||||
| VectorRef cell_new = VectorRef({std::make_shared<Var>("Add"), cell_forgeted, cell_input}); | |||||
| VectorRef smooth_cell_old = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_old_, placeholders[4]}); | |||||
| VectorRef smooth_cell_new = VectorRef({std::make_shared<Var>("Mul"), cell_smooth_new_, cell_new}); | |||||
| VectorRef cell_output = VectorRef({std::make_shared<Var>("Add"), smooth_cell_new, smooth_cell_old}); | |||||
| VectorRef output_gate = VectorRef({std::make_shared<Var>("Sigmoid"), bias_output}); | |||||
| VectorRef cell_to_output = VectorRef({std::make_shared<Var>("Tanh"), cell_new}); | |||||
| VectorRef output = VectorRef({std::make_shared<Var>("Mul"), output_gate, cell_to_output}); | |||||
| VectorRef smooth_hidden_old = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_old_, placeholders[5]}); | |||||
| VectorRef smooth_hidden_new = VectorRef({std::make_shared<Var>("Mul"), hidden_smooth_new_, output}); | |||||
| VectorRef hidden_output = VectorRef({std::make_shared<Var>("Add"), smooth_hidden_new, smooth_hidden_old}); | |||||
| VectorRef set_item = VectorRef({std::make_shared<Var>("SetItem"), placeholders[3], placeholders[2], output}); | |||||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_MakeTuple)); | |||||
| std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output}; | |||||
| outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); | |||||
| VectorRef make_tuple_node = VectorRef(outputs); | |||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Return)); | |||||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | |||||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||||
| auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true); | |||||
| return pattern; | |||||
| } | |||||
| const BaseRef TfliteLstmCellFusion::DefinePattern() const { | |||||
| auto is_while_node = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While)); | |||||
| VectorRef while_node = VectorRef({is_while_node}); | |||||
| auto while_inputs = while_input_vars_; | |||||
| while_inputs.insert(while_inputs.begin() + 4, while_input_vars_[2]); | |||||
| while_node.insert(while_node.end(), while_inputs.begin(), while_inputs.end()); | |||||
| auto is_tuple_get_item = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)); | |||||
| VectorRef while_output = VectorRef({is_tuple_get_item, while_node, std::make_shared<Var>()}); | |||||
| auto is_tensor_list_stack = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)); | |||||
| auto is_parameter = std::make_shared<CondVar>(IsParameterNode); | |||||
| VectorRef tensor_list_stack_node = VectorRef({is_tensor_list_stack, while_output, is_parameter}); | |||||
| return tensor_list_stack_node; | |||||
| } | |||||
| EquivPtr TfliteLstmCellFusion::MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars, | |||||
| const AnfNodePtr &pattern) { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(pattern != nullptr); | |||||
| auto return_node = func_graph->get_return(); | |||||
| PatternEngine pattern_engine(PatternEngine(std::make_shared<DefaultVisitor>(), | |||||
| std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual), | |||||
| std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))); | |||||
| auto empty_equiv = std::make_shared<Equiv>(); | |||||
| EquivPtr equiv = pattern_engine.Match(pattern, return_node, *primitive_vars, empty_equiv); | |||||
| return equiv; | |||||
| } | |||||
| // make sure that only 3,4,5 output of while is referenced | |||||
| bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(while_cnode != nullptr); | |||||
| auto manager = func_graph->manager(); | |||||
| if (manager == nullptr) { | |||||
| MS_LOG(ERROR) << "manager is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto while_node_users = manager->node_users()[while_cnode]; | |||||
| std::vector<size_t> valid_indexes{3, 4, 5}; | |||||
| for (auto &node_user : while_node_users) { | |||||
| if (!utils::isa<CNodePtr>(node_user.first)) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = utils::cast<CNodePtr>(node_user.first); | |||||
| if (GetCNodeType(cnode) != schema::PrimitiveType_TupleGetItem) { | |||||
| return false; | |||||
| } | |||||
| auto index = GetTupleGetItemOutIndex(cnode); | |||||
| if (!lite::IsContain(valid_indexes, index)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| EquivPtr TfliteLstmCellFusion::CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern, | |||||
| const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph, | |||||
| const size_t cnode_num, const size_t all_node_num) { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(pattern != nullptr); | |||||
| MS_ASSERT(anf_sub_graph != nullptr); | |||||
| auto sub_graph = GetValueNode<FuncGraphPtr>(anf_sub_graph); | |||||
| auto nodes = TopoSort(sub_graph->get_return()); | |||||
| auto cnodes = sub_graph->GetOrderedCnodes(); | |||||
| if (cnodes.size() != cnode_num || nodes.size() != all_node_num) { | |||||
| MS_LOG(DEBUG) << "sub graph nodes num not match"; | |||||
| return nullptr; | |||||
| } | |||||
| return MatchGraph(sub_graph, primitive_vars, pattern); | |||||
| } | |||||
| bool TfliteLstmCellFusion::CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv, | |||||
| const CNodePtr &while_cnode, float *smooth) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(equiv != nullptr); | |||||
| MS_ASSERT(while_cnode != nullptr); | |||||
| MS_ASSERT(smooth != nullptr); | |||||
| auto cell_smooth_old_node = utils::cast<AnfNodePtr>((*equiv)[cell_smooth_old_]); | |||||
| MS_ASSERT(cell_smooth_old_node != nullptr); | |||||
| auto cell_smooth_new_node = utils::cast<AnfNodePtr>((*equiv)[cell_smooth_new_]); | |||||
| MS_ASSERT(cell_smooth_new_node != nullptr); | |||||
| auto hidden_smooth_old_node = utils::cast<AnfNodePtr>((*equiv)[hidden_smooth_old_]); | |||||
| MS_ASSERT(hidden_smooth_old_node != nullptr); | |||||
| auto hidden_smooth_new_node = utils::cast<AnfNodePtr>((*equiv)[hidden_smooth_new_]); | |||||
| MS_ASSERT(hidden_smooth_new_node != nullptr); | |||||
| float cell_old, cell_new, hidden_old, hidden_new; | |||||
| if (GetFloatScalarFromParamValueLite(cell_smooth_old_node, &cell_old) != RET_OK) { | |||||
| return false; | |||||
| } | |||||
| if (GetFloatScalarFromParamValueLite(cell_smooth_new_node, &cell_new) != RET_OK) { | |||||
| return false; | |||||
| } | |||||
| if (GetFloatScalarFromParamValueLite(hidden_smooth_old_node, &hidden_old) != RET_OK) { | |||||
| return false; | |||||
| } | |||||
| if (GetFloatScalarFromParamValueLite(hidden_smooth_new_node, &hidden_new) != RET_OK) { | |||||
| return false; | |||||
| } | |||||
| if (cell_old < 0.0f || cell_old > 1.0f || cell_new < 0.0f || cell_new > 1.0f) { | |||||
| MS_LOG(DEBUG) << "cell smooth value illegal"; | |||||
| return false; | |||||
| } | |||||
| if (hidden_old < 0.0f || hidden_old > 1.0f || hidden_new < 0.0f || hidden_new > 1.0f) { | |||||
| MS_LOG(DEBUG) << "hidden smooth value illegal"; | |||||
| return false; | |||||
| } | |||||
| if (std::abs(cell_old + cell_new - 1.0f) > EPSILON || std::abs(hidden_old + hidden_new - 1.0f) > EPSILON || | |||||
| std::abs(cell_old - hidden_old) > EPSILON) { | |||||
| MS_LOG(DEBUG) << "smooth value illegal"; | |||||
| return false; | |||||
| } | |||||
| *smooth = cell_old; | |||||
| return true; | |||||
| } | |||||
| STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector<AnfNodePtr> ¶ms, const ParameterPtr &new_param, | |||||
| bool is_bias) const { | |||||
| MS_ASSERT(new_param != nullptr); | |||||
| MS_ASSERT(params.size() == 4); | |||||
| std::vector<float *> data_ptrs; | |||||
| std::vector<std::vector<int>> data_shapes; | |||||
| for (auto ¶m : params) { | |||||
| if (!utils::isa<ParameterPtr>(param)) { | |||||
| MS_LOG(DEBUG) << "param is not Parameter node"; | |||||
| return RET_FAILED; | |||||
| } | |||||
| auto param_t = utils::cast<ParameterPtr>(param); | |||||
| if (!param_t->has_default()) { | |||||
| MS_LOG(DEBUG) << "param not have default value"; | |||||
| return RET_FAILED; | |||||
| } | |||||
| if (!utils::isa<ParamValueLitePtr>(param_t->default_param())) { | |||||
| MS_LOG(DEBUG) << "default value is not ParamValueLite"; | |||||
| return RET_FAILED; | |||||
| } | |||||
| auto origin_tensor = std::dynamic_pointer_cast<ParamValueLite>(param_t->default_param()); | |||||
| if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { | |||||
| MS_LOG(DEBUG) << "origin_tensor is not float32 type"; | |||||
| return RET_FAILED; | |||||
| } | |||||
| auto data_ptr = reinterpret_cast<float *>(origin_tensor->tensor_addr()); | |||||
| auto data_shape = origin_tensor->tensor_shape(); | |||||
| data_ptrs.push_back(data_ptr); | |||||
| data_shapes.push_back(data_shape); | |||||
| } | |||||
| for (size_t i = 1; i < data_shapes.size(); ++i) { | |||||
| if (data_shapes[i] != data_shapes[0]) { | |||||
| MS_LOG(DEBUG) << "data shape not same"; | |||||
| return RET_FAILED; | |||||
| } | |||||
| } | |||||
| auto new_default = std::make_shared<ParamValueLite>(); | |||||
| if (new_default == nullptr) { | |||||
| MS_LOG(ERROR) << "new_default is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int> new_shape; | |||||
| float *tensor_data = nullptr; | |||||
| int step = 0; | |||||
| int data_size = 0; | |||||
| if (is_bias) { | |||||
| if (data_shapes[0].size() != 1) { | |||||
| MS_LOG(ERROR) << "bias data shape error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| step = data_shapes[0][0]; | |||||
| data_size = 8 * step; | |||||
| new_shape = std::vector<int>({1, data_size}); | |||||
| } else { | |||||
| if (data_shapes[0].size() != 2) { | |||||
| MS_LOG(ERROR) << "weight data shape error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| new_shape = std::vector<int>({1, data_shapes[0][0] * 4, data_shapes[0][1]}); | |||||
| step = data_shapes[0][0] * data_shapes[0][1]; | |||||
| data_size = 4 * step; | |||||
| } | |||||
| tensor_data = new (std::nothrow) float[data_size]; | |||||
| if (tensor_data == nullptr) { | |||||
| MS_LOG(ERROR) << "new data failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| for (int i = 0; i < data_size; ++i) { // bias are stored into first 4*hidden_size buffer, the rest is all 0 | |||||
| tensor_data[i] = 0.0f; | |||||
| } | |||||
| for (size_t i = 0; i < data_ptrs.size(); ++i) { | |||||
| auto source_len = std::accumulate(data_shapes[i].begin(), data_shapes[i].end(), 1, std::multiplies<int>()); | |||||
| auto ret = memcpy_s(tensor_data + i * step, step * sizeof(float), data_ptrs[i], source_len * sizeof(float)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s error"; | |||||
| delete[] tensor_data; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| new_default->set_tensor_shape(new_shape); | |||||
| new_default->set_tensor_type(kNumberTypeFloat32); | |||||
| new_default->set_format(schema::Format_NHWC); | |||||
| new_default->SetTensorData(tensor_data, data_size * sizeof(float)); | |||||
| new_param->set_default_param(new_default); | |||||
| std::vector<int64_t> shape_vector(new_shape.begin(), new_shape.end()); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "abstract_tensor is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| new_param->set_abstract(abstract_tensor); | |||||
| return RET_OK; | |||||
| } | |||||
| CNodePtr TfliteLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, | |||||
| const EquivPtr &body_equiv, const std::string &base_name, | |||||
| const float smooth) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(equiv != nullptr); | |||||
| MS_ASSERT(body_equiv != nullptr); | |||||
| /* | |||||
| * input vars for while node | |||||
| * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_ | |||||
| * 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_ | |||||
| */ | |||||
| auto lstm_primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| std::unique_ptr<schema::LstmT> attr = std::make_unique<schema::LstmT>(); | |||||
| attr->bidirection = false; | |||||
| attr->smooth = smooth; | |||||
| lstm_primitive->value.type = schema::PrimitiveType_Lstm; | |||||
| lstm_primitive->value.value = attr.release(); | |||||
| auto lstm_cvalue = lite::PrimitiveC::Create(lstm_primitive.release()); | |||||
| auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(lstm_cvalue)); | |||||
| auto &vars = while_input_vars_; | |||||
| auto limit1 = utils::cast<AnfNodePtr>((*equiv)[vars[3]]); | |||||
| MS_ASSERT(limit1); | |||||
| auto limit2 = utils::cast<AnfNodePtr>((*equiv)[vars[7]]); | |||||
| MS_ASSERT(limit2); | |||||
| auto i2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]); | |||||
| MS_ASSERT(i2i_weight); | |||||
| auto i2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[10]]); | |||||
| MS_ASSERT(i2f_weight); | |||||
| auto i2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[11]]); | |||||
| MS_ASSERT(i2c_weight); | |||||
| auto i2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[12]]); | |||||
| MS_ASSERT(i2o_weight); | |||||
| auto c2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[13]]); | |||||
| MS_ASSERT(c2i_weight); | |||||
| auto c2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[14]]); | |||||
| MS_ASSERT(c2f_weight); | |||||
| auto c2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[15]]); | |||||
| MS_ASSERT(c2c_weight); | |||||
| auto c2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[16]]); | |||||
| MS_ASSERT(c2o_weight); | |||||
| auto i_bias = utils::cast<AnfNodePtr>((*equiv)[vars[17]]); | |||||
| MS_ASSERT(i_bias); | |||||
| auto f_bias = utils::cast<AnfNodePtr>((*equiv)[vars[18]]); | |||||
| MS_ASSERT(f_bias); | |||||
| auto c_bias = utils::cast<AnfNodePtr>((*equiv)[vars[19]]); | |||||
| MS_ASSERT(c_bias); | |||||
| auto o_bias = utils::cast<AnfNodePtr>((*equiv)[vars[20]]); | |||||
| MS_ASSERT(o_bias); | |||||
| auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]); | |||||
| MS_ASSERT(input); | |||||
| auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]); | |||||
| MS_ASSERT(cell); | |||||
| auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]); | |||||
| MS_ASSERT(hidden); | |||||
| std::vector<AnfNodePtr> i_weights{i2i_weight, i2o_weight, i2f_weight, i2c_weight}; | |||||
| auto i_weight = func_graph->add_parameter(); | |||||
| auto status = GetConcatedParam(i_weights, i_weight, false); | |||||
| if (status != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| i_weight->set_name(base_name + "_weight_i"); | |||||
| std::vector<AnfNodePtr> c_weights{c2i_weight, c2o_weight, c2f_weight, c2c_weight}; | |||||
| auto c_weight = func_graph->add_parameter(); | |||||
| status = GetConcatedParam(c_weights, c_weight, false); | |||||
| if (status != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| c_weight->set_name(base_name + "_weight_c"); | |||||
| std::vector<AnfNodePtr> biases{i_bias, o_bias, f_bias, c_bias}; | |||||
| auto bias = func_graph->add_parameter(); | |||||
| status = GetConcatedParam(biases, bias, true); | |||||
| if (status != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| bias->set_name(base_name + "_bias"); | |||||
| if (!utils::isa<CNodePtr>(input) || GetCNodeType(input) != schema::PrimitiveType_TensorListFromTensor) { | |||||
| MS_LOG(DEBUG) << "input is not tensorlistfromtensor op"; | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_list_cnode = utils::cast<CNodePtr>(input); | |||||
| auto input_tensor_node = tensor_list_cnode->input(1); | |||||
| std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias, hidden, cell}; | |||||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||||
| new_node->set_fullname_with_scope(base_name); | |||||
| return new_node; | |||||
| } | |||||
| CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||||
| const int item_index) { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| MS_ASSERT(get_items != nullptr); | |||||
| auto tuple_get_item_prim_ptr = lite::GetTupleGetItemPrim(); | |||||
| if (tuple_get_item_prim_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | |||||
| auto get_item_value = NewValueNode(MakeValue<int>(item_index)); | |||||
| if (tuple_get_item_prim == nullptr || get_item_value == nullptr) { | |||||
| MS_LOG(ERROR) << "NewValueNode is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> inputs{tuple_get_item_prim, node, get_item_value}; | |||||
| CNodePtr get_item_cnode = func_graph->NewCNode(inputs); | |||||
| std::vector<int64_t> shape_vector; | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "create abstract_tensor failed"; | |||||
| return nullptr; | |||||
| } | |||||
| get_item_cnode->set_abstract(abstract_tensor); | |||||
| get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" + | |||||
| std::to_string(item_index)); | |||||
| return get_item_cnode; | |||||
| } | |||||
| STATUS TfliteLstmCellFusion::AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode, | |||||
| const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(while_cnode != nullptr); | |||||
| auto manager = func_graph->manager(); | |||||
| if (manager == nullptr) { | |||||
| MS_LOG(ERROR) << "manager is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto tr = manager->Transact(); | |||||
| auto while_node_users = manager->node_users()[while_cnode]; | |||||
| for (auto &node_user : while_node_users) { | |||||
| if (node_user.first == output_get_item) { | |||||
| continue; | |||||
| } | |||||
| if (!utils::isa<CNodePtr>(node_user.first)) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto get_item = utils::cast<CNodePtr>(node_user.first); | |||||
| if (GetCNodeType(get_item) != schema::PrimitiveType_TupleGetItem) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto new_inputs = get_item->inputs(); | |||||
| if (new_inputs.size() != 3) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| new_inputs[1] = lstm_cnode; | |||||
| auto index_vnode = get_item->input(2); | |||||
| if (!utils::isa<ValueNode>(index_vnode)) { | |||||
| MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto value_node = utils::cast<ValueNodePtr>(index_vnode); | |||||
| if (value_node == nullptr) { | |||||
| MS_LOG(ERROR) << "cast to ValueNode failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto origin_index = GetValue<int>(value_node->value()); | |||||
| int new_index = origin_index == 4 ? 2 : 1; | |||||
| auto new_index_vnode = NewValueNode(MakeValue<int>(new_index)); | |||||
| new_inputs[2] = new_index_vnode; | |||||
| get_item->set_inputs(new_inputs); | |||||
| get_item->set_fullname_with_scope(lstm_cnode->fullname_with_scope() + "_getitem_" + std::to_string(new_index)); | |||||
| if (get_item->abstract() == nullptr) { | |||||
| MS_LOG(ERROR) << "get_item's abstract is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int> squeeze_axis{0}; | |||||
| auto squeeze_node = CreateSqueezeNode(func_graph, get_item, squeeze_axis); | |||||
| if (squeeze_node == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto get_item_users = manager->node_users()[get_item]; | |||||
| for (auto &get_item_user : get_item_users) { | |||||
| tr.SetEdge(get_item_user.first, get_item_user.second, squeeze_node); | |||||
| } | |||||
| } | |||||
| tr.Commit(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int output_num) { | |||||
| MS_ASSERT(cnode != nullptr); | |||||
| AbstractBasePtrList abstract_list; | |||||
| for (int i = 0; i < output_num; ++i) { | |||||
| std::vector<int64_t> shape_vector; | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "create abstract_tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| abstract_list.emplace_back(abstract_tensor); | |||||
| } | |||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||||
| if (abstract_tuple == nullptr) { | |||||
| MS_LOG(ERROR) << "create abstract_tuple failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| cnode->set_abstract(abstract_tuple); | |||||
| return RET_OK; | |||||
| } | |||||
| CNodePtr TfliteLstmCellFusion::CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node, | |||||
| const std::vector<int> &axis) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| std::unique_ptr<schema::SqueezeT> attr = std::make_unique<schema::SqueezeT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new SqueezeT failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->axis = axis; | |||||
| auto new_primitive_t = std::make_unique<schema::PrimitiveT>(); | |||||
| if (new_primitive_t == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_t is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| new_primitive_t->value.type = schema::PrimitiveType_Squeeze; | |||||
| new_primitive_t->value.value = attr.release(); | |||||
| auto new_primtive_c = std::shared_ptr<lite::PrimitiveC>(lite::PrimitiveC::Create(new_primitive_t.release())); | |||||
| if (new_primtive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| ValueNodePtr value_node = NewValueNode(new_primtive_c); | |||||
| auto squeeze_cnode = func_graph->NewCNode({value_node, input_node}); | |||||
| squeeze_cnode->set_abstract(input_node->abstract()->Clone()); | |||||
| squeeze_cnode->set_fullname_with_scope("squeeze_" + input_node->fullname_with_scope()); | |||||
| return squeeze_cnode; | |||||
| } | |||||
| const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| MS_LOG(DEBUG) << "lstm fusion pass"; | |||||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||||
| return nullptr; | |||||
| } | |||||
| if (!utils::isa<CNodePtr>(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_list_stack_cnode = utils::cast<CNodePtr>(node); | |||||
| auto tuple_get_item_node = tensor_list_stack_cnode->input(1); | |||||
| if (!utils::isa<CNodePtr>(tuple_get_item_node)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item_node); | |||||
| auto while_node = tuple_get_item_cnode->input(1); | |||||
| if (!utils::isa<CNodePtr>(while_node)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto while_cnode = utils::cast<CNodePtr>(while_node); | |||||
| if (CheckIfCNodeIsNull(while_cnode) != RET_OK || CheckInputSize(while_cnode, while_inputs_num_) != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| if (!CheckReferencedOutputs(func_graph, while_cnode)) { | |||||
| return nullptr; | |||||
| } | |||||
| PrimitiveVarMapPtr primitive_vars_cond = std::make_shared<PrimitiveVarMap>(); | |||||
| auto cond_graph_pattern = GetCondGraphPattern(primitive_vars_cond); | |||||
| auto cond_equiv = CheckSubGraph(func_graph, cond_graph_pattern, primitive_vars_cond, while_cnode->input(1), | |||||
| cond_cnodes_num_, cond_nodes_num_); | |||||
| if (cond_equiv == nullptr || cond_equiv->empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| PrimitiveVarMapPtr primitive_vars_body = std::make_shared<PrimitiveVarMap>(); | |||||
| auto body_graph_pattern = GetBodyGraphPattern(primitive_vars_body); | |||||
| auto body_equiv = CheckSubGraph(func_graph, body_graph_pattern, primitive_vars_body, while_cnode->input(2), | |||||
| body_cnodes_num_, body_nodes_num_); | |||||
| if (body_equiv == nullptr || body_equiv->empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| float smooth = 0.0f; | |||||
| if (!CheckBodyGraph(func_graph, body_equiv, while_cnode, &smooth)) { | |||||
| return nullptr; | |||||
| } | |||||
| const std::string lstm_name = "lstm_" + while_cnode->fullname_with_scope(); | |||||
| auto lstm_node = CreateLSTMNode(func_graph, equiv, body_equiv, lstm_name, smooth); | |||||
| if (lstm_node == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto status = SetAbstractTuple(lstm_node, kLSTMOutputNum); | |||||
| if (status != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| auto get_item_node = CreateOutputGetItem(func_graph, lstm_node, 0); | |||||
| if (get_item_node == nullptr) { | |||||
| MS_LOG(DEBUG) << "create lstm output get_item node failed"; | |||||
| return nullptr; | |||||
| } | |||||
| status = AdjustOtherGetItems(func_graph, while_cnode, lstm_node, tuple_get_item_cnode); | |||||
| if (status != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<int> squeeze_axis{1}; // our lstm output:0 have an extra axis that tflite not have, it must be squeezed | |||||
| auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis); | |||||
| if (squeeze_node == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cond_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 1); | |||||
| func_graph->DropFuncGraphCNodeIndex(cond_cnode_index_pair); | |||||
| auto body_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 2); | |||||
| func_graph->DropFuncGraphCNodeIndex(body_cnode_index_pair); | |||||
| MS_LOG(INFO) << "lstm node:" << lstm_node->fullname_with_scope() << " fusion success"; | |||||
| return squeeze_node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,82 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "utils/utils.h" | |||||
| #include "include/errorcode.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class TfliteLstmCellFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit TfliteLstmCellFusion(const std::string &name = "tflite_lstm_cell_fusion", bool multigraph = true, | |||||
| int input_length = 0, int var_num = 0, int cond_nodes_num = 0, int cond_cnodes_num = 0, | |||||
| int body_nodes_num = 0, int body_cnodes_num = 0); | |||||
| ~TfliteLstmCellFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| public: | |||||
| static EquivPtr MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars, | |||||
| const AnfNodePtr &pattern); | |||||
| static EquivPtr CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern, | |||||
| const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph, | |||||
| const size_t cnode_num, const size_t all_node_num); | |||||
| static lite::STATUS SetAbstractTuple(const CNodePtr &cnode, const int output_num); | |||||
| static CNodePtr CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, const int item_index); | |||||
| protected: | |||||
| VarPtr cell_smooth_old_ = nullptr; | |||||
| VarPtr cell_smooth_new_ = nullptr; | |||||
| VarPtr hidden_smooth_old_ = nullptr; | |||||
| VarPtr hidden_smooth_new_ = nullptr; | |||||
| std::vector<VarPtr> while_input_vars_; | |||||
| lite::STATUS GetFloatScalarFromParamValueLite(const AnfNodePtr ¶m_value, float *v) const; | |||||
| CNodePtr CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node, | |||||
| const std::vector<int> &axis) const; | |||||
| lite::STATUS AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode, | |||||
| const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) const; | |||||
| AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; | |||||
| virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; | |||||
| virtual CNodePtr CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const EquivPtr &body_equiv, | |||||
| const std::string &base_name, const float smooth) const; | |||||
| private: | |||||
| bool CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const CNodePtr &while_cnode, | |||||
| float *smooth) const; | |||||
| bool CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) const; | |||||
| lite::STATUS GetConcatedParam(const std::vector<AnfNodePtr> ¶ms, const ParameterPtr &new_param, | |||||
| bool is_bias) const; | |||||
| private: | |||||
| size_t while_input_var_num_ = 0; | |||||
| size_t while_inputs_num_ = 0; | |||||
| size_t cond_nodes_num_ = 0; | |||||
| size_t cond_cnodes_num_ = 0; | |||||
| size_t body_nodes_num_ = 0; | |||||
| size_t body_cnodes_num_ = 0; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_ | |||||