diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index b53a7412fc..1155fb7c03 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -230,8 +230,9 @@ def get_bprop_embedding_look_up(self): # Reshape the 'new_indices' new_indices_shape_changed = (size_op(new_indices),) new_indices = reshape_op(new_indices, new_indices_shape_changed) - x_shp_tail = x_shp[1:] - actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail + actual_dout_shape_changed = new_indices_shape_changed + if len(x_shp) > 1: + actual_dout_shape_changed += x_shp[1:] actual_dout = reshape_op(dout, actual_dout_shape_changed) return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) return bprop diff --git a/tests/st/ops/ascend/test_embedding_lookup.py b/tests/st/ops/ascend/test_embedding_lookup.py index 483fdcdbc4..6aee25d9da 100644 --- a/tests/st/ops/ascend/test_embedding_lookup.py +++ b/tests/st/ops/ascend/test_embedding_lookup.py @@ -15,6 +15,7 @@ import numpy as np import mindspore.context as context +import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore import Tensor from mindspore.ops import operations as P