| @@ -191,12 +191,15 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc | ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | |||||
| ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| ### train | ### train | ||||
| @@ -39,6 +39,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/fusion/conv_transform_fusion.cc | ../optimizer/fusion/conv_transform_fusion.cc | ||||
| ../optimizer/fusion/conv_scale_fusion.cc | ../optimizer/fusion/conv_scale_fusion.cc | ||||
| ../optimizer/fusion/conv_bn_fusion.cc | ../optimizer/fusion/conv_bn_fusion.cc | ||||
| ../optimizer/fusion/conv_tuplegetitem_fusion.cc | |||||
| ../optimizer/fusion/constant_folding_fusion.cc | ../optimizer/fusion/constant_folding_fusion.cc | ||||
| ../optimizer/fusion/quant_dtype_cast_fusion.cc | ../optimizer/fusion/quant_dtype_cast_fusion.cc | ||||
| ../optimizer/fusion/layer_norm_fusion.cc | ../optimizer/fusion/layer_norm_fusion.cc | ||||
| @@ -51,6 +52,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/graph/unused_cast_node_remove_pass.cc | ../optimizer/graph/unused_cast_node_remove_pass.cc | ||||
| ../optimizer/graph/unused_transpose_node_remove_pass.cc | ../optimizer/graph/unused_transpose_node_remove_pass.cc | ||||
| ../optimizer/graph/identity_remove_pass.cc | ../optimizer/graph/identity_remove_pass.cc | ||||
| ../optimizer/graph/infershape_pass.cc | |||||
| ../optimizer/graph/slice_prepose_pass.cc | |||||
| ) | ) | ||||
| add_subdirectory(../anf_importer anf_importer) | add_subdirectory(../anf_importer anf_importer) | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "tools/optimizer/fusion/conv_tuple_activation_fusion.h" | #include "tools/optimizer/fusion/conv_tuple_activation_fusion.h" | ||||
| #include "tools/optimizer/fusion/conv_scale_fusion.h" | #include "tools/optimizer/fusion/conv_scale_fusion.h" | ||||
| #include "tools/optimizer/fusion/conv_bn_fusion.h" | #include "tools/optimizer/fusion/conv_bn_fusion.h" | ||||
| #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" | |||||
| #include "tools/optimizer/fusion/constant_folding_fusion.h" | #include "tools/optimizer/fusion/constant_folding_fusion.h" | ||||
| #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" | #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" | ||||
| #include "tools/optimizer/fusion/layer_norm_fusion.h" | #include "tools/optimizer/fusion/layer_norm_fusion.h" | ||||
| @@ -35,6 +36,8 @@ | |||||
| #include "tools/optimizer/graph/clip_convert_activation_pass.h" | #include "tools/optimizer/graph/clip_convert_activation_pass.h" | ||||
| #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | ||||
| #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | ||||
| #include "tools/optimizer/graph/infershape_pass.h" | |||||
| #include "tools/optimizer/graph/slice_prepose_pass.h" | |||||
| #include "tools/converter/quantizer/post_training_quantizer.h" | #include "tools/converter/quantizer/post_training_quantizer.h" | ||||
| #include "tools/converter/quantizer/quant_cast.h" | #include "tools/converter/quantizer/quant_cast.h" | ||||
| #include "tools/converter/quantizer/weight_quantizer.h" | #include "tools/converter/quantizer/weight_quantizer.h" | ||||
| @@ -76,6 +79,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| schema::ActivationType_RELU)); | schema::ActivationType_RELU)); | ||||
| pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation, | pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation, | ||||
| schema::ActivationType_RELU6)); | schema::ActivationType_RELU6)); | ||||
| pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>( | pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>( | ||||
| true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU)); | true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU)); | ||||
| pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>( | pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>( | ||||
| @@ -89,6 +93,12 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| weight_format_transform_pass->SetFmkType(config->fmk); | weight_format_transform_pass->SetFmkType(config->fmk); | ||||
| weight_format_transform_pass->SetQuantType(config->quantType); | weight_format_transform_pass->SetQuantType(config->quantType); | ||||
| graph_pm->AddPass(weight_format_transform_pass); | graph_pm->AddPass(weight_format_transform_pass); | ||||
| auto infershape_pass = std::make_shared<opt::InferShapePass>(); | |||||
| infershape_pass->SetFmkType(config->fmk); | |||||
| graph_pm->AddPass(infershape_pass); | |||||
| auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>(); | |||||
| slice_prepose_pass->SetFmkType(config->fmk); | |||||
| graph_pm->AddPass(slice_prepose_pass); | |||||
| if (config->fmk == lite::converter::FmkType_MS) { | if (config->fmk == lite::converter::FmkType_MS) { | ||||
| auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | ||||
| @@ -406,6 +406,55 @@ ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) { | |||||
| auto param_value = std::dynamic_pointer_cast<ParamValueLite>(param->default_param()); | auto param_value = std::dynamic_pointer_cast<ParamValueLite>(param->default_param()); | ||||
| return param_value; | return param_value; | ||||
| } | } | ||||
| AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index) { | |||||
| if (cnode == nullptr) { | |||||
| MS_LOG(ERROR) << "CNodePtr is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto inputs = cnode->inputs(); | |||||
| if (!(0 < index && index < inputs.size())) { | |||||
| return nullptr; | |||||
| } | |||||
| auto input = inputs[index]; | |||||
| if (input == nullptr) { | |||||
| MS_LOG(ERROR) << "CNode input is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| AbstractBasePtr abstract = nullptr; | |||||
| if (utils::isa<ParameterPtr>(input)) { | |||||
| auto parameter = input->cast<ParameterPtr>(); | |||||
| abstract = parameter->abstract(); | |||||
| } else if (utils::isa<CNodePtr>(input)) { | |||||
| auto input_cnode = input->cast<CNodePtr>(); | |||||
| if (GetCNodeType(input_cnode) == schema::PrimitiveType_TupleGetItem) { | |||||
| auto tuple_inputs = input_cnode->inputs(); | |||||
| MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize); | |||||
| auto get_item_input_cnode = tuple_inputs.at(1); | |||||
| MS_ASSERT(get_item_input_cnode != nullptr); | |||||
| auto idx = GetTupleGetItemOutIndex(input_cnode); | |||||
| if (!utils::isa<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) { | |||||
| MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple"; | |||||
| return nullptr; | |||||
| } | |||||
| auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract()); | |||||
| auto abstract_list = abstract_tuple->elements(); | |||||
| if (abstract_list.size() <= idx) { | |||||
| MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect"; | |||||
| return nullptr; | |||||
| } | |||||
| abstract = abstract_list[idx]; | |||||
| } else { | |||||
| abstract = input_cnode->abstract(); | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "unsupported input node type"; | |||||
| return nullptr; | |||||
| } | |||||
| return abstract; | |||||
| } | |||||
| bool IsParamNode(const BaseRef &n) { | bool IsParamNode(const BaseRef &n) { | ||||
| if (!utils::isa<ParameterPtr>(n)) { | if (!utils::isa<ParameterPtr>(n)) { | ||||
| return false; | return false; | ||||
| @@ -75,6 +75,8 @@ size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); | |||||
| ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node); | ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node); | ||||
| AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index); | |||||
| enum kTransFilterType { | enum kTransFilterType { | ||||
| kKCHW2HWCK, // 0 | kKCHW2HWCK, // 0 | ||||
| kKCHW2KHWC, | kKCHW2KHWC, | ||||
| @@ -0,0 +1,79 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" | |||||
| #include <memory> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| #include "securec/include/securec.h" | |||||
| namespace mindspore::opt { | |||||
| namespace { | |||||
| constexpr size_t kTupleGetItemLen = 3; | |||||
| bool IsTupleGetItemNode(const BaseRef &n) { | |||||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||||
| auto type = opt::GetCNodeType(n); | |||||
| return type == schema::PrimitiveType_TupleGetItem; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef ConvTupleGetItemFusion::DefinePattern() const { | |||||
| auto tuple_var = std::make_shared<CondVar>(IsTupleGetItemNode); | |||||
| auto tuple_index = std::make_shared<Var>(); | |||||
| auto conv_var = std::make_shared<CondVar>(IsConvNode); | |||||
| return VectorRef({tuple_var, conv_var, tuple_index}); | |||||
| } | |||||
| const AnfNodePtr ConvTupleGetItemFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_LOG(DEBUG) << "conv_tuplegetitem_fusion pass"; | |||||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||||
| return nullptr; | |||||
| } | |||||
| auto tuple_cnode = node->cast<CNodePtr>(); | |||||
| if (CheckIfCNodeIsNull(tuple_cnode) != lite::RET_OK || | |||||
| CheckInputSize(tuple_cnode, kTupleGetItemLen) != lite::RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| auto idx = GetTupleGetItemOutIndex(tuple_cnode); | |||||
| if (idx != 0) { | |||||
| MS_LOG(DEBUG) << "TupleGetItem's idx is not 0"; | |||||
| return nullptr; | |||||
| } | |||||
| auto conv_node = tuple_cnode->input(1); | |||||
| if (CheckIfAnfNodeIsNull(conv_node) != lite::RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| auto conv_cnode = conv_node->cast<CNodePtr>(); | |||||
| if (CheckIfCNodeIsNull(conv_cnode) != lite::RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| auto abstr = conv_cnode->abstract(); | |||||
| if (utils::isa<abstract::AbstractTuplePtr>(abstr)) { | |||||
| auto elements = utils::cast<abstract::AbstractTuplePtr>(abstr)->elements(); | |||||
| if (elements.empty()) { | |||||
| MS_LOG(ERROR) << "AbstractTuple is empty"; | |||||
| return nullptr; | |||||
| } | |||||
| conv_node->set_abstract(elements[0]); | |||||
| } | |||||
| return conv_node; | |||||
| } | |||||
| } // namespace mindspore::opt | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_ | |||||
| #define LITE_MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_ | |||||
| #include <string> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore::opt { | |||||
| class ConvTupleGetItemFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit ConvTupleGetItemFusion(const std::string &name = "conv_tuplegetitem_fusion", bool multigraph = true) | |||||
| : PatternProcessPass(name, multigraph) {} | |||||
| ~ConvTupleGetItemFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace mindspore::opt | |||||
| #endif // LITE_MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_ | |||||
| @@ -324,6 +324,7 @@ const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const | |||||
| } | } | ||||
| auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, gamma_shape, epsilon); | auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, gamma_shape, epsilon); | ||||
| layer_norm_cnode->set_abstract(add2_cnode->abstract()); | |||||
| layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); | layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); | ||||
| MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success"; | MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success"; | ||||
| return layer_norm_cnode; | return layer_norm_cnode; | ||||
| @@ -0,0 +1,306 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/graph/infershape_pass.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <algorithm> | |||||
| #include "mindspore/lite/include/errorcode.h" | |||||
| #include "mindspore/lite/src/ops/primitive_c.h" | |||||
| #include "tools/anf_importer/import_from_meta_graphT.h" | |||||
| namespace mindspore::opt { | |||||
| abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { | |||||
| MS_ASSERT(nullptr != tensor); | |||||
| std::vector<int> shape(tensor->shape()); | |||||
| auto type_id = static_cast<TypeId>(tensor->data_type()); | |||||
| auto type_ptr = TypeIdToType(type_id); | |||||
| std::vector<int64_t> shape_vector; | |||||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||||
| auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| if (new_abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "new AbstractTensor failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto new_value = std::make_shared<ParamValueLite>(); | |||||
| if (new_value == nullptr) { | |||||
| MS_LOG(ERROR) << "new ParamValueLite failed"; | |||||
| return nullptr; | |||||
| } | |||||
| new_value->set_tensor_shape(tensor->shape()); | |||||
| new_value->set_tensor_type(tensor->data_type()); | |||||
| new_value->set_format(tensor->GetFormat()); | |||||
| new_abstract->set_value(new_value); | |||||
| return new_abstract; | |||||
| } | |||||
| STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { | |||||
| MS_ASSERT(parameter != nullptr); | |||||
| auto old_abstract = parameter->abstract(); | |||||
| if (old_abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << parameter->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(old_abstract)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << parameter->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(old_abstract); | |||||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||||
| if (typePtr == nullptr) { | |||||
| MS_LOG(ERROR) << "typePtr is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << parameter->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||||
| std::vector<int32_t> shape; | |||||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), | |||||
| [](const int64_t &value) { return static_cast<int32_t>(value); }); | |||||
| auto new_abstract = std::make_shared<abstract::AbstractTensor>(typePtr, shape_vector); | |||||
| auto new_value = std::make_shared<ParamValueLite>(); | |||||
| new_value->set_tensor_shape(shape); // scalar's shape is {} | |||||
| new_value->set_tensor_type(typePtr->type_id()); | |||||
| new_value->set_format(schema::Format_NHWC); // default format is NHWC | |||||
| if (parameter->has_default()) { | |||||
| auto param_value = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); | |||||
| new_value->set_format(param_value->format()); | |||||
| new_value->set_tensor_size(param_value->tensor_size()); | |||||
| char *tensor_data = new (std::nothrow) char[new_value->tensor_size()]; | |||||
| if (tensor_data == nullptr) { | |||||
| MS_LOG(ERROR) << "new char[] failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto ret = memcpy_s(tensor_data, new_value->tensor_size(), param_value->tensor_addr(), param_value->tensor_size()); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| new_value->set_tensor_addr(tensor_data); | |||||
| } | |||||
| new_abstract->set_value(new_value); | |||||
| parameter->set_abstract(new_abstract); | |||||
| return RET_OK; | |||||
| } | |||||
| void InferShapePass::FreeTensors(std::vector<lite::Tensor *> *tensors) { | |||||
| for (auto tensor : *tensors) { | |||||
| delete tensor; | |||||
| } | |||||
| tensors->clear(); | |||||
| tensors->shrink_to_fit(); | |||||
| } | |||||
| STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors) { | |||||
| MS_ASSERT(cnode != nullptr); | |||||
| MS_ASSERT(input_tensors != nullptr); | |||||
| auto inputs = cnode->inputs(); | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||||
| auto input = inputs[i]; | |||||
| if (input == nullptr) { | |||||
| MS_LOG(ERROR) << "input is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto tensor = std::make_unique<lite::Tensor>(); | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new input tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (utils::isa<ValueNodePtr>(cnode->input(i))) { | |||||
| MS_LOG(ERROR) << "input is value node"; | |||||
| continue; | |||||
| } | |||||
| AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, i); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of CNode is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | |||||
| MS_LOG(DEBUG) << "Abstract of parameter should be abstract tensor"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); | |||||
| if (!utils::isa<ParamValueLitePtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape | |||||
| MS_LOG(DEBUG) << "Value of abstract is not ParamValueLite, indicate that infershape has failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto param_value_lite = utils::cast<ParamValueLitePtr>(abstract_tensor->GetValueTrack()); | |||||
| if (param_value_lite == nullptr) { | |||||
| MS_LOG(ERROR) << "ParamValueLite of abstract is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| tensor->set_shape(param_value_lite->tensor_shape()); | |||||
| tensor->set_data_type(param_value_lite->tensor_type()); | |||||
| tensor->SetFormat(schema::Format(param_value_lite->format())); | |||||
| if (utils::isa<ParameterPtr>(input)) { | |||||
| auto parameter = input->cast<ParameterPtr>(); | |||||
| if (parameter->has_default()) { | |||||
| auto param_value = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); | |||||
| auto ret = tensor->MallocData(); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "Malloc tensor data failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ret = memcpy_s(tensor->MutableData(), tensor->Size(), param_value->tensor_addr(), param_value->tensor_size()); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| input_tensors->push_back(tensor.release()); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors) { | |||||
| MS_ASSERT(output_tensors != nullptr); | |||||
| auto abstract = cnode->abstract(); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "abstract is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t num_outputs = 1; | |||||
| if (utils::isa<abstract::AbstractTuple>(abstract)) { | |||||
| auto abstract_tuple = abstract->cast<abstract::AbstractTuplePtr>(); | |||||
| num_outputs = abstract_tuple->size(); | |||||
| } | |||||
| for (size_t i = 0; i < num_outputs; ++i) { | |||||
| auto output_tensor = std::make_unique<lite::Tensor>(); | |||||
| if (output_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new output tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| output_tensors->push_back(output_tensor.release()); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &output_tensors, | |||||
| const std::shared_ptr<CNode> &cnode) { | |||||
| MS_ASSERT(cnode != nullptr); | |||||
| if (output_tensors.size() == 0) { | |||||
| MS_LOG(ERROR) << "empty output_tensors"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (output_tensors.size() == 1) { | |||||
| auto tensor = output_tensors.front(); | |||||
| auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor); | |||||
| if (new_abstract == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| cnode->set_abstract(new_abstract); | |||||
| } else { | |||||
| AbstractBasePtrList abstract_list; | |||||
| for (size_t i = 0; i < output_tensors.size(); i++) { | |||||
| auto tensor = output_tensors.front(); | |||||
| auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor); | |||||
| if (new_abstract == nullptr) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| abstract_list.emplace_back(new_abstract); | |||||
| } | |||||
| cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||||
| if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) { | |||||
| MS_LOG(INFO) << "The framework type of model should be tf/tflite."; | |||||
| return false; | |||||
| } | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_ASSERT(manager != nullptr); | |||||
| auto node_list = TopoSort(func_graph->get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (utils::isa<ParameterPtr>(node)) { | |||||
| int status = SetParameterAbstract(node->cast<ParameterPtr>()); | |||||
| if (status != RET_OK) { | |||||
| return false; | |||||
| } | |||||
| continue; | |||||
| } | |||||
| if (!utils::isa<CNodePtr>(node)) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto origin_primc = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(cnode->input(0)); | |||||
| if (origin_primc == nullptr) { | |||||
| MS_LOG(ERROR) << "origin_primc is nullptr"; | |||||
| return false; | |||||
| } | |||||
| auto origin_primt = origin_primc->GetPrimitiveT(); | |||||
| if (origin_primt == nullptr) { | |||||
| MS_LOG(ERROR) << "origin_primt is nullptr"; | |||||
| return false; | |||||
| } | |||||
| auto type = GetCNodeType(cnode); | |||||
| if ((type == schema::PrimitiveType_TupleGetItem) || | |||||
| #ifdef SUPPORT_TRAIN | |||||
| (type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) || | |||||
| #endif | |||||
| (type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) { | |||||
| continue; | |||||
| } | |||||
| std::vector<lite::Tensor *> input_tensors; | |||||
| std::vector<lite::Tensor *> output_tensors; | |||||
| auto status = GetCNodeInputTensors(cnode, &input_tensors); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(DEBUG) << "input shape unknown, infershape can't process cnode " << cnode->fullname_with_scope(); | |||||
| FreeTensors(&input_tensors); | |||||
| continue; | |||||
| } | |||||
| status = GetCNodeOutputTensors(cnode, &output_tensors); | |||||
| if (status != RET_OK) { | |||||
| FreeTensors(&input_tensors); | |||||
| FreeTensors(&output_tensors); | |||||
| continue; | |||||
| } | |||||
| auto primt = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primt == nullptr) { | |||||
| MS_LOG(ERROR) << "primt is nullptr"; | |||||
| return false; | |||||
| } | |||||
| *primt = *origin_primt; | |||||
| auto primc = std::shared_ptr<lite::PrimitiveC>(lite::PrimitiveC::Create(primt.release())); | |||||
| if (primc == nullptr) { | |||||
| MS_LOG(ERROR) << "primc is nullptr"; | |||||
| return false; | |||||
| } | |||||
| status = primc->InferShape(input_tensors, output_tensors); | |||||
| if (status == RET_OK) { | |||||
| status = SetCNodeAbstract(output_tensors, cnode); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope(); | |||||
| } | |||||
| } | |||||
| FreeTensors(&input_tensors); | |||||
| FreeTensors(&output_tensors); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace mindspore::opt | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INFERSHAPE_PASS_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INFERSHAPE_PASS_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "tools/converter/converter_flags.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| #include "backend/optimizer/common/pass.h" | |||||
| #include "mindspore/lite/src/tensor.h" | |||||
| #include "mindspore/lite/include/errorcode.h" | |||||
| using mindspore::lite::STATUS; | |||||
| using mindspore::lite::converter::FmkType; | |||||
| namespace mindspore::opt { | |||||
| class InferShapePass : public Pass { | |||||
| public: | |||||
| InferShapePass() : Pass("infershape_pass") {} | |||||
| ~InferShapePass() override = default; | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; } | |||||
| private: | |||||
| void FreeTensors(std::vector<lite::Tensor *> *tensors); | |||||
| abstract::AbstractTensorPtr ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor); | |||||
| STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors); | |||||
| STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors); | |||||
| STATUS SetParameterAbstract(const ParameterPtr ¶meter); | |||||
| STATUS SetCNodeAbstract(const std::vector<lite::Tensor *> &output_tensors, const std::shared_ptr<CNode> &cnode); | |||||
| private: | |||||
| FmkType fmk_type = lite::converter::FmkType_ONNX; | |||||
| }; | |||||
| } // namespace mindspore::opt | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INFERSHAPE_PASS_H_ | |||||
| @@ -0,0 +1,247 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/graph/slice_prepose_pass.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "mindspore/lite/include/errorcode.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| using mindspore::lite::PrimitiveC; | |||||
| namespace mindspore::opt { | |||||
| namespace { | |||||
| std::vector<int32_t> GetCNodeInputShape(const CNodePtr &cnode, size_t index = 1) { | |||||
| MS_ASSERT(cnode != nullptr); | |||||
| std::vector<int32_t> empty_shape; | |||||
| if (index < 1 || cnode->inputs().size() <= index) { | |||||
| MS_LOG(ERROR) << "out of index"; | |||||
| return empty_shape; | |||||
| } | |||||
| auto abstract = GetCNodeInputAbstract(cnode, index); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of CNode is nullptr"; | |||||
| return empty_shape; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | |||||
| MS_LOG(DEBUG) << "abstract is not AbstractTensor"; | |||||
| return empty_shape; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); | |||||
| if (!utils::isa<ParamValueLitePtr>(abstract_tensor->GetValueTrack())) { | |||||
| MS_LOG(DEBUG) << "Value of abstract is not ParamValueLite, indicate that infershape has failed"; | |||||
| return empty_shape; | |||||
| } | |||||
| auto param_value_lite = utils::cast<ParamValueLitePtr>(abstract_tensor->GetValueTrack()); | |||||
| if (param_value_lite == nullptr) { | |||||
| MS_LOG(ERROR) << "ParamValueLite of abstract is nullptr"; | |||||
| return empty_shape; | |||||
| } | |||||
| return param_value_lite->tensor_shape(); | |||||
| } | |||||
| } // namespace | |||||
| schema::SliceT *SlicePreposePass::GetSliceT(const CNodePtr &cnode) { | |||||
| if (cnode == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto primc = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||||
| if (primc == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto primt = primc->GetPrimitiveT(); | |||||
| if (primt == nullptr || primt->value.AsSlice() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return primt->value.AsSlice(); | |||||
| } | |||||
| STATUS SlicePreposePass::SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, | |||||
| const CNodePtr &preceed_cnode, const int index, | |||||
| const TransactionPtr &tr) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| MS_ASSERT(slice_cnode != nullptr); | |||||
| MS_ASSERT(preceed_cnode != nullptr); | |||||
| if (slice_cnode->input(1) != preceed_cnode) { | |||||
| MS_LOG(ERROR) << "preceed node must be slice node's direct parent"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (IsMultiOutputTensors(graph, preceed_cnode)) { | |||||
| MS_LOG(ERROR) << "preceed node referenced by multi nodes not support swap"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto manager = graph->manager(); | |||||
| if (manager == nullptr) { | |||||
| MS_LOG(ERROR) << "manager is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto node_users = manager->node_users()[slice_cnode]; | |||||
| if (tr != nullptr) { // do swap with transaction | |||||
| for (auto &node_user : node_users) { | |||||
| tr->SetEdge(node_user.first, node_user.second, preceed_cnode); | |||||
| } | |||||
| tr->SetEdge(slice_cnode, 1, preceed_cnode->input(index)); | |||||
| tr->SetEdge(preceed_cnode, index, slice_cnode); | |||||
| } else { | |||||
| for (auto &node_user : node_users) { | |||||
| manager->SetEdge(node_user.first, node_user.second, preceed_cnode); | |||||
| } | |||||
| manager->SetEdge(slice_cnode, 1, preceed_cnode->input(index)); | |||||
| manager->SetEdge(preceed_cnode, index, slice_cnode); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| /* | |||||
| * Prepose condition: | |||||
| * the softmax axis is not sliced | |||||
| */ | |||||
| bool SlicePreposePass::PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, | |||||
| const CNodePtr &softmax_cnode) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| MS_ASSERT(slice_cnode != nullptr); | |||||
| MS_ASSERT(softmax_cnode != nullptr); | |||||
| auto softmax_primc = GetValueNode<std::shared_ptr<PrimitiveC>>(softmax_cnode->input(0)); | |||||
| if (softmax_primc == nullptr) { | |||||
| MS_LOG(ERROR) << "softmax_primc is nullptr"; | |||||
| return false; | |||||
| } | |||||
| auto softmax_primt = softmax_primc->GetPrimitiveT(); | |||||
| if (softmax_primt == nullptr || softmax_primt->value.AsSoftMax() == nullptr) { | |||||
| MS_LOG(ERROR) << "softmax_primt is nullptr"; | |||||
| return false; | |||||
| } | |||||
| auto softmax_attr = softmax_primt->value.AsSoftMax(); | |||||
| auto softmax_axis = softmax_attr->axis; | |||||
| auto shape = GetCNodeInputShape(softmax_cnode, 1); | |||||
| if (softmax_axis == -1) { | |||||
| if (shape.empty()) { // when softmax axis == -1, shape info is needed to determine whether slice can be preposed | |||||
| return false; | |||||
| } | |||||
| softmax_axis += shape.size(); | |||||
| } | |||||
| auto slice_t = GetSliceT(slice_cnode); | |||||
| MS_ASSERT(slice_t != nullptr); | |||||
| auto slice_axes = slice_t->axes; | |||||
| auto slice_begin = slice_t->begin; | |||||
| auto slice_size = slice_t->size; | |||||
| for (size_t i = 0; i < slice_axes.size(); ++i) { | |||||
| if (slice_axes[i] == softmax_axis) { | |||||
| if (slice_begin[i] != 0) { | |||||
| return false; | |||||
| } | |||||
| if (slice_size[i] != -1) { | |||||
| if (shape.empty() || slice_axes[i] >= static_cast<int>(shape.size())) { | |||||
| return false; | |||||
| } | |||||
| if (slice_size[i] < shape[slice_axes[i]]) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| auto status = SwapSliceWithPreceed(graph, slice_cnode, softmax_cnode, 1); | |||||
| return status == RET_OK; | |||||
| } | |||||
| bool SlicePreposePass::DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, | |||||
| const CNodePtr &preceed_cnode) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| MS_ASSERT(slice_cnode != nullptr); | |||||
| MS_ASSERT(preceed_cnode != nullptr); | |||||
| auto preceed_node_type = GetCNodeType(preceed_cnode); | |||||
| switch (preceed_node_type) { | |||||
| case schema::PrimitiveType_SoftMax: { | |||||
| return PreposeWithSoftmax(graph, slice_cnode, preceed_cnode); | |||||
| } | |||||
| default: { | |||||
| MS_LOG(DEBUG) << "Node type " << preceed_node_type << " currently not support SlicePrepose"; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool SlicePreposePass::Run(const FuncGraphPtr &graph) { | |||||
| if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) { | |||||
| MS_LOG(INFO) << "The framework type of model should be tf/tflite."; | |||||
| return false; | |||||
| } | |||||
| MS_ASSERT(graph != nullptr); | |||||
| bool changed = false; | |||||
| while (true) { | |||||
| bool this_time_changed = false; | |||||
| auto node_list = TopoSort(graph->get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (node->func_graph() != graph) { | |||||
| continue; | |||||
| } | |||||
| if (!utils::isa<CNodePtr>(node) || GetCNodeType(node) != schema::PrimitiveType_Slice) { | |||||
| continue; | |||||
| } | |||||
| auto slice_cnode = node->cast<CNodePtr>(); | |||||
| if (slice_cnode->inputs().size() != lite::kDoubleNum) { // only support params from attrs now | |||||
| MS_LOG(INFO) << "SlicePrepose not support more than two inputs now"; | |||||
| continue; | |||||
| } | |||||
| auto primt = GetSliceT(slice_cnode); | |||||
| if (primt == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_t of slice is nullptr"; | |||||
| continue; | |||||
| } | |||||
| auto preceed_node = slice_cnode->input(1); | |||||
| if (preceed_node == nullptr) { | |||||
| MS_LOG(ERROR) << "preceed node is nullptr"; | |||||
| continue; | |||||
| } | |||||
| auto output_tensor_num = GetOutputTensorNum(preceed_node); | |||||
| if (output_tensor_num > 1) { | |||||
| continue; | |||||
| } | |||||
| auto output_node_list = GetRealNodeUsedList(graph, utils::cast<AnfNodePtr>(preceed_node)); | |||||
| if (output_node_list->size() > 1) { // referenced by multi nodes | |||||
| continue; | |||||
| } else { | |||||
| if (utils::isa<ParameterPtr>(preceed_node)) { | |||||
| /* | |||||
| * if preceed_node is parameter without default param, it's input placeholder, so we can't prepose | |||||
| * if preceed_node is parameter with default param, constant_folding will process it | |||||
| */ | |||||
| continue; | |||||
| } | |||||
| auto preceed_cnode = preceed_node->cast<CNodePtr>(); | |||||
| if (preceed_cnode == nullptr) { | |||||
| MS_LOG(ERROR) << "preceed_cnode is nullptr"; | |||||
| continue; | |||||
| } | |||||
| if (DoPrepose(graph, slice_cnode, preceed_cnode)) { | |||||
| this_time_changed = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (this_time_changed) { | |||||
| changed = true; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| } // namespace mindspore::opt | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SLICE_PREPOSE_PASS_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SLICE_PREPOSE_PASS_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include "tools/converter/converter_flags.h" | |||||
| #include "backend/optimizer/common/pass.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "mindspore/core/ir/manager.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| using mindspore::lite::converter::FmkType; | |||||
| namespace mindspore::opt { | |||||
| using lite::RET_ERROR; | |||||
| using lite::RET_OK; | |||||
| using lite::STATUS; | |||||
| using TransactionPtr = std::shared_ptr<mindspore::FuncGraphTransaction>; | |||||
| using NodeUsedListPtr = std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>>; | |||||
| class SlicePreposePass : public Pass { | |||||
| public: | |||||
| SlicePreposePass() : Pass("slice_prepose_pass") {} | |||||
| ~SlicePreposePass() override = default; | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; } | |||||
| private: | |||||
| schema::SliceT *GetSliceT(const CNodePtr &cnode); | |||||
| bool DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode); | |||||
| STATUS SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode, | |||||
| const int index, const TransactionPtr &tr = nullptr); | |||||
| bool PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &softmax_cnode); | |||||
| private: | |||||
| FmkType fmk_type = lite::converter::FmkType_ONNX; | |||||
| }; | |||||
| } // namespace mindspore::opt | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SLICE_PREPOSE_PASS_H_ | |||||