|
|
|
@@ -2716,6 +2716,8 @@ class BatchToSpaceND(PrimitiveWithInfer): |
|
|
|
class BroadcastTo(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
Broadcasts input tensor to a given shape. |
|
|
|
Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one. |
|
|
|
When input shape is broadcast to target shape, it starts with the trailing dimensions. |
|
|
|
|
|
|
|
Args: |
|
|
|
shape (tuple): The target shape to broadcast. |
|
|
|
@@ -2738,11 +2740,20 @@ class BroadcastTo(PrimitiveWithInfer): |
|
|
|
def __init__(self, shape): |
|
|
|
"""Init BroadcastTo""" |
|
|
|
validator.check_value_type("shape", shape, (tuple), self.name) |
|
|
|
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name) |
|
|
|
for i in shape: |
|
|
|
validator.check_integer("shape element", i, 0, Rel.GT, self.name) |
|
|
|
self.shape = shape |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name) |
|
|
|
|
|
|
|
reversed_x_shape = tuple(reversed(x_shape)) |
|
|
|
reversed_target = tuple(reversed(self.shape)) |
|
|
|
for i, v in enumerate(reversed_x_shape): |
|
|
|
if v not in (reversed_target[i], 1): |
|
|
|
raise ValueError(f"Not supported shapes for broadcast, " |
|
|
|
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.") |
|
|
|
return self.shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
|