| @@ -65,21 +65,9 @@ void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &arg | |||
| } | |||
| } | |||
| } | |||
| bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_number, const TypeId &scalar_type, | |||
| const size_t &s_type_number) { | |||
| if (scalar_type == kNumberTypeFloat16 || scalar_type == kNumberTypeFloat32 || scalar_type == kNumberTypeFloat64) { | |||
| if (tensor_type == kNumberTypeFloat16 || tensor_type == kNumberTypeFloat32 || tensor_type == kNumberTypeFloat64) { | |||
| return t_type_number >= s_type_number; | |||
| } | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type, | |||
| const size_t type_number) { | |||
| void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_id, const size_t type_number) { | |||
| *max_type_id = type_id; | |||
| *max_type = type; | |||
| *max_type_number = type_number; | |||
| } | |||
| @@ -118,7 +106,6 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId | |||
| TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices, | |||
| const std::set<size_t> &write_indices) { | |||
| TypeId max_type_id = kTypeUnknown; | |||
| TypeId max_type = kTypeUnknown; | |||
| size_t max_type_number = 0; | |||
| bool has_int8 = false; | |||
| for (const auto &index : indices) { | |||
| @@ -128,6 +115,9 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||
| if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { | |||
| continue; | |||
| } | |||
| if (arg_type != kObjectTypeTensorType) { | |||
| continue; | |||
| } | |||
| auto it = type_map.find(arg_type_id); | |||
| if (it == type_map.end()) { | |||
| continue; | |||
| @@ -136,24 +126,11 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||
| has_int8 = true; | |||
| } | |||
| if (max_type_id == kTypeUnknown) { | |||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); | |||
| continue; | |||
| } | |||
| if (max_type == arg_type) { | |||
| if (it->second > max_type_number) { | |||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| } | |||
| } else { | |||
| if (arg_type == kObjectTypeTensorType) { | |||
| if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) { | |||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| } | |||
| } else { | |||
| if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) { | |||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| } | |||
| } | |||
| if (it->second > max_type_number) { | |||
| SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); | |||
| } | |||
| } | |||