| @@ -824,6 +824,10 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new InstanceNorm(primitive); | |||
| case schema::PrimitiveType_While: | |||
| return new While(primitive); | |||
| case schema::PrimitiveType_OnnxInt8Quantize: | |||
| return new Quant(primitive); | |||
| case schema::PrimitiveType_OnnxInt8Dequantize: | |||
| return new Dequant(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| @@ -196,6 +196,7 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | |||
| ) | |||
| endif() | |||
| @@ -152,7 +152,7 @@ std::string FlagParser::Usage(const Option<std::string> &usgMsg) const { | |||
| // first line, brief of the usage | |||
| std::string usageString = usgMsg.IsSome() ? usgMsg.Get() + "\n" : ""; | |||
| // usage of bin name | |||
| usageString += usageMsg.IsNone() ? "usage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; | |||
| usageString += usageMsg.IsNone() ? "\nusage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; | |||
| // help line of help message, usageLine:message of parametors | |||
| std::string helpLine = ""; | |||
| std::string usageLine = ""; | |||
| @@ -47,6 +47,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/graph/weight_format_hardcode_pass.cc | |||
| ../optimizer/graph/clip_convert_activation_pass.cc | |||
| ../optimizer/graph/unused_cast_node_remove_pass.cc | |||
| ../optimizer/graph/unused_transpose_node_remove_pass.cc | |||
| ../optimizer/graph/identity_remove_pass.cc | |||
| ) | |||
| @@ -33,6 +33,7 @@ | |||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | |||
| #include "tools/optimizer/graph/clip_convert_activation_pass.h" | |||
| #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | |||
| #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| @@ -90,9 +91,22 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| if (config->fmk == lite::converter::FmkType_MS) { | |||
| auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | |||
| if (remove_unused_cast_pass == nullptr) { | |||
| MS_LOG(ERROR) << "RemoveUnusedCastOpPass shoud be specified"; | |||
| return nullptr; | |||
| } | |||
| remove_unused_cast_pass->SetFmkType(config->fmk); | |||
| pm->AddPass(remove_unused_cast_pass); | |||
| } | |||
| if (config->fmk == lite::converter::FmkType_ONNX) { | |||
| auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>(); | |||
| if (remove_unused_transpose_pass == nullptr) { | |||
| MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass shoud be specified"; | |||
| return nullptr; | |||
| } | |||
| remove_unused_transpose_pass->SetFmkType(config->fmk); | |||
| pm->AddPass(remove_unused_transpose_pass); | |||
| } | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); | |||
| optimizer->AddPassManager(convert_pm); | |||
| @@ -61,5 +61,6 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx | |||
| } | |||
| OnnxNodeRegistrar g_onnxTransposeParser("Transpose", new OnnxTransposeParser()); | |||
| OnnxNodeRegistrar g_onnxInt8TransposeParser("Int8Transpose", new OnnxTransposeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,90 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "mindspore/lite/include/errorcode.h" | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore::opt { | |||
| static constexpr size_t kTransposeInput = 1; | |||
| const std::vector<int> kPermNCHW{0, 3, 1, 2}; | |||
| const std::vector<int> kPermNHWC{0, 2, 3, 1}; | |||
| void RemoveUnusedTransposeOpPass::SetFmkType(FmkType type) { this->fmk_type = type; } | |||
| bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| if (this->fmk_type != lite::converter::FmkType_ONNX) { | |||
| MS_LOG(ERROR) << "The framework type of model should be onnx."; | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type == schema::PrimitiveType_Transpose) { | |||
| auto transpose_cnode = node->cast<CNodePtr>(); | |||
| auto typeInput = opt::GetCNodeType(transpose_cnode->input(kTransposeInput)); | |||
| if (typeInput != schema::PrimitiveType_Conv2D) { | |||
| continue; | |||
| } | |||
| auto primPtr = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(transpose_cnode->input(0)); | |||
| if (primPtr == nullptr) { | |||
| MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC"; | |||
| return RET_ERROR; | |||
| } | |||
| auto primT = primPtr->GetPrimitiveT(); | |||
| if (primT == nullptr) { | |||
| MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int32_t> perm = primT->value.AsTranspose()->perm; | |||
| if (perm == kPermNCHW) { | |||
| manager->Replace(transpose_cnode, transpose_cnode->input(1)); | |||
| } | |||
| } else if (type == schema::PrimitiveType_Conv2D) { | |||
| auto conv_node = node->cast<CNodePtr>(); | |||
| auto typeInput = opt::GetCNodeType(conv_node->input(kTransposeInput)); | |||
| if (typeInput != schema::PrimitiveType_Transpose) { | |||
| continue; | |||
| } | |||
| auto transpose_cnode = conv_node->input(kTransposeInput)->cast<CNodePtr>(); | |||
| auto primPtr = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(transpose_cnode->input(0)); | |||
| if (primPtr == nullptr) { | |||
| MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC"; | |||
| return RET_ERROR; | |||
| } | |||
| auto primT = primPtr->GetPrimitiveT(); | |||
| if (primT == nullptr) { | |||
| MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveT"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int32_t> perm = primT->value.AsTranspose()->perm; | |||
| if (perm == kPermNHWC) { | |||
| manager->Replace(transpose_cnode, transpose_cnode->input(1)); | |||
| } | |||
| } else { | |||
| continue; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_ | |||
| #include <string> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore::opt { | |||
| class RemoveUnusedTransposeOpPass : public Pass { | |||
| public: | |||
| RemoveUnusedTransposeOpPass() : Pass("remove_unused_cast_pass") {} | |||
| ~RemoveUnusedTransposeOpPass() override = default; | |||
| void SetFmkType(FmkType fmkType); | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| FmkType fmk_type = lite::converter::FmkType_TF; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_ | |||