|
|
@@ -4551,7 +4551,9 @@ class BroadcastTo(PrimitiveWithInfer): |
|
|
the target dimension is -1. In case of -1 in target shape, it will be replaced by the input shape's value |
|
|
the target dimension is -1. In case of -1 in target shape, it will be replaced by the input shape's value |
|
|
in that dimension. |
|
|
in that dimension. |
|
|
|
|
|
|
|
|
When input shape is broadcast to target shape, it starts with the trailing dimensions. |
|
|
|
|
|
|
|
|
When input shape is broadcast to target shape, it starts with the trailing |
|
|
|
|
|
dimensions. If there is a -1 in the target shape, the -1 cannot be in a leading, |
|
|
|
|
|
non-existing dimension. |
|
|
|
|
|
|
|
|
Args: |
|
|
Args: |
|
|
shape (tuple): The target shape to broadcast. Can be fully specified, or have -1 in one position |
|
|
shape (tuple): The target shape to broadcast. Can be fully specified, or have -1 in one position |
|
|
@@ -4566,9 +4568,8 @@ class BroadcastTo(PrimitiveWithInfer): |
|
|
|
|
|
|
|
|
Raises: |
|
|
Raises: |
|
|
TypeError: If `shape` is not a tuple. |
|
|
TypeError: If `shape` is not a tuple. |
|
|
ValueError: Given a shape tuple, if it has several -1; or if the -1 is in an invalid position |
|
|
|
|
|
such as one that does not have a opposing dimension in an input tensor; or if the target and |
|
|
|
|
|
input shapes are incompatible. |
|
|
|
|
|
|
|
|
ValueError: if the target and input shapes are incompatible, or if a -1 in the |
|
|
|
|
|
target shape is in an invalid location. |
|
|
|
|
|
|
|
|
Supported Platforms: |
|
|
Supported Platforms: |
|
|
``Ascend`` ``GPU`` |
|
|
``Ascend`` ``GPU`` |
|
|
@@ -4582,13 +4583,13 @@ class BroadcastTo(PrimitiveWithInfer): |
|
|
[[1. 2. 3.] |
|
|
[[1. 2. 3.] |
|
|
[1. 2. 3.]] |
|
|
[1. 2. 3.]] |
|
|
|
|
|
|
|
|
>>> shape = (2, -1) |
|
|
|
|
|
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) |
|
|
|
|
|
|
|
|
>>> shape = (-1, 2) |
|
|
|
|
|
>>> input_x = Tensor(np.array([[1], [2]]).astype(np.float32)) |
|
|
>>> broadcast_to = ops.BroadcastTo(shape) |
|
|
>>> broadcast_to = ops.BroadcastTo(shape) |
|
|
>>> output = broadcast_to(input_x) |
|
|
>>> output = broadcast_to(input_x) |
|
|
>>> print(output) |
|
|
>>> print(output) |
|
|
[[1. 2. 3.] |
|
|
|
|
|
[1. 2. 3.]] |
|
|
|
|
|
|
|
|
[[1. 1.] |
|
|
|
|
|
[2. 2.]] |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
@prim_attr_register |
|
|
@prim_attr_register |
|
|
@@ -4600,35 +4601,30 @@ class BroadcastTo(PrimitiveWithInfer): |
|
|
validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name) |
|
|
validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name) |
|
|
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name) |
|
|
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name) |
|
|
self.shape = shape |
|
|
self.shape = shape |
|
|
if -1 in self.shape: |
|
|
|
|
|
undef_dims = self.shape.count(-1) |
|
|
|
|
|
if undef_dims > 1: |
|
|
|
|
|
raise ValueError(f'The shape can only has one -1 at most, but has {undef_dims}.') |
|
|
|
|
|
self.dyn = True |
|
|
|
|
|
else: |
|
|
|
|
|
self.dyn = False |
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
def infer_shape(self, x_shape): |
|
|
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name) |
|
|
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name) |
|
|
target_shape = list(self.shape) |
|
|
|
|
|
outer_dim_offset = len(target_shape) - len(x_shape) |
|
|
|
|
|
if self.dyn: |
|
|
|
|
|
for i, v in enumerate(target_shape): |
|
|
|
|
|
if v == -1: |
|
|
|
|
|
if i < outer_dim_offset: |
|
|
|
|
|
raise ValueError(f" -1 in init shape is in an incompatible location" |
|
|
|
|
|
f" with given input tensor, -1 index in init shape: {i}" |
|
|
|
|
|
f" but -1 can only be in index {len(x_shape)} onwards for this input.") |
|
|
|
|
|
target_shape[i] = x_shape[i - outer_dim_offset] |
|
|
|
|
|
|
|
|
|
|
|
reversed_x_shape = tuple(reversed(x_shape)) |
|
|
reversed_x_shape = tuple(reversed(x_shape)) |
|
|
reversed_target = tuple(reversed(target_shape)) |
|
|
|
|
|
|
|
|
reversed_filtered_target = [] |
|
|
|
|
|
for i, v in enumerate(tuple(reversed(self.shape))): |
|
|
|
|
|
if v == -1: |
|
|
|
|
|
if i >= len(reversed_x_shape): |
|
|
|
|
|
raise ValueError("-1 is not valid in a leading, non-existing dimension") |
|
|
|
|
|
|
|
|
|
|
|
reversed_filtered_target.append(reversed_x_shape[i]) |
|
|
|
|
|
else: |
|
|
|
|
|
reversed_filtered_target.append(v) |
|
|
|
|
|
|
|
|
|
|
|
self.shape = tuple(reversed(reversed_filtered_target)) |
|
|
|
|
|
self.add_prim_attr('shape', self.shape) |
|
|
|
|
|
|
|
|
for i, v in enumerate(reversed_x_shape): |
|
|
for i, v in enumerate(reversed_x_shape): |
|
|
if v not in (reversed_target[i], 1): |
|
|
|
|
|
|
|
|
if v not in (reversed_filtered_target[i], 1): |
|
|
raise ValueError(f"Not supported shapes for broadcast, " |
|
|
raise ValueError(f"Not supported shapes for broadcast, " |
|
|
f"x_shape: {tuple(x_shape)}, target shape {target_shape}.") |
|
|
|
|
|
self.shape = tuple(target_shape) |
|
|
|
|
|
self.add_prim_attr('shape', self.shape) |
|
|
|
|
|
return target_shape |
|
|
|
|
|
|
|
|
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.") |
|
|
|
|
|
|
|
|
|
|
|
return self.shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
def infer_dtype(self, x_dtype): |
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) |
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) |
|
|
|