|
|
|
@@ -1271,7 +1271,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. |
|
|
|
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is a prefix of `x_shape`. |
|
|
|
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`. |
|
|
|
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
@@ -1279,7 +1279,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)) |
|
|
|
>>> segment_ids = Tensor(np.array([0, 1, 1]).np.int32) |
|
|
|
>>> segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32)) |
|
|
|
>>> num_segments = 2 |
|
|
|
>>> unsorted_segment_min = P.UnsortedSegmentMin() |
|
|
|
>>> unsorted_segment_min(input_x, segment_ids, num_segments) |
|
|
|
@@ -1299,6 +1299,8 @@ class UnsortedSegmentMin(PrimitiveWithInfer): |
|
|
|
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) |
|
|
|
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) |
|
|
|
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name) |
|
|
|
validator.check(f'first shape of input_x', x_shape[0], |
|
|
|
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) |
|
|
|
num_segments_v = num_segments['value'] |
|
|
|
validator.check_value_type('num_segments', num_segments_v, [int], self.name) |
|
|
|
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) |
|
|
|
|