|
|
|
@@ -181,15 +181,15 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args, |
|
|
|
} |
|
|
|
size_t priority = 0; |
|
|
|
TypeId max_type = TypeId::kTypeUnknown; |
|
|
|
bool has_float = false; |
|
|
|
bool has_int = false; |
|
|
|
bool has_int8 = false; |
|
|
|
bool has_scalar_float32 = false; |
|
|
|
bool has_scalar_int64 = false; |
|
|
|
bool has_tensor_int8 = false; |
|
|
|
for (size_t index : indexes) { |
|
|
|
if (!has_float && py::isinstance<py::float_>(py_args[index])) { |
|
|
|
has_float = true; |
|
|
|
if (!has_scalar_float32 && py::isinstance<py::float_>(py_args[index])) { |
|
|
|
has_scalar_float32 = true; |
|
|
|
} |
|
|
|
if (!has_int && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) { |
|
|
|
has_int = true; |
|
|
|
if (!has_scalar_int64 && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) { |
|
|
|
has_scalar_int64 = true; |
|
|
|
} |
|
|
|
|
|
|
|
auto obj = py_args[index]; |
|
|
|
@@ -201,7 +201,7 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args, |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (arg_type_id == kNumberTypeInt8) { |
|
|
|
has_int8 = true; |
|
|
|
has_tensor_int8 = true; |
|
|
|
} |
|
|
|
if (type_priority->second > priority) { |
|
|
|
max_type = type_priority->first; |
|
|
|
@@ -210,18 +210,18 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args, |
|
|
|
} |
|
|
|
} |
|
|
|
if (max_type == TypeId::kNumberTypeBool) { |
|
|
|
if (has_int) { |
|
|
|
if (has_scalar_int64) { |
|
|
|
max_type = TypeId::kNumberTypeInt64; |
|
|
|
} |
|
|
|
if (has_float) { |
|
|
|
if (has_scalar_float32) { |
|
|
|
max_type = TypeId::kNumberTypeFloat32; |
|
|
|
} |
|
|
|
} |
|
|
|
if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 && |
|
|
|
max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_float) { |
|
|
|
max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) { |
|
|
|
max_type = TypeId::kNumberTypeFloat32; |
|
|
|
} |
|
|
|
if (max_type == TypeId::kNumberTypeUInt8 && has_int8) { |
|
|
|
if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) { |
|
|
|
max_type = TypeId::kNumberTypeInt16; |
|
|
|
} |
|
|
|
(void)dst_type.emplace(std::make_pair(type, max_type)); |
|
|
|
|