Browse Source

!7517 Choose the date type for idx in the infer implementation for Unique

Merge pull request !7517 from YuJianfeng/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f14593f5f3
1 changed files with 6 additions and 1 deletions
  1. +6
    -1
      mindspore/core/abstract/prim_arrays.cc

+ 6
- 1
mindspore/core/abstract/prim_arrays.cc View File

@@ -167,7 +167,12 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p
ShapeVector max_shape = shape->shape();
auto ids =
std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape));
auto ids_idx = std::make_shared<AbstractTensor>(std::make_shared<Int>(32), shape->shape());
// Currently we choose the same data type as input for the idx.
TypePtr ids_idx_type = kInt32;
if (input->element() != nullptr && input->element()->GetTypeTrack() != nullptr) {
ids_idx_type = input->element()->GetTypeTrack();
}
auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, shape->shape());
// outputs: ids, ids_idx
AbstractBasePtrList elements = {ids, ids_idx};
return std::make_shared<AbstractTuple>(elements);


Loading…
Cancel
Save