From a771c10fb5f65e9464bf5869d687cce9a7a0b205 Mon Sep 17 00:00:00 2001 From: wangzhe Date: Wed, 30 Dec 2020 16:39:29 +0800 Subject: [PATCH] debug encoder --- mindspore/lite/nnacl/fp32/fill_fp32.c | 7 ++ mindspore/lite/nnacl/fp32/fill_fp32.h | 2 + mindspore/lite/src/ops/fill.cc | 75 ++---------- mindspore/lite/src/ops/transpose.cc | 9 ++ .../src/runtime/kernel/arm/fp32/fill_fp32.cc | 28 ++++- .../src/runtime/kernel/arm/fp32/fill_fp32.h | 2 + .../src/runtime/kernel/arm/fp32/gru_fp32.cc | 14 +-- .../runtime/kernel/arm/fp32/transpose_fp32.cc | 16 ++- .../lite/tools/converter/anf_transform.cc | 4 +- .../legacy_optimizer/graph/switch_pass.cc | 3 +- .../converter/parser/tf/tf_conv_parser.cc | 28 +++-- .../converter/parser/tf/tf_fill_parser.cc | 47 ++++---- .../converter/parser/tf/tf_fill_parser.h | 1 + .../converter/parser/tf/tf_node_parser.cc | 2 +- .../converter/parser/tf/tf_pool_parser.cc | 2 +- .../converter/parser/tf/tf_rsqrt_parser.cc | 61 ++++++++++ .../converter/parser/tf/tf_rsqrt_parser.h | 38 ++++++ .../parser/tf/tf_transpose_parser.cc | 43 ++++--- .../lite/tools/optimizer/common/gllo_utils.cc | 10 +- .../fusion/bidirection_tf_gru_cell_fusion.cc | 109 +++++++++--------- .../fusion/bidirection_tf_gru_cell_fusion.h | 7 +- .../graph/update_conv2d_param_pass.cc | 37 ++++++ .../graph/update_conv2d_param_pass.h | 6 + 23 files changed, 360 insertions(+), 191 deletions(-) create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_rsqrt_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_rsqrt_parser.h diff --git a/mindspore/lite/nnacl/fp32/fill_fp32.c b/mindspore/lite/nnacl/fp32/fill_fp32.c index 06a4423a6d..be915092bc 100644 --- a/mindspore/lite/nnacl/fp32/fill_fp32.c +++ b/mindspore/lite/nnacl/fp32/fill_fp32.c @@ -22,3 +22,10 @@ int Fill(float *output, int size, float data) { } return NNACL_OK; } + +int FillInt32(int *output, int size, int data) { + for (int i = 0; i < size; ++i) { + output[i] = data; + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp32/fill_fp32.h b/mindspore/lite/nnacl/fp32/fill_fp32.h index 6f7b5d7f90..7678a61436 100644 --- a/mindspore/lite/nnacl/fp32/fill_fp32.h +++ b/mindspore/lite/nnacl/fp32/fill_fp32.h @@ -35,6 +35,8 @@ typedef struct FillParameter { extern "C" { #endif int Fill(float *output, int size, float data); + +int FillInt32(int *output, int size, int data); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/ops/fill.cc b/mindspore/lite/src/ops/fill.cc index 8f54d21f16..b322bc1ac6 100644 --- a/mindspore/lite/src/ops/fill.cc +++ b/mindspore/lite/src/ops/fill.cc @@ -56,26 +56,6 @@ PrimitiveC *FillCreator(const schema::Primitive *primitive) { return PrimitiveC: Registry FillRegistry(schema::PrimitiveType_Fill, FillCreator); #endif -template -void CalShape(const T *data, const std::vector &inputs, std::vector *out_shape, int shape_size) { - int input_count = inputs[0]->ElementsNum(); - int index = 0; - int size = 1; - for (int i = 0; i < shape_size; i++) { - if (static_cast(data[i]) == -1) { - index = i; - } else if (static_cast(data[i]) == 0) { - size *= inputs[0]->shape().at(i); - } else { - size *= data[i]; - } - out_shape->push_back(data[i]); - } - if (static_cast(data[index]) == -1) { - (*out_shape).at(index) = input_count / size; - } -} - int Fill::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); @@ -94,54 +74,23 @@ int Fill::InferShape(std::vector inputs_, std::vector output return RET_INFER_INVALID; } - std::vector out_shape; + std::vector output_shape; + auto param_dims = GetDims(); + for (size_t i = 0; i < param_dims.size(); i++) { + output_shape.push_back(param_dims.at(i)); + } + if (inputs_.size() == kDoubleNum) { - auto shape_tensor = inputs_.at(1); - if (shape_tensor->IsConst()) { - if (shape_tensor->data_c() == nullptr || (shape_tensor->shape().size() == 1 && shape_tensor->shape()[0] == 0)) { - MS_LOG(DEBUG) << "reshape to a scalar."; - output->set_shape(out_shape); - return RET_OK; - } - } - if (shape_tensor->data_c() == nullptr) { - MS_LOG(INFO) << "Do infer shape in runtime."; + auto input_dims = inputs_.at(1); + MS_ASSERT(input_dims != nullptr); + if (input_dims->data_c() == nullptr) { return RET_INFER_INVALID; } - size_t shape_size = shape_tensor->ElementsNum(); - switch (shape_tensor->data_type()) { - case kNumberTypeInt8: { - auto data = reinterpret_cast(shape_tensor->MutableData()); - CalShape(data, inputs_, &out_shape, shape_size); - } break; - case kNumberTypeInt32: { - auto data = reinterpret_cast(shape_tensor->MutableData()); - CalShape(data, inputs_, &out_shape, shape_size); - } break; - case kNumberTypeInt64: { - auto data = reinterpret_cast(shape_tensor->MutableData()); - CalShape(data, inputs_, &out_shape, shape_size); - } break; - case kNumberTypeFloat: { - auto data = reinterpret_cast(shape_tensor->MutableData()); - CalShape(data, inputs_, &out_shape, shape_size); - } break; - case kNumberTypeUInt32: { - auto data = reinterpret_cast(shape_tensor->MutableData()); - CalShape(data, inputs_, &out_shape, shape_size); - } break; - default: { - MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type(); - return RET_INFER_ERR; - } - } - } else { - for (size_t i = 0; i < GetDims().size(); i++) { - out_shape.push_back(GetDims().at(i)); - } + int *dims_data = reinterpret_cast(input_dims->data_c()); + output_shape = std::vector{dims_data, dims_data + input_dims->ElementsNum()}; } - output->set_shape(out_shape); + output->set_shape(output_shape); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/transpose.cc b/mindspore/lite/src/ops/transpose.cc index 1b09454972..36eaefba4f 100644 --- a/mindspore/lite/src/ops/transpose.cc +++ b/mindspore/lite/src/ops/transpose.cc @@ -116,6 +116,15 @@ int Transpose::InferShape(std::vector inputs_, std::vector o MS_ASSERT(output != nullptr); std::vector perm = GetPerm(); + if (inputs_.size() == kDoubleNum) { + auto input_perm = inputs_.at(1); + MS_ASSERT(input_perm != nullptr); + if (input_perm->data_c() == nullptr) { + return RET_INFER_INVALID; + } + int *perm_data = reinterpret_cast(input_perm->data_c()); + perm = std::vector{perm_data, perm_data + input_perm->ElementsNum()}; + } std::vector nchw2nhwc_perm = {0, 2, 3, 1}; std::vector nhwc2nchw_perm = {0, 3, 1, 2}; std::vector in_shape = input->shape(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.cc index d21931d1c8..25c7a42eb8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.cc @@ -48,7 +48,15 @@ int FillCPUKernel::DoFill(int task_id) { return RET_OK; } int offset = task_id * thread_sz_stride_; - int ret = Fill(out_ptr_ + offset, size, src_data_); + auto input_tensor = in_tensors_.at(0); + int ret = RET_OK; + if (input_tensor->data_type() == kNumberTypeFloat32 || input_tensor->data_type() == kNumberTypeFloat) { + ret = Fill(out_ptr_ + offset, size, src_data_); + } else if (input_tensor->data_type() == kNumberTypeInt32 || input_tensor->data_type() == kNumberTypeInt) { + ret = FillInt32(int32_out_ptr_ + offset, size, int32_src_data_); + } else { + return RET_ERROR; + } if (ret != RET_OK) { MS_LOG(ERROR) << "FillRun error task_id[" << task_id << "] error_code[" << ret << "]"; return ret; @@ -67,11 +75,20 @@ int FillRun(void *cdata, int task_id) { } int FillCPUKernel::Run() { - auto fillData = in_tensors_.at(in_tensors_.size() - 1); + auto fill_input = in_tensors_.front(); auto output = out_tensors_.front(); - auto fill_data = reinterpret_cast(fillData->MutableData()); - src_data_ = fill_data[0]; - out_ptr_ = reinterpret_cast(output->MutableData()); + if (fill_input->data_type() == kNumberTypeFloat32 || fill_input->data_type() == kNumberTypeFloat) { + auto fill_data = reinterpret_cast(fill_input->MutableData()); + src_data_ = fill_data[0]; + out_ptr_ = reinterpret_cast(output->MutableData()); + } else if (fill_input->data_type() == kNumberTypeInt32 || fill_input->data_type() == kNumberTypeInt) { + auto fill_data = reinterpret_cast(fill_input->MutableData()); + int32_src_data_ = fill_data[0]; + int32_out_ptr_ = reinterpret_cast(output->MutableData()); + } else { + MS_LOG(ERROR) << "unsupported fill data type " << fill_input->data_type(); + return RET_ERROR; + } auto ret = ParallelLaunch(this->context_->thread_pool_, FillRun, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "FillRun error error_code[" << ret << "]"; @@ -80,5 +97,6 @@ int FillCPUKernel::Run() { return RET_OK; } +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Fill, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Fill, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.h index 5927c854e6..c6990e267c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.h @@ -44,6 +44,8 @@ class FillCPUKernel : public LiteKernel { int data_size_; float src_data_; float *out_ptr_; + int int32_src_data_; + int *int32_out_ptr_; int thread_count_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc index e7a1911bc9..cd70788e54 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc @@ -138,10 +138,10 @@ int GruCPUKernel::Run() { MS_ASSERT(output != nullptr); auto input_ptr = reinterpret_cast(input->data_c()); MS_ASSERT(input_ptr); - auto output_ptr = reinterpret_cast(output->MutableData()); + auto output_ptr = reinterpret_cast(output->data_c()); MS_ASSERT(output_ptr); auto output_hidden_state = out_tensors_[1]; - memcpy(output_hidden_state->MutableData(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float)); + memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float)); int check_seq_len = gru_parm_->seq_len_; if (in_tensors_.size() == 6) { auto seq_len = reinterpret_cast(in_tensors_.at(5)->data_c()); @@ -152,12 +152,12 @@ int GruCPUKernel::Run() { check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0])); } - MS_ASSERT(weight_g_ptr_); - MS_ASSERT(weight_r_ptr_); - MS_ASSERT(bias_ptr_); - MS_ASSERT(gate_buffer_); + MS_ASSERT(weight_g_ptr_ != nullptr); + MS_ASSERT(weight_r_ptr_ != nullptr); + MS_ASSERT(bias_ptr_ != nullptr); + MS_ASSERT(gate_buffer_ != nullptr); Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_, - reinterpret_cast(output_hidden_state->MutableData()), gate_buffer_, check_seq_len, gru_parm_); + reinterpret_cast(output_hidden_state->data_c()), gate_buffer_, check_seq_len, gru_parm_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc index d8b4a3f63f..b1f63901c8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc @@ -39,7 +39,7 @@ int TransposeCPUKernel::Init() { int TransposeCPUKernel::ReSize() { TransposeParameter *param = reinterpret_cast(op_parameter_); - if (in_tensors_.at(kInputIndex)->shape().size() != static_cast(param->num_axes_)) { + if (in_tensors_.at(kInputIndex)->shape().size() != static_cast(param->num_axes_) && in_tensors_.size() != 2) { return RET_OK; } auto &inTensor = in_tensors_.front(); @@ -89,6 +89,20 @@ int TransposeCPUKernel::Run() { MS_ASSERT(out_data_); TransposeParameter *param = reinterpret_cast(this->op_parameter_); + if (in_tensors_.size() == 2) { + auto input_perm = in_tensors_.at(1); + MS_ASSERT(input_perm != nullptr); + MS_ASSERT(input_perm->data_c() != nullptr); + int *perm_data = reinterpret_cast(input_perm->data_c()); + auto perm = std::vector{perm_data, perm_data + input_perm->ElementsNum()}; + for (int i = 0; i < input_perm->ElementsNum(); ++i) { + param->perm_[i] = perm[i]; + } + for (int i = input_perm->ElementsNum(); i <= 8; ++i) { + param->perm_[i] = 0; + } + param->num_axes_ = input_perm->ElementsNum(); + } if (in_tensor->shape().size() != static_cast(param->num_axes_)) { memcpy(out_data_, in_data_, in_tensor->ElementsNum() * sizeof(float)); return RET_OK; diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 43fc4b48f0..538178cd78 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -161,7 +161,9 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap inne_context_ptr->Init(); const_fold_pm->AddPass(std::make_shared(inne_context_ptr)); } - const_fold_pm->AddPass(std::make_shared()); + auto update_conv2d_param_pass = std::make_shared(); + update_conv2d_param_pass->SetFmkType(config->fmk); + const_fold_pm->AddPass(update_conv2d_param_pass); fusion_pm->AddPass(std::make_shared()); convert_pm->AddPass(std::make_shared()); if (config->fmk == lite::converter::FmkType_TFLITE) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc index 06ec76e1e6..76bfe5a04a 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc @@ -280,6 +280,7 @@ STATUS SingleSwitchPass::InsertPartialAndMergeAfterSwitch() { second_partial_node_->outputIndex.push_back(graph_->allTensors.size() - 1); } + auto origin_switch_outputs = switch_node_->outputIndex; switch_node_->outputIndex.clear(); for (size_t i = 3; i < switch_node_->inputIndex.size(); i++) { auto &switch_in_tensor = graph_->allTensors.at(i); @@ -338,7 +339,7 @@ STATUS SingleSwitchPass::InsertPartialAndMergeAfterSwitch() { merge_node->inputIndex.insert(merge_node->inputIndex.end(), second_partial_node_->outputIndex.begin(), second_partial_node_->outputIndex.end()); } - merge_node->outputIndex = origin_switch_output_tensor_indices_; + merge_node->outputIndex = origin_switch_outputs; graph_->nodes.push_back(std::move(merge_node)); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc index 426f004c7b..33521fb416 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc @@ -67,19 +67,23 @@ STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op, attr->strideW = strides[1]; auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1)); - if (weight_node == nullptr) { - MS_LOG(ERROR) << "Find Conv2D input weights failed"; - return RET_ERROR; - } - std::vector kernels(4); - status = ParseKernels(*weight_node, attr->format, &kernels); - if (status != RET_OK) { - return status; + if (weight_node != nullptr) { + std::vector kernels(4); + status = ParseKernels(*weight_node, attr->format, &kernels); + if (status != RET_OK) { + return status; + } + attr->kernelH = kernels[0]; + attr->kernelW = kernels[1]; + attr->channelIn = kernels[2]; + attr->channelOut = kernels[3]; + } else { + attr->kernelH = -1; + attr->kernelW = -1; + attr->channelIn = -1; + attr->channelOut = -1; + MS_LOG(WARNING) << "parsing of kernelH/W channelIn/Out is delayed"; } - attr->kernelH = kernels[0]; - attr->kernelW = kernels[1]; - attr->channelIn = kernels[2]; - attr->channelOut = kernels[3]; status = ParsePadMode(tf_op, &attr->padMode); if (status != RET_OK) { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_fill_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_fill_parser.cc index b37ab69b1a..801a28e5a9 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_fill_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_fill_parser.cc @@ -42,20 +42,15 @@ STATUS TFFillParser::Parse(const tensorflow::NodeDef &tf_op, return RET_NULL_PTR; } - primitive->value.type = schema::PrimitiveType_Fill; - primitive->value.value = attr.release(); - *primitiveC = PrimitiveC::Create(primitive.release()); - if (*primitiveC == nullptr) { - MS_LOG(ERROR) << "primitiveC is nullptr"; - return RET_ERROR; - } - *output_size = 1; inputs->emplace_back(tf_op.input(1)); // parse dims tensorflow::AttrValue attr_value; auto dims_node = GetConstInputNode(tf_node_map, tf_op.input(0)); - MS_ASSERT(dims_node != nullptr); - if (dims_node != nullptr && TensorFlowUtils::FindAttrValue(*dims_node, "value", &attr_value)) { + if (dims_node != nullptr) { + if (!TensorFlowUtils::FindAttrValue(*dims_node, "value", &attr_value)) { + MS_LOG(ERROR) << "fill dims input not have value attr"; + return RET_ERROR; + } if (attr_value.value_case() != tensorflow::AttrValue::kTensor) { MS_LOG(ERROR) << "The attrValue of value should have tensor type, actual: " << attr_value.value_case() << ", node: " << tf_op.name().c_str(); @@ -66,32 +61,44 @@ STATUS TFFillParser::Parse(const tensorflow::NodeDef &tf_op, MS_LOG(ERROR) << "The dimsTensor dataType should be DT_INT32, actual : " << dims_tensor.dtype(); return RET_ERROR; } - const tensorflow::TensorShapeProto &dimsTensorShape = dims_tensor.tensor_shape(); - size_t shapeSize = 1; - for (int i = 0; i < dimsTensorShape.dim_size(); i++) { - shapeSize *= dimsTensorShape.dim(i).size(); + const tensorflow::TensorShapeProto &dims_tensor_shape = dims_tensor.tensor_shape(); + size_t shape_size = 1; + for (int i = 0; i < dims_tensor_shape.dim_size(); i++) { + shape_size *= dims_tensor_shape.dim(i).size(); } size_t size = dims_tensor.int_val().size(); if (size > 0) { - for (size_t i = 0; i < shapeSize; i++) { - attr->dims.emplace_back(dims_tensor.int_val().Get(0)); + for (size_t i = 0; i < shape_size; i++) { + attr->dims.emplace_back(dims_tensor.int_val().Get(i)); } } else { size = dims_tensor.tensor_content().length(); - if (size == shapeSize * sizeof(int32_t)) { - attr->dims.resize(shapeSize); + if (size > 0) { + if (size != shape_size * sizeof(int32_t)) { + MS_LOG(ERROR) << "tensor size mismatch"; + return RET_ERROR; + } + attr->dims.resize(shape_size); if (EOK != ::memcpy_s(attr->dims.data(), size, dims_tensor.tensor_content().data(), size)) { MS_LOG(ERROR) << "Memcpy_s from dimsTensor to attr failed"; return RET_ERROR; } } else { - MS_LOG(ERROR) << "Can not find weight data, node: " << dims_node->name().c_str(); - return RET_ERROR; + MS_LOG(DEBUG) << "empty dims"; } } } else { inputs->emplace_back(tf_op.input(0)); } + + primitive->value.type = schema::PrimitiveType_Fill; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + *output_size = 1; return RET_OK; } TFNodeRegistrar g_tfFillParser("Fill", new TFFillParser()); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_fill_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_fill_parser.h index 1793514261..f663cf3917 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_fill_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_fill_parser.h @@ -15,6 +15,7 @@ */ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ + #include #include #include diff --git a/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc index 2d6f726c3c..017c5fae6a 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc @@ -46,7 +46,7 @@ const NodeDef *TFNodeParser::GetConstInputNode(const std::mapop() != "Const") { - MS_LOG(ERROR) << "Attr node is not Const"; + MS_LOG(DEBUG) << "Attr node is not Const"; return nullptr; } return node; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_pool_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_pool_parser.cc index d9cdbc08a0..6886d5afc9 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_pool_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_pool_parser.cc @@ -54,7 +54,7 @@ STATUS TFPoolParser::Parse(const tensorflow::NodeDef &tf_op, if (attr_value.s() == "VALID") { attr->padMode = schema::PadMode_VALID; } else if (attr_value.s() == "SAME") { - attr->padMode = schema::PadMode_VALID; + attr->padMode = schema::PadMode_SAME_UPPER; } } diff --git a/mindspore/lite/tools/converter/parser/tf/tf_rsqrt_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_rsqrt_parser.cc new file mode 100644 index 0000000000..ef2251e2a4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_rsqrt_parser.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_rsqrt_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFRsqrtParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF RsqrtParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_Rsqrt; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + return RET_OK; +} +TFNodeRegistrar g_tfRsqrtParser("Rsqrt", new TFRsqrtParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_rsqrt_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_rsqrt_parser.h new file mode 100644 index 0000000000..dd40da5ce0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_rsqrt_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RSQRT_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RSQRT_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFRsqrtParser : public TFNodeParser { + public: + TFRsqrtParser() = default; + ~TFRsqrtParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RSQRT_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc index b8d3f52d44..9aba20133d 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc @@ -41,28 +41,36 @@ STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op, MS_LOG(ERROR) << "new attr failed"; return RET_NULL_PTR; } - attr->conjugate = false; + + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + auto perm_node = GetConstInputNode(tf_node_map, tf_op.input(1)); if (perm_node == nullptr) { - MS_LOG(ERROR) << "Find Transpose input perm failed"; - return RET_ERROR; - } - tensorflow::AttrValue attr_value; - if (!TensorFlowUtils::FindAttrValue(*perm_node, "value", &attr_value)) { - MS_LOG(ERROR) << "The value attr should be specified"; - return RET_ERROR; - } - auto tensor_proto = attr_value.tensor(); - if (tensor_proto.int_val_size() > 0) { - for (int i = 0; i < tensor_proto.int_val_size(); ++i) { - attr->perm.push_back(tensor_proto.int_val(i)); + status = AddOpInput(tf_op, 1, inputs); + if (status != RET_OK) { + return status; } } else { - auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); - auto data = reinterpret_cast(tensor_proto.tensor_content().data()); - for (size_t i = 0; i < data_num; ++i) { - attr->perm.push_back(data[i]); + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(*perm_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + auto tensor_proto = attr_value.tensor(); + if (tensor_proto.int_val_size() > 0) { + for (int i = 0; i < tensor_proto.int_val_size(); ++i) { + attr->perm.push_back(tensor_proto.int_val(i)); + } + } else { + auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); + auto data = reinterpret_cast(tensor_proto.tensor_content().data()); + for (size_t i = 0; i < data_num; ++i) { + attr->perm.push_back(data[i]); + } } } @@ -75,7 +83,6 @@ STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op, } *output_size = 1; - auto status = AddOpInput(tf_op, 0, inputs); return status; } TFNodeRegistrar g_tfTransposeParser("Transpose", new TFTransposeParser()); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 397220d584..6403dfa6cd 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -693,7 +693,7 @@ STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int3 tensor->set_tensor_shape({filterC, filterK, filterH, filterW}); } else if (type == kKHWC2CHWK) { tensor->set_tensor_shape({filterC, filterH, filterW, filterK}); - } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) { + } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC || type == kHWCK2KHWC) { tensor->set_tensor_shape({filterK, filterH, filterW, filterC}); } else { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; @@ -812,7 +812,8 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType } } break; case kHWCK2KCHW: - case kHWCK2CKHW: { + case kHWCK2CKHW: + case kHWCK2KHWC: { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { @@ -821,9 +822,12 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType if (type == kHWCK2KCHW) { p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); - } else { + } else if (type == kHWCK2CKHW) { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } *p2Buff = *p1Buff; } diff --git a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc index 83e69b28d3..1cc32b483b 100644 --- a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc @@ -25,7 +25,6 @@ namespace mindspore { namespace opt { namespace { -constexpr size_t kWhileCommonInputsLength = 2; constexpr size_t kWhileUniqInputsLength = 6; constexpr size_t kCondNodesNum = 12; constexpr size_t kCondCNodesNum = 4; @@ -47,16 +46,11 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, : PatternProcessPass(name, multigraph) { /* * vars for while input - * common: - * 0:const0 1:init_state * fw_while_inputs: * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias * bw_while_inputs: * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias */ - for (size_t i = 0; i < kWhileCommonInputsLength; ++i) { - common_vars_.emplace_back(std::make_shared()); - } for (size_t i = 0; i < kWhileUniqInputsLength; ++i) { fw_vars_.emplace_back(std::make_shared()); bw_vars_.emplace_back(std::make_shared()); @@ -64,17 +58,16 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, input_ = std::make_shared(); input_length_ = std::make_shared(); transpose_input_ = std::make_shared(); + fw_init_state_ = std::make_shared(); + bw_init_state_ = std::make_shared(); } const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { - auto const1 = std::make_shared(IsParameterNode); - auto ele_shape = std::make_shared(IsParameterNode); - // forward auto fw_max1 = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_}); - auto fw_max2 = - VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, fw_max1}); + auto fw_max2 = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), + std::make_shared(IsParameterNode), fw_max1}); auto fw_shape = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_}); @@ -84,32 +77,33 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2}); auto fw_reserve = - VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape, - fw_stride}); + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), + std::make_shared(IsParameterNode), fw_stride}); auto fw_from_tensor = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), - transpose_input_, ele_shape}); + transpose_input_, std::make_shared(IsParameterNode)}); auto is_fw_while = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_While)); - auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], common_vars_[0], fw_stride, common_vars_[0], - fw_reserve, common_vars_[1], fw_min, fw_from_tensor, input_length_}); + auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], std::make_shared(IsParameterNode), + fw_stride, std::make_shared(IsParameterNode), fw_reserve, fw_init_state_, fw_min, + fw_from_tensor, input_length_}); fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end()); - fw_while.emplace_back(common_vars_[1]); + fw_while.emplace_back(std::make_shared()); auto fw_get_item = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)), fw_while, std::make_shared()}); auto fw_stack = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)), - fw_get_item, ele_shape}); - auto fw_out_trans = - VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), fw_stack}); + fw_get_item, std::make_shared(IsParameterNode)}); + auto fw_out_trans = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), + fw_stack, std::make_shared()}); // backward auto bw_reverse_seq = VectorRef( {std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), input_, input_length_}); auto bw_max1 = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_}); - auto bw_max2 = - VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, bw_max1}); - auto bw_trans = - VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_reverse_seq}); + auto bw_max2 = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), + std::make_shared(IsParameterNode), bw_max1}); + auto bw_trans = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), + bw_reverse_seq, std::make_shared()}); auto bw_shape = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans}); auto bw_stride = @@ -117,22 +111,23 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { auto bw_min = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2}); auto bw_reserve = - VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape, - bw_stride}); + VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), + std::make_shared(IsParameterNode), bw_stride}); auto bw_from_tensor = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), bw_trans, - ele_shape}); + std::make_shared(IsParameterNode)}); auto is_bw_while = std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_While)); - auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], common_vars_[0], bw_stride, common_vars_[0], - bw_reserve, common_vars_[1], bw_min, bw_from_tensor, input_length_}); + auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], std::make_shared(IsParameterNode), + bw_stride, std::make_shared(IsParameterNode), bw_reserve, bw_init_state_, bw_min, + bw_from_tensor, input_length_}); bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end()); - bw_while.emplace_back(common_vars_[1]); + bw_while.emplace_back(std::make_shared()); auto bw_get_item = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)), bw_while, std::make_shared()}); auto bw_stack = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)), - bw_get_item, ele_shape}); - auto bw_out_trans = - VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_stack}); + bw_get_item, std::make_shared(IsParameterNode)}); + auto bw_out_trans = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), + bw_stack, std::make_shared()}); auto bw_reverse1 = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), bw_out_trans, input_length_}); @@ -416,10 +411,12 @@ STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, } CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, - const AnfNodePtr &hidden_state, + const AnfNodePtr &fw_init_state, + const AnfNodePtr &bw_init_state, const std::string base_name) const { - MS_ASSERT(func_graph); - MS_ASSERT(hidden_state); + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(fw_init_state != nullptr); + MS_ASSERT(bw_init_state != nullptr); auto stack_primitive = std::make_unique(); std::unique_ptr attr = std::make_unique(); attr->axis = 0; @@ -427,9 +424,9 @@ CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &f stack_primitive->value.value = attr.release(); auto stack_cvalue = lite::PrimitiveC::Create(stack_primitive.release()); auto value_node = NewValueNode(std::shared_ptr(stack_cvalue)); - std::vector new_node_inputs = {value_node, hidden_state, hidden_state}; + std::vector new_node_inputs = {value_node, fw_init_state, bw_init_state}; auto new_node = func_graph->NewCNode(new_node_inputs); - new_node->set_abstract(hidden_state->abstract()->Clone()); + new_node->set_abstract(fw_init_state->abstract()->Clone()); new_node->set_fullname_with_scope("stack_hidden_" + base_name); return new_node; } @@ -452,31 +449,33 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr auto value_node = NewValueNode(std::shared_ptr(gru_cvalue)); auto fw_gate_kernel = utils::cast((*equiv)[fw_vars_[2]]); - MS_ASSERT(fw_gate_kernel); + MS_ASSERT(fw_gate_kernel != nullptr); auto fw_gate_bias = utils::cast((*equiv)[fw_vars_[3]]); - MS_ASSERT(fw_gate_bias); + MS_ASSERT(fw_gate_bias != nullptr); auto fw_cand_kernel = utils::cast((*equiv)[fw_vars_[4]]); - MS_ASSERT(fw_cand_kernel); + MS_ASSERT(fw_cand_kernel != nullptr); auto fw_cand_bias = utils::cast((*equiv)[fw_vars_[5]]); - MS_ASSERT(fw_cand_bias); + MS_ASSERT(fw_cand_bias != nullptr); auto bw_gate_kernel = utils::cast((*equiv)[bw_vars_[2]]); - MS_ASSERT(bw_gate_kernel); + MS_ASSERT(bw_gate_kernel != nullptr); auto bw_gate_bias = utils::cast((*equiv)[bw_vars_[3]]); - MS_ASSERT(bw_gate_bias); + MS_ASSERT(bw_gate_bias != nullptr); auto bw_cand_kernel = utils::cast((*equiv)[bw_vars_[4]]); - MS_ASSERT(bw_cand_kernel); + MS_ASSERT(bw_cand_kernel != nullptr); auto bw_cand_bias = utils::cast((*equiv)[bw_vars_[5]]); - MS_ASSERT(bw_cand_bias); + MS_ASSERT(bw_cand_bias != nullptr); - auto hidden = utils::cast((*equiv)[common_vars_[1]]); - MS_ASSERT(hidden); - auto stacked_hidden = GetStackedHiddenState(func_graph, hidden, base_name); + auto fw_init_state = utils::cast((*equiv)[fw_init_state_]); + MS_ASSERT(fw_init_state != nullptr); + auto bw_init_state = utils::cast((*equiv)[bw_init_state_]); + MS_ASSERT(bw_init_state != nullptr); + auto stacked_hidden = GetStackedHiddenState(func_graph, fw_init_state, bw_init_state, base_name); if (stacked_hidden == nullptr) { return nullptr; } auto input_length = utils::cast((*equiv)[input_length_]); - MS_ASSERT(hidden); + MS_ASSERT(hidden != nullptr); int input_size = 0; int hidden_size = 0; @@ -536,8 +535,8 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, const std::string base_name) const { - MS_ASSERT(func_graph); - MS_ASSERT(gru_output); + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(gru_output != nullptr); auto split_primitive = std::make_unique(); std::unique_ptr split_attr = std::make_unique(); split_attr->numberSplit = 2; @@ -603,8 +602,8 @@ CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node, const EquivPtr &equiv) const { - MS_ASSERT(func_graph); - MS_ASSERT(concat_node); + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(concat_node != nullptr); MS_LOG(DEBUG) << "bidirection tf gru fusion pass"; if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); @@ -612,7 +611,7 @@ const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_gr } auto transpose_input = utils::cast((*equiv)[transpose_input_]); - MS_ASSERT(transpose_input); + MS_ASSERT(transpose_input != nullptr); if (!utils::isa(transpose_input) || GetCNodeType(transpose_input) != schema::PrimitiveType_Transpose) { return nullptr; } diff --git a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h index be5ef15190..53819466bd 100644 --- a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h @@ -54,18 +54,19 @@ class BiDirectionTfGruCellFusion : public PatternProcessPass { float *tensor_data) const; void CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, const int r1, const int c0, const int c1, float *data, bool t = false) const; - CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &hidden_state, - const std::string base_name) const; + CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state, + const AnfNodePtr &bw_init_state, const std::string base_name) const; CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, const std::string base_name) const; private: - std::vector common_vars_; std::vector fw_vars_; std::vector bw_vars_; VarPtr input_; VarPtr input_length_; VarPtr transpose_input_; + VarPtr fw_init_state_; + VarPtr bw_init_state_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc index 95f247ea94..102428063b 100644 --- a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc @@ -53,7 +53,44 @@ bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) { primT->value.AsDepthwiseConv2D()->channelIn = weight->tensor_shape().at(0); } } + } else if (type == schema::PrimitiveType_Conv2D) { + auto conv2d_cnode = node->cast(); + auto primitive_c = GetValueNode>(conv2d_cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "Conv2D node has no primitiveC."; + continue; + } + auto primT = primitive_c->primitiveT(); + if (primT == nullptr) { + MS_LOG(ERROR) << "Conv2D node has no primitiveT."; + continue; + } + auto conv2d_primt = primT->value.AsConv2D(); + auto weight_node = conv2d_cnode->input(lite::kAnfPopulaterInputNumTwo); + if (weight_node == nullptr) { + MS_LOG(ERROR) << "Conv2D weight node is nullptr."; + continue; + } + if (!weight_node->isa()) { + MS_LOG(ERROR) << "Conv2D weight node is not parameter."; + continue; + } + auto weight_param = weight_node->cast(); + if (!weight_param->has_default()) { + MS_LOG(ERROR) << "Conv2D weight node is not parameter."; + continue; + } + auto default_param = weight_param->default_param(); + auto weight_tensor = std::dynamic_pointer_cast(default_param); + auto weight_shape = weight_tensor->tensor_shape(); + if (fmk_type == lite::converter::FmkType_TF && conv2d_primt->format == schema::Format_NHWC) { + conv2d_primt->kernelH = weight_shape[0]; + conv2d_primt->kernelW = weight_shape[1]; + conv2d_primt->channelIn = weight_shape[2]; + conv2d_primt->channelOut = weight_shape[3]; + } } + if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { MS_LOG(ERROR) << "remove identity pass is failed."; return false; diff --git a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.h b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.h index c6958d8062..30894d9904 100644 --- a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.h +++ b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.h @@ -19,13 +19,19 @@ #include "schema/inner/model_generated.h" #include "backend/optimizer/common/pass.h" #include "tools/optimizer/common/gllo_utils.h" +#include "tools/converter/converter_flags.h" +using mindspore::lite::converter::FmkType; namespace mindspore::opt { class UpdateConv2DParamPass : public Pass { public: UpdateConv2DParamPass() : Pass("update_conv2d_param_pass") {} ~UpdateConv2DParamPass() override = default; bool Run(const FuncGraphPtr &graph) override; + void SetFmkType(FmkType fmk_type) { this->fmk_type = fmk_type; } + + private: + FmkType fmk_type = lite::converter::FmkType_ONNX; }; } // namespace mindspore::opt #endif // MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_