From c42e53ae5b1b4d5355469859de84d6dc57df3b51 Mon Sep 17 00:00:00 2001 From: geekun Date: Mon, 30 Mar 2020 19:20:26 +0800 Subject: [PATCH] fix custom op bug and add custom op check --- mindspore/ccsrc/transform/op_adapter.h | 25 ++++++++++++++++++++ mindspore/ccsrc/transform/op_adapter_util.cc | 8 ++++++- 2 files changed, 32 insertions(+), 1 deletion(-) mode change 100755 => 100644 mindspore/ccsrc/transform/op_adapter_util.cc diff --git a/mindspore/ccsrc/transform/op_adapter.h b/mindspore/ccsrc/transform/op_adapter.h index aa466adbb8..929117101b 100644 --- a/mindspore/ccsrc/transform/op_adapter.h +++ b/mindspore/ccsrc/transform/op_adapter.h @@ -279,6 +279,31 @@ class OpAdapter : public BaseOpAdapter { } OutHandler getOutput(const OperatorPtr& op, int index) override { + MS_EXCEPTION_IF_NULL(op); + if (IsCustomOp(op)) { + return getCustomOutput(op, index); + } + return getNormalOutput(op, index); + } + + OutHandler getCustomOutput(const OperatorPtr& op, int index) { + MS_EXCEPTION_IF_NULL(op); + auto it = cus_output_map_.find(op->GetOpType()); + if (it == cus_output_map_.end()) { + MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT is not supported!"; + return OutHandler(); + } + + std::unordered_map& output_map = it->second; + + if ((output_map.find(index) != output_map.end())) { + return OutHandler(op, output_map[index]); + } + MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT index(" << index << ")!"; + return OutHandler(); + } + + OutHandler getNormalOutput(const OperatorPtr& op, int index) { MS_EXCEPTION_IF_NULL(op); if (!dyn_output_map_.empty() && !output_map_.empty()) { MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!"; diff --git a/mindspore/ccsrc/transform/op_adapter_util.cc b/mindspore/ccsrc/transform/op_adapter_util.cc old mode 100755 new mode 100644 index 49b8714837..d52699fa8f --- a/mindspore/ccsrc/transform/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/op_adapter_util.cc @@ -223,7 +223,13 @@ bool IsCustomPrim(const PrimitivePtr& prim) { return false; } - return GetValue(flag); + bool is_custom_op = GetValue(flag); + if (!is_custom_op && prim->GetAttr("_custom_op_impl_config_path") != nullptr) { + MS_LOG(EXCEPTION) << "The custom op flag is false, but the op information config path is not null, non-custom op " + "can not assign the op information config path."; + } + + return is_custom_op; } bool IsCustomCNode(const AnfNodePtr& anf) {