Browse Source

!8693 Add Comments for UnsortedSegmentOps

From: @huangxinjing
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d549676d36
1 changed files with 15 additions and 0 deletions
  1. +15
    -0
      mindspore/ops/operations/array_ops.py

+ 15
- 0
mindspore/ops/operations/array_ops.py View File

@@ -653,6 +653,7 @@ class Transpose(PrimitiveWithCheck):
def check_dtype(self, x, perm):
validator.check_subclass("x", x, mstype.tensor, self.name)


class Unique(Primitive):
"""
Returns the unique elements of input tensor and also return a tensor containing the index of each value of input
@@ -1672,6 +1673,9 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
up. Segment_ids does not need to be sorted, and it does not need to cover all values in the entire valid value
range.

Note:
If the segment_id i is absent in the segment_ids, then output[i] will be filled with 0.

If the sum of the given segment_ids :math:`i` is empty, then :math:`\text{output}[i] = 0`. If the given segment_ids
is negative, the value will be ignored. 'num_segments' must be equal to the number of different segment_ids.

@@ -1751,6 +1755,10 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
The data type must be int32.
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.

Note:
If the segment_id i is absent in the segment_ids, then output[i] will be filled with
the maximum value of the input_x's type.

Outputs:
Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.

@@ -1801,6 +1809,10 @@ class UnsortedSegmentMax(PrimitiveWithInfer):
The data type must be int32.
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.

Note:
If the segment_id i is absent in the segment_ids, then output[i] will be filled with
the minimum value of the input_x's type.

Outputs:
Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.

@@ -2916,12 +2928,14 @@ class ScatterUpdate(_ScatterOp_Dynamic):
[[2.0, 1.2, 1.0],
[3.0, 1.2, 1.0]]
"""

@prim_attr_register
def __init__(self, use_locking=True):
"""Initialize ScatterUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])


class ScatterNdUpdate(_ScatterNdOp):
"""
Updates tensor value by using input indices and value.
@@ -3078,6 +3092,7 @@ class ScatterAdd(_ScatterOp_Dynamic):
>>> print(output)
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
"""

@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize ScatterAdd"""


Loading…
Cancel
Save