|
|
@@ -50,7 +50,7 @@ class _ScatterOp(PrimitiveWithInfer): |
|
|
|
|
|
|
|
|
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): |
|
|
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): |
|
|
if updates_shape and updates_shape != indices_shape + x_shape[1:]: |
|
|
if updates_shape and updates_shape != indices_shape + x_shape[1:]: |
|
|
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or " |
|
|
|
|
|
|
|
|
raise ValueError(f"For '{prim_name}', " |
|
|
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " |
|
|
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " |
|
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") |
|
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") |
|
|
|
|
|
|
|
|
@@ -79,7 +79,7 @@ class _ScatterNdOp(_ScatterOp): |
|
|
validator.check('the dimension of x', len(x_shape), |
|
|
validator.check('the dimension of x', len(x_shape), |
|
|
'the dimension of indices', indices_shape[-1], Rel.GE) |
|
|
'the dimension of indices', indices_shape[-1], Rel.GE) |
|
|
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape: |
|
|
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape: |
|
|
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or updates_shape = " |
|
|
|
|
|
|
|
|
raise ValueError(f"For '{prim_name}', updates_shape = " |
|
|
f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, " |
|
|
f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, " |
|
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") |
|
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") |
|
|
|
|
|
|
|
|
|