| @@ -36,15 +36,15 @@ int ArithmeticCPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int ArithmeticCPUKernel::ReSize() { | |||
| if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) { | |||
| data_type_ = kDataTypeFloat; | |||
| } else { | |||
| data_type_ = kDataTypeInt; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int ArithmeticCPUKernel::ReSize() { | |||
| arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); | |||
| arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); | |||
| arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); | |||
| @@ -183,6 +183,7 @@ if(BUILD_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | |||
| ) | |||
| endif() | |||
| ### train | |||
| @@ -4,3 +4,4 @@ gate_u_net_small-1_110.mindir | |||
| shufflenetv2.mindir | |||
| inceptionv3.mindir | |||
| googlenet.mindir | |||
| resnext50.mindir | |||
| @@ -60,6 +60,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/fusion/quant_dtype_cast_fusion.cc | |||
| ../optimizer/graph/weight_format_transform_pass.cc | |||
| ../optimizer/graph/weight_format_hardcode_pass.cc | |||
| ../optimizer/graph/unused_cast_node_remove_pass.cc | |||
| ) | |||
| add_subdirectory(../anf_importer anf_importer) | |||
| @@ -27,6 +27,7 @@ | |||
| #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" | |||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | |||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | |||
| #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| @@ -72,6 +73,11 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| graph_pm->AddPass(weight_format_transform_pass); | |||
| } | |||
| if (config->fmk == lite::converter::FmkType_MS) { | |||
| auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | |||
| remove_unused_cast_pass->SetFmkType(config->fmk); | |||
| pm->AddPass(remove_unused_cast_pass); | |||
| } | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| optimizer->AddPassManager(pm); | |||
| optimizer->AddPassManager(graph_pm); | |||
| @@ -29,25 +29,12 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| bool IsUnusedNode(const CNodeT &node) { | |||
| if (node.primitive->value.type == schema::PrimitiveType_TupleGetItem) { | |||
| return true; | |||
| } | |||
| if (node.primitive->value.type == schema::PrimitiveType_Cast) { | |||
| auto attr = reinterpret_cast<schema::CastT *>(node.primitive->value.value); | |||
| if (attr->srcT == kNumberTypeFloat32 && attr->dstT == kNumberTypeFloat16) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| STATUS UnusedNodeRemovePass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| bool ifChanged = false; | |||
| for (size_t i = 0; i < graph->nodes.size(); i++) { | |||
| auto &node = graph->nodes.at(i); | |||
| if (IsUnusedNode(*node)) { | |||
| if (node->primitive->value.type == schema::PrimitiveType_TupleGetItem) { | |||
| ifChanged = true; | |||
| auto status = IsolateOneWayNode(graph, i); | |||
| if (status != RET_OK) { | |||
| @@ -0,0 +1,74 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "mindspore/lite/include/errorcode.h" | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore::opt { | |||
| void RemoveUnusedCastOpPass::SetFmkType(FmkType type) { this->fmk_type = type; } | |||
| bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| if (this->fmk_type != lite::converter::FmkType_MS) { | |||
| MS_LOG(ERROR) << "The framework type of model should be mindspore."; | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type != schema::PrimitiveType_Cast) { | |||
| continue; | |||
| } | |||
| auto cast_cnode = node->cast<CNodePtr>(); | |||
| auto abstract_base = cast_cnode->input(1)->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << cast_cnode->input(1)->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " | |||
| << cast_cnode->input(1)->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| auto input_type = abstract_tensor->element()->GetTypeTrack(); | |||
| MS_ASSERT(input_type != nullptr); | |||
| auto input_type_value = input_type->type_id(); | |||
| if (cast_cnode->inputs().size() != lite::kMultiNum || !utils::isa<ValueNodePtr>(cast_cnode->input(2))) { | |||
| MS_LOG(ERROR) << "Second input of cast should be a ValueNode"; | |||
| return RET_ERROR; | |||
| } | |||
| auto output_type = GetValueNode<NumberPtr>(cast_cnode->input(2)); | |||
| if (output_type == nullptr) { | |||
| MS_LOG(ERROR) << "Second input of cast is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto output_type_value = output_type->type_id(); | |||
| if ((input_type_value == kNumberTypeFloat32 && output_type_value == kNumberTypeFloat16) || | |||
| (input_type_value == kNumberTypeFloat16 && output_type_value == kNumberTypeFloat32)) { | |||
| manager->Replace(node, cast_cnode->input(1)); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_ | |||
| #include <string> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore::opt { | |||
| class RemoveUnusedCastOpPass : public Pass { | |||
| public: | |||
| RemoveUnusedCastOpPass() : Pass("remove_unused_cast_pass") {} | |||
| ~RemoveUnusedCastOpPass() override = default; | |||
| void SetFmkType(FmkType fmkType); | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| FmkType fmk_type = lite::converter::FmkType_TF; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_ | |||