From 65cbb7b08fbd986fd6805a690921d15fd642d4a1 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Wed, 24 Feb 2021 11:53:06 +0800 Subject: [PATCH] Change io_format in adapter --- .../ccsrc/transform/graph_ir/io_format_map.cc | 24 ++++++++++++++ .../ccsrc/transform/graph_ir/io_format_map.h | 33 +++++++++++++++++++ .../transform/graph_ir/op_adapter_util.cc | 14 +++++--- mindspore/ops/operations/nn_ops.py | 1 - 4 files changed, 67 insertions(+), 5 deletions(-) create mode 100644 mindspore/ccsrc/transform/graph_ir/io_format_map.cc create mode 100644 mindspore/ccsrc/transform/graph_ir/io_format_map.h diff --git a/mindspore/ccsrc/transform/graph_ir/io_format_map.cc b/mindspore/ccsrc/transform/graph_ir/io_format_map.cc new file mode 100644 index 0000000000..8d265cc7e4 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/io_format_map.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/io_format_map.h" + +namespace mindspore { +namespace transform { +std::unordered_map IOFormatMap::io_format_map_ = {{"MatMul", "ND"}, {"Conv3D", "format"}}; +std::unordered_map &IOFormatMap::get() { return io_format_map_; } +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/io_format_map.h b/mindspore/ccsrc/transform/graph_ir/io_format_map.h new file mode 100644 index 0000000000..c53a79103d --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/io_format_map.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_IO_FORMAT_MAP_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_IO_FORMAT_MAP_H_ + +#include +#include + +namespace mindspore { +namespace transform { +class IOFormatMap { + public: + static std::unordered_map &get(); + + private: + static std::unordered_map io_format_map_; +}; +} // namespace transform +} // namespace mindspore +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_IO_FORMAT_MAP_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc index 55e297c546..36df9eb6df 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc @@ -22,6 +22,7 @@ #include "utils/utils.h" #include "transform/graph_ir/op_adapter_base.h" +#include "transform/graph_ir/io_format_map.h" namespace mindspore { namespace transform { @@ -293,12 +294,17 @@ std::string GetOpIOFormat(const AnfNodePtr &anf) { MS_LOG(ERROR) << "The anf is not a Primitive."; return ret; } - ValuePtr format = prim->GetAttr("io_format"); - if (format == nullptr) { + auto io_format_map = IOFormatMap::get(); + auto iter = io_format_map.find(prim->name()); + if (iter == io_format_map.end()) { return "NCHW"; } - ret = GetValue(format); - return ret; + if (iter->second == "format") { + ValuePtr format = prim->GetAttr("format"); + MS_EXCEPTION_IF_NULL(format); + return GetValue(format); + } + return iter->second; } } // namespace transform } // namespace mindspore diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 2f86382303..034ee1d265 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -7490,7 +7490,6 @@ class Conv3D(PrimitiveWithInfer): self.add_prim_attr('mode', self.mode) self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) self.add_prim_attr('data_format', self.format) - self.add_prim_attr('io_format', self.format) self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) self.group = validator.check_positive_int(group, 'group', self.name) self.add_prim_attr('groups', self.group)