Browse Source

Fix shape checking for scatterop in dynamic shape scene

tags/v1.1.0
yujianfeng 5 years ago
parent
commit
552a3c4a07
2 changed files with 3 additions and 3 deletions
  1. +2
    -2
      mindspore/core/abstract/prim_arrays.cc
  2. +1
    -1
      mindspore/ops/operations/array_ops.py

+ 2
- 2
mindspore/core/abstract/prim_arrays.cc View File

@@ -169,8 +169,8 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p
std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape));
// Currently we choose the same data type as input for the idx.
TypePtr ids_idx_type = kInt32;
if (input->element() != nullptr && input->element()->GetTypeTrack() != nullptr) {
ids_idx_type = input->element()->GetTypeTrack();
if (input->element() != nullptr && input->element()->GetTypeTrack() == kInt64) {
ids_idx_type = kInt64;
}
auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, shape->shape());
// outputs: ids, ids_idx


+ 1
- 1
mindspore/ops/operations/array_ops.py View File

@@ -47,7 +47,7 @@ class _ScatterOp(PrimitiveWithInfer):
)

def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")


Loading…
Cancel
Save