Merge pull request !6917 from wangshaocong/bugfix_mastertags/v1.1.0
| @@ -36,15 +36,15 @@ int ArithmeticCPUKernel::Init() { | |||||
| if (!InferShapeDone()) { | if (!InferShapeDone()) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| return ReSize(); | |||||
| } | |||||
| int ArithmeticCPUKernel::ReSize() { | |||||
| if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) { | if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) { | ||||
| data_type_ = kDataTypeFloat; | data_type_ = kDataTypeFloat; | ||||
| } else { | } else { | ||||
| data_type_ = kDataTypeInt; | data_type_ = kDataTypeInt; | ||||
| } | } | ||||
| return ReSize(); | |||||
| } | |||||
| int ArithmeticCPUKernel::ReSize() { | |||||
| arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); | arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); | ||||
| arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); | arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); | ||||
| arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); | arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); | ||||
| @@ -183,6 +183,7 @@ if(BUILD_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| ### train | ### train | ||||
| @@ -4,3 +4,4 @@ gate_u_net_small-1_110.mindir | |||||
| shufflenetv2.mindir | shufflenetv2.mindir | ||||
| inceptionv3.mindir | inceptionv3.mindir | ||||
| googlenet.mindir | googlenet.mindir | ||||
| resnext50.mindir | |||||
| @@ -60,6 +60,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/fusion/quant_dtype_cast_fusion.cc | ../optimizer/fusion/quant_dtype_cast_fusion.cc | ||||
| ../optimizer/graph/weight_format_transform_pass.cc | ../optimizer/graph/weight_format_transform_pass.cc | ||||
| ../optimizer/graph/weight_format_hardcode_pass.cc | ../optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ../optimizer/graph/unused_cast_node_remove_pass.cc | |||||
| ) | ) | ||||
| add_subdirectory(../anf_importer anf_importer) | add_subdirectory(../anf_importer anf_importer) | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" | #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" | ||||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | ||||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | #include "tools/optimizer/graph/weight_format_transform_pass.h" | ||||
| #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | |||||
| #include "tools/converter/quantizer/post_training_quantizer.h" | #include "tools/converter/quantizer/post_training_quantizer.h" | ||||
| #include "tools/converter/quantizer/quant_cast.h" | #include "tools/converter/quantizer/quant_cast.h" | ||||
| #include "tools/converter/quantizer/weight_quantizer.h" | #include "tools/converter/quantizer/weight_quantizer.h" | ||||
| @@ -72,6 +73,11 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| graph_pm->AddPass(weight_format_transform_pass); | graph_pm->AddPass(weight_format_transform_pass); | ||||
| } | } | ||||
| if (config->fmk == lite::converter::FmkType_MS) { | |||||
| auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | |||||
| remove_unused_cast_pass->SetFmkType(config->fmk); | |||||
| pm->AddPass(remove_unused_cast_pass); | |||||
| } | |||||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | ||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| optimizer->AddPassManager(graph_pm); | optimizer->AddPassManager(graph_pm); | ||||
| @@ -29,25 +29,12 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| bool IsUnusedNode(const CNodeT &node) { | |||||
| if (node.primitive->value.type == schema::PrimitiveType_TupleGetItem) { | |||||
| return true; | |||||
| } | |||||
| if (node.primitive->value.type == schema::PrimitiveType_Cast) { | |||||
| auto attr = reinterpret_cast<schema::CastT *>(node.primitive->value.value); | |||||
| if (attr->srcT == kNumberTypeFloat32 && attr->dstT == kNumberTypeFloat16) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| STATUS UnusedNodeRemovePass::Run(schema::MetaGraphT *graph) { | STATUS UnusedNodeRemovePass::Run(schema::MetaGraphT *graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| bool ifChanged = false; | bool ifChanged = false; | ||||
| for (size_t i = 0; i < graph->nodes.size(); i++) { | for (size_t i = 0; i < graph->nodes.size(); i++) { | ||||
| auto &node = graph->nodes.at(i); | auto &node = graph->nodes.at(i); | ||||
| if (IsUnusedNode(*node)) { | |||||
| if (node->primitive->value.type == schema::PrimitiveType_TupleGetItem) { | |||||
| ifChanged = true; | ifChanged = true; | ||||
| auto status = IsolateOneWayNode(graph, i); | auto status = IsolateOneWayNode(graph, i); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -0,0 +1,74 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| #include "mindspore/lite/include/errorcode.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore::opt { | |||||
| void RemoveUnusedCastOpPass::SetFmkType(FmkType type) { this->fmk_type = type; } | |||||
| bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) { | |||||
| if (this->fmk_type != lite::converter::FmkType_MS) { | |||||
| MS_LOG(ERROR) << "The framework type of model should be mindspore."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_ASSERT(manager != nullptr); | |||||
| auto node_list = TopoSort(func_graph->get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (!utils::isa<CNodePtr>(node)) { | |||||
| continue; | |||||
| } | |||||
| auto type = opt::GetCNodeType(node); | |||||
| if (type != schema::PrimitiveType_Cast) { | |||||
| continue; | |||||
| } | |||||
| auto cast_cnode = node->cast<CNodePtr>(); | |||||
| auto abstract_base = cast_cnode->input(1)->abstract(); | |||||
| if (abstract_base == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << cast_cnode->input(1)->fullname_with_scope(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " | |||||
| << cast_cnode->input(1)->fullname_with_scope(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||||
| auto input_type = abstract_tensor->element()->GetTypeTrack(); | |||||
| MS_ASSERT(input_type != nullptr); | |||||
| auto input_type_value = input_type->type_id(); | |||||
| if (cast_cnode->inputs().size() != lite::kMultiNum || !utils::isa<ValueNodePtr>(cast_cnode->input(2))) { | |||||
| MS_LOG(ERROR) << "Second input of cast should be a ValueNode"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto output_type = GetValueNode<NumberPtr>(cast_cnode->input(2)); | |||||
| if (output_type == nullptr) { | |||||
| MS_LOG(ERROR) << "Second input of cast is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto output_type_value = output_type->type_id(); | |||||
| if ((input_type_value == kNumberTypeFloat32 && output_type_value == kNumberTypeFloat16) || | |||||
| (input_type_value == kNumberTypeFloat16 && output_type_value == kNumberTypeFloat32)) { | |||||
| manager->Replace(node, cast_cnode->input(1)); | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace mindspore::opt | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_ | |||||
| #include <string> | |||||
| #include "backend/optimizer/common/pass.h" | |||||
| #include "tools/converter/converter_flags.h" | |||||
| using mindspore::lite::converter::FmkType; | |||||
| namespace mindspore::opt { | |||||
| class RemoveUnusedCastOpPass : public Pass { | |||||
| public: | |||||
| RemoveUnusedCastOpPass() : Pass("remove_unused_cast_pass") {} | |||||
| ~RemoveUnusedCastOpPass() override = default; | |||||
| void SetFmkType(FmkType fmkType); | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| private: | |||||
| FmkType fmk_type = lite::converter::FmkType_TF; | |||||
| }; | |||||
| } // namespace mindspore::opt | |||||
| #endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_ | |||||