Browse Source

fix file clean

tags/v1.3.0
wuxuejian 4 years ago
parent
commit
5ac46e0007
2 changed files with 7 additions and 3 deletions
  1. +1
    -0
      mindspore/ops/_grad/grad_array_ops.py
  2. +6
    -3
      mindspore/ops/operations/array_ops.py

+ 1
- 0
mindspore/ops/_grad/grad_array_ops.py View File

@@ -1060,6 +1060,7 @@ def get_bprop_unique(self):

return bprop


@bprop_getters.register(P.MaskedSelect)
def get_bprop_masked_select(self):
"""Generate bprop for MaskedSelect"""


+ 6
- 3
mindspore/ops/operations/array_ops.py View File

@@ -5886,16 +5886,18 @@ class Range(PrimitiveWithCheck):
return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype)
return None


class MaskedSelect(PrimitiveWithCheck):
"""
Returns a new 1-D Tensor which indexes the input tensor according to the boolean mask.
The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.

Inputs:
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
- **mask** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
- **mask** (Tensor[bool]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.

Outputs:
Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`.
A 1-D Tensor, with the same type as x.

Raises:
TypeError: If `x` is not a Tensor.
@@ -5905,7 +5907,7 @@ class MaskedSelect(PrimitiveWithCheck):

Examples:
>>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
>>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool)
>>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool_)
>>> output = ops.MaskedSelect()(x, mask)
>>> print(output)
[1 3]
@@ -5921,6 +5923,7 @@ class MaskedSelect(PrimitiveWithCheck):
def check_dtype(self, x_dtype, mask_dtype):
validator.check_tensor_dtype_valid('mask', mask_dtype, [mstype.bool_], self.name)


class SearchSorted(PrimitiveWithInfer):
"""
Find the indices from the innermost dimension of `sequence` such that the order of the innermost dimension


Loading…
Cancel
Save