Browse Source

Change io_format in adapter

tags/v1.2.0-rc1
l00591931 4 years ago
parent
commit
65cbb7b08f
4 changed files with 67 additions and 5 deletions
  1. +24
    -0
      mindspore/ccsrc/transform/graph_ir/io_format_map.cc
  2. +33
    -0
      mindspore/ccsrc/transform/graph_ir/io_format_map.h
  3. +10
    -4
      mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc
  4. +0
    -1
      mindspore/ops/operations/nn_ops.py

+ 24
- 0
mindspore/ccsrc/transform/graph_ir/io_format_map.cc View File

@@ -0,0 +1,24 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "transform/graph_ir/io_format_map.h"

namespace mindspore {
namespace transform {
std::unordered_map<std::string, std::string> IOFormatMap::io_format_map_ = {{"MatMul", "ND"}, {"Conv3D", "format"}};
std::unordered_map<std::string, std::string> &IOFormatMap::get() { return io_format_map_; }
} // namespace transform
} // namespace mindspore

+ 33
- 0
mindspore/ccsrc/transform/graph_ir/io_format_map.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_IO_FORMAT_MAP_H_
#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_IO_FORMAT_MAP_H_

#include <unordered_map>
#include <string>

namespace mindspore {
namespace transform {
class IOFormatMap {
public:
static std::unordered_map<std::string, std::string> &get();

private:
static std::unordered_map<std::string, std::string> io_format_map_;
};
} // namespace transform
} // namespace mindspore
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_IO_FORMAT_MAP_H_

+ 10
- 4
mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc View File

@@ -22,6 +22,7 @@

#include "utils/utils.h"
#include "transform/graph_ir/op_adapter_base.h"
#include "transform/graph_ir/io_format_map.h"

namespace mindspore {
namespace transform {
@@ -293,12 +294,17 @@ std::string GetOpIOFormat(const AnfNodePtr &anf) {
MS_LOG(ERROR) << "The anf is not a Primitive.";
return ret;
}
ValuePtr format = prim->GetAttr("io_format");
if (format == nullptr) {
auto io_format_map = IOFormatMap::get();
auto iter = io_format_map.find(prim->name());
if (iter == io_format_map.end()) {
return "NCHW";
}
ret = GetValue<std::string>(format);
return ret;
if (iter->second == "format") {
ValuePtr format = prim->GetAttr("format");
MS_EXCEPTION_IF_NULL(format);
return GetValue<std::string>(format);
}
return iter->second;
}
} // namespace transform
} // namespace mindspore

+ 0
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -7490,7 +7490,6 @@ class Conv3D(PrimitiveWithInfer):
self.add_prim_attr('mode', self.mode)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
self.add_prim_attr('data_format', self.format)
self.add_prim_attr('io_format', self.format)
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('groups', self.group)


Loading…
Cancel
Save