diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc index eead37132f..b963f8747c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -36,15 +36,15 @@ int ArithmeticCPUKernel::Init() { if (!InferShapeDone()) { return RET_OK; } + return ReSize(); +} + +int ArithmeticCPUKernel::ReSize() { if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) { data_type_ = kDataTypeFloat; } else { data_type_ = kDataTypeInt; } - return ReSize(); -} - -int ArithmeticCPUKernel::ReSize() { arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index aac448cf5c..103a11e025 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -183,6 +183,7 @@ if(BUILD_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc + ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc ) endif() ### train diff --git a/mindspore/lite/test/models_mindspore.cfg b/mindspore/lite/test/models_mindspore.cfg index 8f3c8a9f94..68441dc2ce 100644 --- a/mindspore/lite/test/models_mindspore.cfg +++ b/mindspore/lite/test/models_mindspore.cfg @@ -4,3 +4,4 @@ gate_u_net_small-1_110.mindir shufflenetv2.mindir inceptionv3.mindir googlenet.mindir +resnext50.mindir diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 66857a3a02..bf0567170c 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -60,6 +60,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/quant_dtype_cast_fusion.cc ../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc + ../optimizer/graph/unused_cast_node_remove_pass.cc ) add_subdirectory(../anf_importer anf_importer) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 10cf6b3352..d5f90a00e5 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -27,6 +27,7 @@ #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" #include "tools/optimizer/graph/weight_format_hardcode_pass.h" #include "tools/optimizer/graph/weight_format_transform_pass.h" +#include "tools/optimizer/graph/unused_cast_node_remove_pass.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/weight_quantizer.h" @@ -72,6 +73,11 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver graph_pm->AddPass(weight_format_transform_pass); } + if (config->fmk == lite::converter::FmkType_MS) { + auto remove_unused_cast_pass = std::make_shared(); + remove_unused_cast_pass->SetFmkType(config->fmk); + pm->AddPass(remove_unused_cast_pass); + } pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); optimizer->AddPassManager(graph_pm); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.cc index 455779cf7d..d252fe5713 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.cc @@ -29,25 +29,12 @@ namespace mindspore { namespace lite { -bool IsUnusedNode(const CNodeT &node) { - if (node.primitive->value.type == schema::PrimitiveType_TupleGetItem) { - return true; - } - if (node.primitive->value.type == schema::PrimitiveType_Cast) { - auto attr = reinterpret_cast(node.primitive->value.value); - if (attr->srcT == kNumberTypeFloat32 && attr->dstT == kNumberTypeFloat16) { - return true; - } - } - return false; -} - STATUS UnusedNodeRemovePass::Run(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); bool ifChanged = false; for (size_t i = 0; i < graph->nodes.size(); i++) { auto &node = graph->nodes.at(i); - if (IsUnusedNode(*node)) { + if (node->primitive->value.type == schema::PrimitiveType_TupleGetItem) { ifChanged = true; auto status = IsolateOneWayNode(graph, i); if (status != RET_OK) { diff --git a/mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.cc new file mode 100644 index 0000000000..fc893ebda3 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/unused_cast_node_remove_pass.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "mindspore/lite/include/errorcode.h" +#include "src/ops/primitive_c.h" + +namespace mindspore::opt { +void RemoveUnusedCastOpPass::SetFmkType(FmkType type) { this->fmk_type = type; } + +bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) { + if (this->fmk_type != lite::converter::FmkType_MS) { + MS_LOG(ERROR) << "The framework type of model should be mindspore."; + return RET_ERROR; + } + MS_ASSERT(func_graph != nullptr); + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto type = opt::GetCNodeType(node); + if (type != schema::PrimitiveType_Cast) { + continue; + } + auto cast_cnode = node->cast(); + auto abstract_base = cast_cnode->input(1)->abstract(); + if (abstract_base == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << cast_cnode->input(1)->fullname_with_scope(); + return RET_ERROR; + } + if (!utils::isa(abstract_base)) { + MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " + << cast_cnode->input(1)->fullname_with_scope(); + return RET_ERROR; + } + auto abstract_tensor = utils::cast(abstract_base); + auto input_type = abstract_tensor->element()->GetTypeTrack(); + MS_ASSERT(input_type != nullptr); + auto input_type_value = input_type->type_id(); + + if (cast_cnode->inputs().size() != lite::kMultiNum || !utils::isa(cast_cnode->input(2))) { + MS_LOG(ERROR) << "Second input of cast should be a ValueNode"; + return RET_ERROR; + } + auto output_type = GetValueNode(cast_cnode->input(2)); + if (output_type == nullptr) { + MS_LOG(ERROR) << "Second input of cast is nullptr"; + return RET_ERROR; + } + auto output_type_value = output_type->type_id(); + if ((input_type_value == kNumberTypeFloat32 && output_type_value == kNumberTypeFloat16) || + (input_type_value == kNumberTypeFloat16 && output_type_value == kNumberTypeFloat32)) { + manager->Replace(node, cast_cnode->input(1)); + } + } + return true; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.h b/mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.h new file mode 100644 index 0000000000..4264e0d3d0 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_ +#include +#include "backend/optimizer/common/pass.h" +#include "tools/converter/converter_flags.h" + +using mindspore::lite::converter::FmkType; +namespace mindspore::opt { +class RemoveUnusedCastOpPass : public Pass { + public: + RemoveUnusedCastOpPass() : Pass("remove_unused_cast_pass") {} + ~RemoveUnusedCastOpPass() override = default; + void SetFmkType(FmkType fmkType); + bool Run(const FuncGraphPtr &graph) override; + + private: + FmkType fmk_type = lite::converter::FmkType_TF; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_