diff --git a/mindspore/ops/_utils/utils.py b/mindspore/ops/_utils/utils.py index 712c79af38..e20fef9746 100644 --- a/mindspore/ops/_utils/utils.py +++ b/mindspore/ops/_utils/utils.py @@ -54,6 +54,8 @@ def get_broadcast_shape(x_shape, y_shape, prim_name): broadcast_shape_back.append(x_shape[i]) elif x_shape[i] == y_shape[i]: broadcast_shape_back.append(x_shape[i]) + elif x_shape[i] == -1 or y_shape[i] == -1: + broadcast_shape_back.append(-1) else: raise ValueError(f"For '{prim_name}', the x_shape {x_shape} and y_shape {y_shape} can not broadcast.")