| @@ -65,21 +65,9 @@ void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &arg | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_number, const TypeId &scalar_type, | |||||
| const size_t &s_type_number) { | |||||
| if (scalar_type == kNumberTypeFloat16 || scalar_type == kNumberTypeFloat32 || scalar_type == kNumberTypeFloat64) { | |||||
| if (tensor_type == kNumberTypeFloat16 || tensor_type == kNumberTypeFloat32 || tensor_type == kNumberTypeFloat64) { | |||||
| return t_type_number >= s_type_number; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type, | |||||
| const size_t type_number) { | |||||
| void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_id, const size_t type_number) { | |||||
| *max_type_id = type_id; | *max_type_id = type_id; | ||||
| *max_type = type; | |||||
| *max_type_number = type_number; | *max_type_number = type_number; | ||||
| } | } | ||||
| @@ -118,7 +106,6 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId | |||||
| TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices, | TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices, | ||||
| const std::set<size_t> &write_indices) { | const std::set<size_t> &write_indices) { | ||||
| TypeId max_type_id = kTypeUnknown; | TypeId max_type_id = kTypeUnknown; | ||||
| TypeId max_type = kTypeUnknown; | |||||
| size_t max_type_number = 0; | size_t max_type_number = 0; | ||||
| bool has_int8 = false; | bool has_int8 = false; | ||||
| for (const auto &index : indices) { | for (const auto &index : indices) { | ||||
| @@ -128,6 +115,9 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||||
| if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { | if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (arg_type != kObjectTypeTensorType) { | |||||
| continue; | |||||
| } | |||||
| auto it = type_map.find(arg_type_id); | auto it = type_map.find(arg_type_id); | ||||
| if (it == type_map.end()) { | if (it == type_map.end()) { | ||||
| continue; | continue; | ||||
| @@ -136,24 +126,11 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||||
| has_int8 = true; | has_int8 = true; | ||||
| } | } | ||||
| if (max_type_id == kTypeUnknown) { | if (max_type_id == kTypeUnknown) { | ||||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (max_type == arg_type) { | |||||
| if (it->second > max_type_number) { | |||||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| } | |||||
| } else { | |||||
| if (arg_type == kObjectTypeTensorType) { | |||||
| if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) { | |||||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| } | |||||
| } else { | |||||
| if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) { | |||||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| } | |||||
| } | |||||
| if (it->second > max_type_number) { | |||||
| SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); | |||||
| } | } | ||||
| } | } | ||||