From 93b1243806ebec6235d2852638ab40c585dcd7a8 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Fri, 12 Mar 2021 15:44:10 +0800 Subject: [PATCH] modify attr_target when load and export --- .../transform/express_ir/mindir_exporter.cc | 2 +- .../core/load_mindir/anf_model_parser.cc | 2 +- mindspore/core/utils/check_convert_utils.cc | 50 +++++++++++++++++++ mindspore/core/utils/check_convert_utils.h | 2 + 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index 929bd98286..cd7a919590 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -428,7 +428,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name(attr.first); auto attr_value = attr.second; - CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr_value); + CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value); SetValueToAttributeProto(attr_value, attr_proto); } } else { diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index 7cd8133df2..f5b4f09b3e 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -501,7 +501,7 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto); const std::string &op_type = prim->name(); if (!IsLite()) { - CheckAndConvertUtils::ConvertAttrValueToInt(op_type, attr_name, &res); + CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &res); } prim->AddAttr(attr_name, res); break; diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index 3e05bf6f7d..fd7cfc27ff 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -254,6 +254,56 @@ bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, return true; } +void ConvertTargetAttr(const std::string &attr_name, ValuePtr *const value) { + if (attr_name == "primitive_target") { + auto target_value = GetValue(*value); + if (target_value == "CPU") { + *value = MakeValue("host"); + } else { + MS_LOG(EXCEPTION) << "The primitive_target only support CPU when export, but got " << target_value; + } + } +} + +void RestoreTargetAttr(const std::string &attr_name, ValuePtr *const value) { + if (attr_name == "primitive_target") { + auto target_value = GetValue(*value); + // compatible with exported model + if (target_value == "CPU") { + return; + } + if (target_value == "host") { + *value = MakeValue("CPU"); + } else { + MS_LOG(EXCEPTION) << "Invalid primitive_target value: " << target_value; + } + } +} + +void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, + ValuePtr *const value) { + if (value == nullptr || *value == nullptr) { + MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name; + return; + } + // convert enum to string + ConvertAttrValueToString(op_type, attr_name, value); + // set cpu target as host + ConvertTargetAttr(attr_name, value); +} + +void CheckAndConvertUtils::ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name, + ValuePtr *const value) { + if (value == nullptr || *value == nullptr) { + MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name; + return; + } + // convert string to enum + ConvertAttrValueToInt(op_type, attr_name, value); + // restore target as CPU + RestoreTargetAttr(attr_name, value); +} + namespace { typedef std::map> AttrFunction; diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index eeb3ab136e..c47c8b3779 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -284,6 +284,8 @@ class CheckAndConvertUtils { const std::string &prim_name); static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); + static void ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); + static void ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name); static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value); static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false);