|
|
|
@@ -54,6 +54,8 @@ def get_broadcast_shape(x_shape, y_shape, prim_name): |
|
|
|
broadcast_shape_back.append(x_shape[i]) |
|
|
|
elif x_shape[i] == y_shape[i]: |
|
|
|
broadcast_shape_back.append(x_shape[i]) |
|
|
|
elif x_shape[i] == -1 or y_shape[i] == -1: |
|
|
|
broadcast_shape_back.append(-1) |
|
|
|
else: |
|
|
|
raise ValueError(f"For '{prim_name}', the x_shape {x_shape} and y_shape {y_shape} can not broadcast.") |
|
|
|
|
|
|
|
|