|
|
|
@@ -170,7 +170,9 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p |
|
|
|
std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape)); |
|
|
|
// Currently we choose the same data type as input for the idx. |
|
|
|
TypePtr ids_idx_type = kInt32; |
|
|
|
if (input->element() != nullptr && input->element()->GetTypeTrack() == kInt64) { |
|
|
|
MS_EXCEPTION_IF_NULL(input->element()); |
|
|
|
MS_EXCEPTION_IF_NULL(input->element()->GetTypeTrack()); |
|
|
|
if (input->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { |
|
|
|
ids_idx_type = kInt64; |
|
|
|
} |
|
|
|
auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, shape->shape()); |
|
|
|
|