| @@ -824,6 +824,10 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new InstanceNorm(primitive); | return new InstanceNorm(primitive); | ||||
| case schema::PrimitiveType_While: | case schema::PrimitiveType_While: | ||||
| return new While(primitive); | return new While(primitive); | ||||
| case schema::PrimitiveType_OnnxInt8Quantize: | |||||
| return new Quant(primitive); | |||||
| case schema::PrimitiveType_OnnxInt8Dequantize: | |||||
| return new Dequant(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| @@ -196,6 +196,7 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc | |||||
| ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | ||||
| ) | ) | ||||
| endif() | endif() | ||||
| @@ -152,7 +152,7 @@ std::string FlagParser::Usage(const Option<std::string> &usgMsg) const { | |||||
| // first line, brief of the usage | // first line, brief of the usage | ||||
| std::string usageString = usgMsg.IsSome() ? usgMsg.Get() + "\n" : ""; | std::string usageString = usgMsg.IsSome() ? usgMsg.Get() + "\n" : ""; | ||||
| // usage of bin name | // usage of bin name | ||||
| usageString += usageMsg.IsNone() ? "usage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; | |||||
| usageString += usageMsg.IsNone() ? "\nusage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; | |||||
| // help line of help message, usageLine:message of parametors | // help line of help message, usageLine:message of parametors | ||||
| std::string helpLine = ""; | std::string helpLine = ""; | ||||
| std::string usageLine = ""; | std::string usageLine = ""; | ||||
| @@ -47,6 +47,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/graph/weight_format_hardcode_pass.cc | ../optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ../optimizer/graph/clip_convert_activation_pass.cc | ../optimizer/graph/clip_convert_activation_pass.cc | ||||
| ../optimizer/graph/unused_cast_node_remove_pass.cc | ../optimizer/graph/unused_cast_node_remove_pass.cc | ||||
| ../optimizer/graph/unused_transpose_node_remove_pass.cc | |||||
| ../optimizer/graph/identity_remove_pass.cc | ../optimizer/graph/identity_remove_pass.cc | ||||
| ) | ) | ||||
| @@ -33,6 +33,7 @@ | |||||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | #include "tools/optimizer/graph/weight_format_transform_pass.h" | ||||
| #include "tools/optimizer/graph/clip_convert_activation_pass.h" | #include "tools/optimizer/graph/clip_convert_activation_pass.h" | ||||
| #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | ||||
| #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | |||||
| #include "tools/converter/quantizer/post_training_quantizer.h" | #include "tools/converter/quantizer/post_training_quantizer.h" | ||||
| #include "tools/converter/quantizer/quant_cast.h" | #include "tools/converter/quantizer/quant_cast.h" | ||||
| #include "tools/converter/quantizer/weight_quantizer.h" | #include "tools/converter/quantizer/weight_quantizer.h" | ||||
| @@ -90,9 +91,22 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| if (config->fmk == lite::converter::FmkType_MS) { | if (config->fmk == lite::converter::FmkType_MS) { | ||||
| auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | ||||
| if (remove_unused_cast_pass == nullptr) { | |||||
| MS_LOG(ERROR) << "RemoveUnusedCastOpPass shoud be specified"; | |||||
| return nullptr; | |||||
| } | |||||
| remove_unused_cast_pass->SetFmkType(config->fmk); | remove_unused_cast_pass->SetFmkType(config->fmk); | ||||
| pm->AddPass(remove_unused_cast_pass); | pm->AddPass(remove_unused_cast_pass); | ||||
| } | } | ||||
| if (config->fmk == lite::converter::FmkType_ONNX) { | |||||
| auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>(); | |||||
| if (remove_unused_transpose_pass == nullptr) { | |||||
| MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass shoud be specified"; | |||||
| return nullptr; | |||||
| } | |||||
| remove_unused_transpose_pass->SetFmkType(config->fmk); | |||||
| pm->AddPass(remove_unused_transpose_pass); | |||||
| } | |||||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | ||||
| convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); | convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); | ||||
| optimizer->AddPassManager(convert_pm); | optimizer->AddPassManager(convert_pm); | ||||
| @@ -61,5 +61,6 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxTransposeParser("Transpose", new OnnxTransposeParser()); | OnnxNodeRegistrar g_onnxTransposeParser("Transpose", new OnnxTransposeParser()); | ||||
| OnnxNodeRegistrar g_onnxInt8TransposeParser("Int8Transpose", new OnnxTransposeParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,90 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| #include "mindspore/lite/include/errorcode.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore::opt { | |||||
| static constexpr size_t kTransposeInput = 1; | |||||
| const std::vector<int> kPermNCHW{0, 3, 1, 2}; | |||||
| const std::vector<int> kPermNHWC{0, 2, 3, 1}; | |||||
| void RemoveUnusedTransposeOpPass::SetFmkType(FmkType type) { this->fmk_type = type; } | |||||
| bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) { | |||||
| if (this->fmk_type != lite::converter::FmkType_ONNX) { | |||||
| MS_LOG(ERROR) << "The framework type of model should be onnx."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_ASSERT(manager != nullptr); | |||||
| auto node_list = TopoSort(func_graph->get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (!utils::isa<CNodePtr>(node)) { | |||||
| continue; | |||||
| } | |||||
| auto type = opt::GetCNodeType(node); | |||||
| if (type == schema::PrimitiveType_Transpose) { | |||||
| auto transpose_cnode = node->cast<CNodePtr>(); | |||||
| auto typeInput = opt::GetCNodeType(transpose_cnode->input(kTransposeInput)); | |||||
| if (typeInput != schema::PrimitiveType_Conv2D) { | |||||
| continue; | |||||
| } | |||||
| auto primPtr = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(transpose_cnode->input(0)); | |||||
| if (primPtr == nullptr) { | |||||
| MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto primT = primPtr->GetPrimitiveT(); | |||||
| if (primT == nullptr) { | |||||
| MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int32_t> perm = primT->value.AsTranspose()->perm; | |||||
| if (perm == kPermNCHW) { | |||||
| manager->Replace(transpose_cnode, transpose_cnode->input(1)); | |||||
| } | |||||
| } else if (type == schema::PrimitiveType_Conv2D) { | |||||
| auto conv_node = node->cast<CNodePtr>(); | |||||
| auto typeInput = opt::GetCNodeType(conv_node->input(kTransposeInput)); | |||||
| if (typeInput != schema::PrimitiveType_Transpose) { | |||||
| continue; | |||||
| } | |||||
| auto transpose_cnode = conv_node->input(kTransposeInput)->cast<CNodePtr>(); | |||||
| auto primPtr = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(transpose_cnode->input(0)); | |||||
| if (primPtr == nullptr) { | |||||
| MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto primT = primPtr->GetPrimitiveT(); | |||||
| if (primT == nullptr) { | |||||
| MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveT"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int32_t> perm = primT->value.AsTranspose()->perm; | |||||
| if (perm == kPermNHWC) { | |||||
| manager->Replace(transpose_cnode, transpose_cnode->input(1)); | |||||
| } | |||||
| } else { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace mindspore::opt | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_ | |||||
| #include <string> | |||||
| #include "backend/optimizer/common/pass.h" | |||||
| #include "tools/converter/converter_flags.h" | |||||
| using mindspore::lite::converter::FmkType; | |||||
| namespace mindspore::opt { | |||||
| class RemoveUnusedTransposeOpPass : public Pass { | |||||
| public: | |||||
| RemoveUnusedTransposeOpPass() : Pass("remove_unused_cast_pass") {} | |||||
| ~RemoveUnusedTransposeOpPass() override = default; | |||||
| void SetFmkType(FmkType fmkType); | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| private: | |||||
| FmkType fmk_type = lite::converter::FmkType_TF; | |||||
| }; | |||||
| } // namespace mindspore::opt | |||||
| #endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_ | |||||