|
|
|
@@ -253,6 +253,56 @@ bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void ConvertTargetAttr(const std::string &attr_name, ValuePtr *const value) { |
|
|
|
if (attr_name == "primitive_target") { |
|
|
|
auto target_value = GetValue<std::string>(*value); |
|
|
|
if (target_value == "CPU") { |
|
|
|
*value = MakeValue<std::string>("host"); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "The primitive_target only support CPU when export, but got " << target_value; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void RestoreTargetAttr(const std::string &attr_name, ValuePtr *const value) { |
|
|
|
if (attr_name == "primitive_target") { |
|
|
|
auto target_value = GetValue<std::string>(*value); |
|
|
|
// compatible with exported model |
|
|
|
if (target_value == "CPU") { |
|
|
|
return; |
|
|
|
} |
|
|
|
if (target_value == "host") { |
|
|
|
*value = MakeValue<std::string>("CPU"); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid primitive_target value: " << target_value; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, |
|
|
|
ValuePtr *const value) { |
|
|
|
if (value == nullptr || *value == nullptr) { |
|
|
|
MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name; |
|
|
|
return; |
|
|
|
} |
|
|
|
// convert enum to string |
|
|
|
ConvertAttrValueToString(op_type, attr_name, value); |
|
|
|
// set cpu target as host |
|
|
|
ConvertTargetAttr(attr_name, value); |
|
|
|
} |
|
|
|
|
|
|
|
void CheckAndConvertUtils::ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name, |
|
|
|
ValuePtr *const value) { |
|
|
|
if (value == nullptr || *value == nullptr) { |
|
|
|
MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name; |
|
|
|
return; |
|
|
|
} |
|
|
|
// convert string to enum |
|
|
|
ConvertAttrValueToInt(op_type, attr_name, value); |
|
|
|
// restore target as CPU |
|
|
|
RestoreTargetAttr(attr_name, value); |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction; |
|
|
|
|
|
|
|
|