From: @wangnan39 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -430,7 +430,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons | |||||
| mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); | mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); | ||||
| attr_proto->set_name(attr.first); | attr_proto->set_name(attr.first); | ||||
| auto attr_value = attr.second; | auto attr_value = attr.second; | ||||
| CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr_value); | |||||
| CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value); | |||||
| SetValueToAttributeProto(attr_value, attr_proto); | SetValueToAttributeProto(attr_value, attr_proto); | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -501,7 +501,7 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind | |||||
| ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto); | ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto); | ||||
| const std::string &op_type = prim->name(); | const std::string &op_type = prim->name(); | ||||
| if (!IsLite()) { | if (!IsLite()) { | ||||
| CheckAndConvertUtils::ConvertAttrValueToInt(op_type, attr_name, &res); | |||||
| CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &res); | |||||
| } | } | ||||
| prim->AddAttr(attr_name, res); | prim->AddAttr(attr_name, res); | ||||
| break; | break; | ||||
| @@ -253,6 +253,56 @@ bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, | |||||
| return true; | return true; | ||||
| } | } | ||||
| void ConvertTargetAttr(const std::string &attr_name, ValuePtr *const value) { | |||||
| if (attr_name == "primitive_target") { | |||||
| auto target_value = GetValue<std::string>(*value); | |||||
| if (target_value == "CPU") { | |||||
| *value = MakeValue<std::string>("host"); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "The primitive_target only support CPU when export, but got " << target_value; | |||||
| } | |||||
| } | |||||
| } | |||||
| void RestoreTargetAttr(const std::string &attr_name, ValuePtr *const value) { | |||||
| if (attr_name == "primitive_target") { | |||||
| auto target_value = GetValue<std::string>(*value); | |||||
| // compatible with exported model | |||||
| if (target_value == "CPU") { | |||||
| return; | |||||
| } | |||||
| if (target_value == "host") { | |||||
| *value = MakeValue<std::string>("CPU"); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Invalid primitive_target value: " << target_value; | |||||
| } | |||||
| } | |||||
| } | |||||
| void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, | |||||
| ValuePtr *const value) { | |||||
| if (value == nullptr || *value == nullptr) { | |||||
| MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name; | |||||
| return; | |||||
| } | |||||
| // convert enum to string | |||||
| ConvertAttrValueToString(op_type, attr_name, value); | |||||
| // set cpu target as host | |||||
| ConvertTargetAttr(attr_name, value); | |||||
| } | |||||
| void CheckAndConvertUtils::ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name, | |||||
| ValuePtr *const value) { | |||||
| if (value == nullptr || *value == nullptr) { | |||||
| MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name; | |||||
| return; | |||||
| } | |||||
| // convert string to enum | |||||
| ConvertAttrValueToInt(op_type, attr_name, value); | |||||
| // restore target as CPU | |||||
| RestoreTargetAttr(attr_name, value); | |||||
| } | |||||
| namespace { | namespace { | ||||
| typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction; | typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction; | ||||
| @@ -284,6 +284,8 @@ class CheckAndConvertUtils { | |||||
| const std::string &prim_name); | const std::string &prim_name); | ||||
| static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); | static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); | ||||
| static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); | static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); | ||||
| static void ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); | |||||
| static void ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); | |||||
| static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name); | static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name); | ||||
| static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value); | static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value); | ||||
| static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false); | static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false); | ||||