Browse Source

fix scatter error msg

tags/v0.7.0-beta
fangzehua 5 years ago
parent
commit
17d3982d46
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      mindspore/ops/operations/array_ops.py

+ 2
- 2
mindspore/ops/operations/array_ops.py View File

@@ -50,7 +50,7 @@ class _ScatterOp(PrimitiveWithInfer):


def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
if updates_shape and updates_shape != indices_shape + x_shape[1:]: if updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
raise ValueError(f"For '{prim_name}', "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")


@@ -79,7 +79,7 @@ class _ScatterNdOp(_ScatterOp):
validator.check('the dimension of x', len(x_shape), validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE) 'the dimension of indices', indices_shape[-1], Rel.GE)
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape: if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape:
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or updates_shape = "
raise ValueError(f"For '{prim_name}', updates_shape = "
f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, " f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")




Loading…
Cancel
Save