|
|
|
@@ -215,15 +215,21 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign |
|
|
|
TypeId arg_type_id = kTypeUnknown; |
|
|
|
AbstractBasePtr arg_value = args_spec_list[i]; |
|
|
|
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); |
|
|
|
auto it_map = type_map.find(arg_type_id); |
|
|
|
if (it_map == type_map.end()) { |
|
|
|
auto it_map = type_name_map.find(arg_type_id); |
|
|
|
if (it_map == type_name_map.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (is_write) { |
|
|
|
if (arg_type_id != it->second) { |
|
|
|
MS_LOG(EXCEPTION) << "In op '" << func_name << "', argument '" << args_spec_list[i] |
|
|
|
<< "' can not cast type from '" << TypeIdLabel(arg_type_id) << "' to '" |
|
|
|
<< TypeIdLabel(it->second) << "' automatically."; |
|
|
|
auto it_name_map = type_name_map.find(it->second); |
|
|
|
if (it_name_map == type_name_map.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << "In op '" << func_name << "', \n" |
|
|
|
<< "the type of writable argument is '" << it_map->second << "', " |
|
|
|
<< "but the largest type in the same SignatureEumDtype is '" << it_name_map->second |
|
|
|
<< "'. The writable arg type is not equal to the largest type, " |
|
|
|
<< "so can not cast automatically."; |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
|