|
|
|
@@ -106,6 +106,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve |
|
|
|
TypeId max_type_id = kTypeUnknown; |
|
|
|
size_t max_type_number = 0; |
|
|
|
bool has_int8 = false; |
|
|
|
bool has_scalar_int32 = false; |
|
|
|
bool has_scalar_float32 = false; |
|
|
|
for (const auto &index : indices) { |
|
|
|
TypeId arg_type_id = kTypeUnknown; |
|
|
|
TypeId arg_type = kTypeUnknown; |
|
|
|
@@ -114,6 +116,11 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (arg_type != kObjectTypeTensorType) { |
|
|
|
if (arg_type_id == kNumberTypeInt32) { |
|
|
|
has_scalar_int32 = true; |
|
|
|
} else if (arg_type_id == kNumberTypeFloat32) { |
|
|
|
has_scalar_float32 = true; |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto it = type_map.find(arg_type_id); |
|
|
|
@@ -135,6 +142,17 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve |
|
|
|
if (max_type_id == kNumberTypeUInt8 && has_int8 == true) { |
|
|
|
max_type_id = kNumberTypeInt16; |
|
|
|
} |
|
|
|
// if bool is the max type, see if there is scalar input |
|
|
|
// if so, it means that max is bool tensor, use scalar type instead. |
|
|
|
// for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2]) |
|
|
|
if (max_type_id == kNumberTypeBool) { |
|
|
|
if (has_scalar_int32) { |
|
|
|
max_type_id = kNumberTypeInt32; |
|
|
|
} |
|
|
|
if (has_scalar_float32) { |
|
|
|
max_type_id = kNumberTypeFloat32; |
|
|
|
} |
|
|
|
} |
|
|
|
return max_type_id; |
|
|
|
} |
|
|
|
|
|
|
|
|