|
|
|
@@ -146,51 +146,55 @@ static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrC |
|
|
|
|
|
|
|
bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) { |
|
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
|
if (value->isa<StringImm>()) { |
|
|
|
auto attr_value_str = GetValue<std::string>(value); |
|
|
|
if (DataFormatToEnumMap.find(attr_value_str) == DataFormatToEnumMap.end()) { |
|
|
|
MS_LOG(DEBUG) << "The data format " << attr_value_str << " not be converted to enum."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
*enum_value = DataFormatToEnumMap[attr_value_str]; |
|
|
|
return true; |
|
|
|
} else { |
|
|
|
if (!value->isa<StringImm>()) { |
|
|
|
*enum_value = GetValue<int64_t>(value); |
|
|
|
return true; |
|
|
|
} |
|
|
|
auto attr_value_str = GetValue<std::string>(value); |
|
|
|
auto iter = DataFormatToEnumMap.find(attr_value_str); |
|
|
|
if (iter == DataFormatToEnumMap.end()) { |
|
|
|
MS_LOG(DEBUG) << "The data format " << attr_value_str << " not be converted to enum."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
*enum_value = iter->second; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void CheckAndConvertUtils::GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper) { |
|
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
|
if (value->isa<StringImm>()) { |
|
|
|
auto attr_value_str = GetValue<std::string>(value); |
|
|
|
if (!value->isa<StringImm>()) { |
|
|
|
*enum_value = GetValue<int64_t>(value); |
|
|
|
return; |
|
|
|
} |
|
|
|
auto attr_value_str = GetValue<std::string>(value); |
|
|
|
|
|
|
|
std::map<std::string, int64_t> pad_map = PadModToEnumMap; |
|
|
|
if (is_upper) { |
|
|
|
pad_map = PadModToEnumUpperMap; |
|
|
|
} |
|
|
|
if (pad_map.find(attr_value_str) == pad_map.end()) { |
|
|
|
if (is_upper) { |
|
|
|
auto iter = PadModToEnumUpperMap.find(attr_value_str); |
|
|
|
if (iter == PadModToEnumUpperMap.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same"; |
|
|
|
} |
|
|
|
*enum_value = pad_map[attr_value_str]; |
|
|
|
} else { |
|
|
|
*enum_value = GetValue<int64_t>(value); |
|
|
|
*enum_value = iter->second; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto iter = PadModToEnumMap.find(attr_value_str); |
|
|
|
if (iter == PadModToEnumMap.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same"; |
|
|
|
} |
|
|
|
*enum_value = iter->second; |
|
|
|
} |
|
|
|
|
|
|
|
void CheckAndConvertUtils::GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value) { |
|
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
|
if (value->isa<StringImm>()) { |
|
|
|
auto attr_value_str = GetValue<std::string>(value); |
|
|
|
|
|
|
|
std::map<std::string, int64_t> pad_map = ReductionToEnumMap; |
|
|
|
if (pad_map.find(attr_value_str) == pad_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same"; |
|
|
|
} |
|
|
|
*enum_value = pad_map[attr_value_str]; |
|
|
|
} else { |
|
|
|
if (!value->isa<StringImm>()) { |
|
|
|
*enum_value = GetValue<int64_t>(value); |
|
|
|
return; |
|
|
|
} |
|
|
|
auto attr_value_str = GetValue<std::string>(value); |
|
|
|
auto iter = ReductionToEnumMap.find(attr_value_str); |
|
|
|
if (iter == ReductionToEnumMap.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same"; |
|
|
|
} |
|
|
|
*enum_value = iter->second; |
|
|
|
} |
|
|
|
|
|
|
|
AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) { |
|
|
|
|