|
|
|
@@ -5886,16 +5886,18 @@ class Range(PrimitiveWithCheck): |
|
|
|
return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype) |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
class MaskedSelect(PrimitiveWithCheck): |
|
|
|
""" |
|
|
|
Returns a new 1-D Tensor which indexes the input tensor according to the boolean mask. |
|
|
|
The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. |
|
|
|
- **mask** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. |
|
|
|
- **mask** (Tensor[bool]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`. |
|
|
|
A 1-D Tensor, with the same type as x. |
|
|
|
|
|
|
|
Raises: |
|
|
|
TypeError: If `x` is not a Tensor. |
|
|
|
@@ -5905,7 +5907,7 @@ class MaskedSelect(PrimitiveWithCheck): |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64) |
|
|
|
>>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool) |
|
|
|
>>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool_) |
|
|
|
>>> output = ops.MaskedSelect()(x, mask) |
|
|
|
>>> print(output) |
|
|
|
[1 3] |
|
|
|
@@ -5921,6 +5923,7 @@ class MaskedSelect(PrimitiveWithCheck): |
|
|
|
def check_dtype(self, x_dtype, mask_dtype): |
|
|
|
validator.check_tensor_dtype_valid('mask', mask_dtype, [mstype.bool_], self.name) |
|
|
|
|
|
|
|
|
|
|
|
class SearchSorted(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
Find the indices from the innermost dimension of `sequence` such that the order of the innermost dimension |
|
|
|
|