|
|
|
@@ -331,7 +331,7 @@ std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const |
|
|
|
std::ostringstream buffer; |
|
|
|
buffer << "The primitive[" << prim_name << "]'s attribute[" << arg_name << "]"; |
|
|
|
if (check_list.size() == 1) { |
|
|
|
buffer << "must be \"" << (*check_list.begin()) << "\",but got \"" << arg_value << "\"."; |
|
|
|
buffer << " must be \"" << (*check_list.begin()) << "\",but got \"" << arg_value << "\"."; |
|
|
|
MS_EXCEPTION(ValueError) << buffer.str(); |
|
|
|
} |
|
|
|
buffer << " should be a element of {"; |
|
|
|
@@ -363,7 +363,7 @@ int64_t CheckAndConvertUtils::CheckInteger(const std::string &arg_name, int64_t |
|
|
|
if (iter_to_string == kCompareToString.end()) { |
|
|
|
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare string map"; |
|
|
|
} |
|
|
|
buffer << iter_to_string->second << "\'" << match_value << "\' , but got \'" << arg_value << "\'."; |
|
|
|
buffer << iter_to_string->second << match_value << ", but got " << arg_value << "."; |
|
|
|
MS_EXCEPTION(ValueError) << buffer.str(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -442,13 +442,13 @@ void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, Comp |
|
|
|
if (prim_name.empty()) { |
|
|
|
buffer << "The attribute[" << arg_name << "] must "; |
|
|
|
} else { |
|
|
|
buffer << "For primitive[" << prim_name << "]'s attribute[" << arg_name << "] must "; |
|
|
|
buffer << "The primitive[" << prim_name << "]'s attribute[" << arg_name << "] must "; |
|
|
|
} |
|
|
|
auto iter_to_string = kCompareToString.find(compare_type); |
|
|
|
if (iter_to_string == kCompareToString.end()) { |
|
|
|
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map"; |
|
|
|
} |
|
|
|
buffer << iter_to_string->second << "\'" << value << "\', but got \'" << arg_value << "\'"; |
|
|
|
buffer << iter_to_string->second << value << ", but got " << arg_value << "."; |
|
|
|
MS_EXCEPTION(ValueError) << buffer.str(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -462,36 +462,41 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, Ty |
|
|
|
MS_EXCEPTION_IF_NULL(type); |
|
|
|
if (!type->isa<TensorType>()) { |
|
|
|
std::ostringstream buffer; |
|
|
|
buffer << "The Primitive[" << prim_name << "]'s input_arguments must be all tensor.\n"; |
|
|
|
for (const auto &type_info : types) { |
|
|
|
buffer << "input_arguments[" << type_info.first << "]" |
|
|
|
<< ":" << type_info.second->ToString() << "\n"; |
|
|
|
} |
|
|
|
buffer << "The primitive[" << prim_name << "]'s input arguments must be all tensor.\n"; |
|
|
|
if (!check_list.empty()) { |
|
|
|
buffer << "Valid type: {"; |
|
|
|
buffer << "Valid type list: {"; |
|
|
|
for (auto const &valid_type : check_list) { |
|
|
|
buffer << valid_type->ToString() << ", "; |
|
|
|
if (valid_type->isa<TensorType>()) { |
|
|
|
buffer << valid_type->ToString() << ", "; |
|
|
|
break; |
|
|
|
} |
|
|
|
buffer << "Tensor[" << valid_type << "]" |
|
|
|
<< ", "; |
|
|
|
} |
|
|
|
buffer << "}.\n"; |
|
|
|
} |
|
|
|
for (const auto &type_info : types) { |
|
|
|
buffer << "input argument[" << type_info.first << "]" |
|
|
|
<< ":" << type_info.second->ToString() << "\n"; |
|
|
|
} |
|
|
|
MS_EXCEPTION(TypeError) << buffer.str(); |
|
|
|
} |
|
|
|
} |
|
|
|
auto check_type = _CheckTypeSame(types, prim_name, check_list, false); |
|
|
|
auto check_type = _CheckTypeSame(types, prim_name, false); |
|
|
|
std::string input_names = ""; |
|
|
|
for (const auto &item : types) { |
|
|
|
(void)input_names.append(item.first); |
|
|
|
(void)input_names.append(", "); |
|
|
|
} |
|
|
|
return CheckTypeValid(input_names, check_type, check_list, prim_name); |
|
|
|
return CheckSubClass(input_names, check_type, check_list, prim_name); |
|
|
|
} |
|
|
|
|
|
|
|
TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, const TypePtr &type, |
|
|
|
const std::set<TypePtr> &check_list, const std::string &prim_name) { |
|
|
|
MS_EXCEPTION_IF_NULL(type); |
|
|
|
if (!type->isa<TensorType>()) { |
|
|
|
MS_EXCEPTION(TypeError) << "The Primitive[" << prim_name << "] input_arguments[" << type_name |
|
|
|
<< "] must be a \'Tensor\' but got \'" << type->ToString() << "\'."; |
|
|
|
MS_EXCEPTION(TypeError) << "The Primitive[" << prim_name << "] input argument[" << type_name |
|
|
|
<< "] must be a Tensor but got " << type->ToString() << "."; |
|
|
|
} |
|
|
|
auto tensor_type = type->cast<TensorTypePtr>(); |
|
|
|
auto element = tensor_type->element(); |
|
|
|
@@ -504,17 +509,19 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return CheckSubClass(type_name, element, check_list, prim_name); |
|
|
|
return CheckSubClass(type_name, type, check_list, prim_name); |
|
|
|
} |
|
|
|
|
|
|
|
ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value, |
|
|
|
const std::string &prim_name) { |
|
|
|
if (value == nullptr) { |
|
|
|
MS_EXCEPTION(ValueError) << "The " << prim_name << "'s " << type_name << " value is nullptr."; |
|
|
|
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "] input argument[" << type_name |
|
|
|
<< "] value is nullptr."; |
|
|
|
} |
|
|
|
ShapeVector tensor_value; |
|
|
|
if (!value->isa<tensor::Tensor>()) { |
|
|
|
MS_EXCEPTION(ValueError) << "The " << prim_name << "'s " << type_name << " must be a tensor."; |
|
|
|
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "] input argument[" << type_name |
|
|
|
<< "] must be a tensor,but got " << value->ToString(); |
|
|
|
} |
|
|
|
auto input_tensor = value->cast<tensor::TensorPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(input_tensor); |
|
|
|
@@ -532,26 +539,35 @@ ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_na |
|
|
|
MS_EXCEPTION_IF_NULL(tensor_data); |
|
|
|
tensor_value = {tensor_data, tensor_data + data_size}; |
|
|
|
} else { |
|
|
|
MS_EXCEPTION(TypeError) << "The " << prim_name << "'s " << type_name << " must be a int32 or int64."; |
|
|
|
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "] input argument[" << type_name |
|
|
|
<< "] must be a Tensor[Int64] or Tensor[Int32] type,but got " << value->ToString(); |
|
|
|
} |
|
|
|
return tensor_value; |
|
|
|
} |
|
|
|
|
|
|
|
TypePtr CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const TypePtr &type_, |
|
|
|
TypePtr CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const TypePtr &type, |
|
|
|
const std::set<TypePtr> &template_types, const std::string &prim_name) { |
|
|
|
bool ok = std::any_of(template_types.begin(), template_types.end(), |
|
|
|
[type_](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(type_, accept); }); |
|
|
|
auto check_type = type; |
|
|
|
bool ok = std::any_of(template_types.begin(), template_types.end(), [check_type](const TypePtr &accept) -> bool { |
|
|
|
return IsIdentidityOrSubclass(check_type, accept); |
|
|
|
}); |
|
|
|
if (ok) { |
|
|
|
return check_type; |
|
|
|
} |
|
|
|
if (type->isa<TensorType>()) { |
|
|
|
auto tensor_type = type->cast<TensorTypePtr>(); |
|
|
|
check_type = tensor_type->element(); |
|
|
|
} |
|
|
|
ok = std::any_of(template_types.begin(), template_types.end(), |
|
|
|
[check_type](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(check_type, accept); }); |
|
|
|
if (ok) { |
|
|
|
return type_; |
|
|
|
return check_type; |
|
|
|
} else { |
|
|
|
std::string type_str = type_->ToString(); |
|
|
|
std::string type_str = type->ToString(); |
|
|
|
std::ostringstream buffer; |
|
|
|
buffer << "Primitive[" << prim_name << "]'s arguments[" << type_name << "]'s type:" << type_str |
|
|
|
<< " must be a subclass of ["; |
|
|
|
for (const auto &template_type : template_types) { |
|
|
|
buffer << template_type->ToString() << ", "; |
|
|
|
} |
|
|
|
buffer << "]"; |
|
|
|
buffer << "Primitive[" << prim_name << "]'s input argument[" << type_name << "] must be a type of "; |
|
|
|
buffer << GetErrorTypeString(template_types, type) << ", but got " << type->ToString(); |
|
|
|
buffer << "."; |
|
|
|
MS_EXCEPTION(TypeError) << buffer.str(); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -559,21 +575,19 @@ TypePtr CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const |
|
|
|
TypePtr CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args, |
|
|
|
const std::set<TypePtr> &valid_values, |
|
|
|
const std::string &prim_name, const bool allow_mix) { |
|
|
|
auto arg_ = _CheckTypeSame(args, prim_name, valid_values, allow_mix); |
|
|
|
auto arg_ = _CheckTypeSame(args, prim_name, allow_mix); |
|
|
|
return CheckTypeValid(args.begin()->first, arg_, valid_values, prim_name); |
|
|
|
} |
|
|
|
|
|
|
|
TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name, |
|
|
|
const std::set<TypePtr> &check_list, const bool allow_mix) { |
|
|
|
const bool allow_mix) { |
|
|
|
if (args.empty()) { |
|
|
|
MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!"; |
|
|
|
} |
|
|
|
std::ostringstream buffer; |
|
|
|
TypePtr return_type = nullptr; |
|
|
|
TypePtr return_type = args.begin()->second; |
|
|
|
buffer << "The primitive[" << prim_name << "]"; |
|
|
|
auto first_type = args.begin()->second; |
|
|
|
MS_EXCEPTION_IF_NULL(first_type); |
|
|
|
bool tensor_flag = first_type->isa<TensorType>(); |
|
|
|
bool tensor_flag = return_type->isa<TensorType>(); |
|
|
|
std::set<TypeId> types_id; |
|
|
|
for (const auto &elem : args) { |
|
|
|
auto type = elem.second; |
|
|
|
@@ -584,11 +598,7 @@ TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map<std::string, TypePtr |
|
|
|
buffer << "'s " |
|
|
|
<< "input type must be same.\n"; |
|
|
|
for (const auto &error_elem : args) { |
|
|
|
buffer << "input_arguments[" << error_elem.first << "]:" << error_elem.second->ToString() << "\n"; |
|
|
|
} |
|
|
|
buffer << "Validate type list:["; |
|
|
|
for (const auto &item : check_list) { |
|
|
|
buffer << item->ToString() << ", "; |
|
|
|
buffer << "input argument[" << error_elem.first << "]:" << error_elem.second->ToString() << "\n"; |
|
|
|
} |
|
|
|
MS_EXCEPTION(TypeError) << buffer.str(); |
|
|
|
} |
|
|
|
@@ -597,26 +607,24 @@ TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map<std::string, TypePtr |
|
|
|
auto tensor_type = type->cast<TensorTypePtr>(); |
|
|
|
auto element = tensor_type->element(); |
|
|
|
MS_EXCEPTION_IF_NULL(element); |
|
|
|
return_type = element->DeepCopy(); |
|
|
|
if (!allow_mix) { |
|
|
|
return_type = element; |
|
|
|
} else { |
|
|
|
return_type = tensor_type; |
|
|
|
} |
|
|
|
(void)types_id.emplace(element->type_id()); |
|
|
|
} else { |
|
|
|
(void)types_id.emplace(type->type_id()); |
|
|
|
return_type = type->DeepCopy(); |
|
|
|
} |
|
|
|
if (types_id.size() > 1) { |
|
|
|
buffer << "'s input type must be same.\n"; |
|
|
|
for (const auto &item : args) { |
|
|
|
buffer << "name:[" << item.first << "]:" << item.second->ToString() << ".\n"; |
|
|
|
} |
|
|
|
buffer << "Validate type list:["; |
|
|
|
for (const auto &item : check_list) { |
|
|
|
buffer << item->ToString() << ", "; |
|
|
|
} |
|
|
|
buffer << "]."; |
|
|
|
MS_EXCEPTION(TypeError) << buffer.str(); |
|
|
|
} |
|
|
|
} |
|
|
|
return return_type; |
|
|
|
return return_type->DeepCopy(); |
|
|
|
} |
|
|
|
|
|
|
|
TypePtr CheckAndConvertUtils::CheckTypeValid(const std::string &arg_name, const TypePtr &arg_type, |
|
|
|
@@ -661,7 +669,7 @@ void CheckAndConvertUtils::CheckSummaryParam(const AbstractBasePtr &name, const |
|
|
|
(void)CheckTypeValid("name", name->BuildType(), {kString}, class_name); |
|
|
|
auto s = GetValue<std::string>(name->BuildValue()); |
|
|
|
if (s.empty()) { |
|
|
|
MS_EXCEPTION(ValueError) << "The primitive[" << class_name << "]'s input_arguments[name] " |
|
|
|
MS_EXCEPTION(ValueError) << "The primitive[" << class_name << "]'s input argument[name] " |
|
|
|
<< " cannot be an empty string."; |
|
|
|
} |
|
|
|
(void)CheckTypeValid("value", value->BuildType(), {kTensorType}, class_name); |
|
|
|
@@ -721,7 +729,7 @@ std::vector<int64_t> CheckAndConvertUtils::CheckAttrTupleInt(const std::string & |
|
|
|
}); |
|
|
|
} else { |
|
|
|
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name |
|
|
|
<< "] must be a tuple with all Int elements, but got \'" << attr->ToString() << "\'"; |
|
|
|
<< "] must be a tuple with all Int elements, but got " << attr->ToString() << "."; |
|
|
|
} |
|
|
|
return result; |
|
|
|
} |
|
|
|
@@ -765,4 +773,43 @@ bool CheckAndConvertUtils::HasDynamicShapeInput(const AbstractBasePtrList &abs_l |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
std::string CheckAndConvertUtils::GetErrorTypeString(const std::set<TypePtr> &check_list, const TypePtr &check_type) { |
|
|
|
std::ostringstream buffer; |
|
|
|
buffer << "{"; |
|
|
|
// got tensor type list |
|
|
|
for (const auto &item : check_list) { |
|
|
|
if (item->isa<TensorType>()) { |
|
|
|
buffer << item->ToString(); |
|
|
|
buffer << ", "; |
|
|
|
continue; |
|
|
|
} |
|
|
|
buffer << "Tensor[" << item->ToString() << "], "; |
|
|
|
} |
|
|
|
if (check_type->isa<TensorType>()) { |
|
|
|
buffer << "}"; |
|
|
|
return buffer.str(); |
|
|
|
} |
|
|
|
// got python type |
|
|
|
std::set<std::string> type_string; |
|
|
|
for (const auto &item : check_list) { |
|
|
|
if (item->isa<Float>()) { |
|
|
|
type_string.emplace("Float"); |
|
|
|
} |
|
|
|
if (item->isa<Int>()) { |
|
|
|
type_string.emplace("Int"); |
|
|
|
} |
|
|
|
if (item->isa<Bool>()) { |
|
|
|
type_string.emplace("Bool"); |
|
|
|
} |
|
|
|
if (item->isa<UInt>()) { |
|
|
|
type_string.emplace("UInt"); |
|
|
|
} |
|
|
|
} |
|
|
|
for (const auto &item : type_string) { |
|
|
|
buffer << item << ","; |
|
|
|
} |
|
|
|
buffer << "}"; |
|
|
|
return buffer.str(); |
|
|
|
} |
|
|
|
} // namespace mindspore |