|
|
|
@@ -167,7 +167,12 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p |
|
|
|
ShapeVector max_shape = shape->shape(); |
|
|
|
auto ids = |
|
|
|
std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape)); |
|
|
|
auto ids_idx = std::make_shared<AbstractTensor>(std::make_shared<Int>(32), shape->shape()); |
|
|
|
// Currently we choose the same data type as input for the idx. |
|
|
|
TypePtr ids_idx_type = kInt32; |
|
|
|
if (input->element() != nullptr && input->element()->GetTypeTrack() != nullptr) { |
|
|
|
ids_idx_type = input->element()->GetTypeTrack(); |
|
|
|
} |
|
|
|
auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, shape->shape()); |
|
|
|
// outputs: ids, ids_idx |
|
|
|
AbstractBasePtrList elements = {ids, ids_idx}; |
|
|
|
return std::make_shared<AbstractTuple>(elements); |
|
|
|
|