Browse Source

!13233 modify attr_target when load and export

From: @wangnan39
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
17654135aa
4 changed files with 54 additions and 2 deletions
  1. +1
    -1
      mindspore/ccsrc/transform/express_ir/mindir_exporter.cc
  2. +1
    -1
      mindspore/core/load_mindir/anf_model_parser.cc
  3. +50
    -0
      mindspore/core/utils/check_convert_utils.cc
  4. +2
    -0
      mindspore/core/utils/check_convert_utils.h

+ 1
- 1
mindspore/ccsrc/transform/express_ir/mindir_exporter.cc View File

@@ -430,7 +430,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name(attr.first); attr_proto->set_name(attr.first);
auto attr_value = attr.second; auto attr_value = attr.second;
CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr_value);
CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value);
SetValueToAttributeProto(attr_value, attr_proto); SetValueToAttributeProto(attr_value, attr_proto);
} }
} else { } else {


+ 1
- 1
mindspore/core/load_mindir/anf_model_parser.cc View File

@@ -501,7 +501,7 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind
ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto); ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
const std::string &op_type = prim->name(); const std::string &op_type = prim->name();
if (!IsLite()) { if (!IsLite()) {
CheckAndConvertUtils::ConvertAttrValueToInt(op_type, attr_name, &res);
CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &res);
} }
prim->AddAttr(attr_name, res); prim->AddAttr(attr_name, res);
break; break;


+ 50
- 0
mindspore/core/utils/check_convert_utils.cc View File

@@ -253,6 +253,56 @@ bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type,
return true; return true;
} }


void ConvertTargetAttr(const std::string &attr_name, ValuePtr *const value) {
if (attr_name == "primitive_target") {
auto target_value = GetValue<std::string>(*value);
if (target_value == "CPU") {
*value = MakeValue<std::string>("host");
} else {
MS_LOG(EXCEPTION) << "The primitive_target only support CPU when export, but got " << target_value;
}
}
}

void RestoreTargetAttr(const std::string &attr_name, ValuePtr *const value) {
if (attr_name == "primitive_target") {
auto target_value = GetValue<std::string>(*value);
// compatible with exported model
if (target_value == "CPU") {
return;
}
if (target_value == "host") {
*value = MakeValue<std::string>("CPU");
} else {
MS_LOG(EXCEPTION) << "Invalid primitive_target value: " << target_value;
}
}
}

void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name,
ValuePtr *const value) {
if (value == nullptr || *value == nullptr) {
MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
return;
}
// convert enum to string
ConvertAttrValueToString(op_type, attr_name, value);
// set cpu target as host
ConvertTargetAttr(attr_name, value);
}

void CheckAndConvertUtils::ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name,
ValuePtr *const value) {
if (value == nullptr || *value == nullptr) {
MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
return;
}
// convert string to enum
ConvertAttrValueToInt(op_type, attr_name, value);
// restore target as CPU
RestoreTargetAttr(attr_name, value);
}

namespace { namespace {
typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction; typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction;




+ 2
- 0
mindspore/core/utils/check_convert_utils.h View File

@@ -284,6 +284,8 @@ class CheckAndConvertUtils {
const std::string &prim_name); const std::string &prim_name);
static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static void ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static void ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name); static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name);
static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value); static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value);
static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false); static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false);


Loading…
Cancel
Save