From 552a3c4a072991a69f9e93dc38d3dccd6d468a91 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Fri, 30 Oct 2020 14:39:11 +0800 Subject: [PATCH] Fix shape checking for scatterop in dynamic shape scene --- mindspore/core/abstract/prim_arrays.cc | 4 ++-- mindspore/ops/operations/array_ops.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index cb4e2753b5..abf3ede689 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -169,8 +169,8 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p std::make_shared(input->element(), std::make_shared(ids_shape, min_shape, max_shape)); // Currently we choose the same data type as input for the idx. TypePtr ids_idx_type = kInt32; - if (input->element() != nullptr && input->element()->GetTypeTrack() != nullptr) { - ids_idx_type = input->element()->GetTypeTrack(); + if (input->element() != nullptr && input->element()->GetTypeTrack() == kInt64) { + ids_idx_type = kInt64; } auto ids_idx = std::make_shared(ids_idx_type, shape->shape()); // outputs: ids, ids_idx diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 5138bbd074..d4648fcd23 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -47,7 +47,7 @@ class _ScatterOp(PrimitiveWithInfer): ) def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): - if updates_shape and updates_shape != indices_shape + x_shape[1:]: + if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]: raise ValueError(f"For '{prim_name}', " f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")