| @@ -50,7 +50,7 @@ class _ScatterOp(PrimitiveWithInfer): | |||||
| def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): | def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): | ||||
| if updates_shape and updates_shape != indices_shape + x_shape[1:]: | if updates_shape and updates_shape != indices_shape + x_shape[1:]: | ||||
| raise ValueError(f"For '{prim_name}', the shape of updates should be [] or " | |||||
| raise ValueError(f"For '{prim_name}', " | |||||
| f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " | f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " | ||||
| f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") | f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") | ||||
| @@ -79,7 +79,7 @@ class _ScatterNdOp(_ScatterOp): | |||||
| validator.check('the dimension of x', len(x_shape), | validator.check('the dimension of x', len(x_shape), | ||||
| 'the dimension of indices', indices_shape[-1], Rel.GE) | 'the dimension of indices', indices_shape[-1], Rel.GE) | ||||
| if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape: | if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape: | ||||
| raise ValueError(f"For '{prim_name}', the shape of updates should be [] or updates_shape = " | |||||
| raise ValueError(f"For '{prim_name}', updates_shape = " | |||||
| f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, " | f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, " | ||||
| f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") | f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") | ||||