diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 02f6ff0d11..4f1e027a31 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -807,7 +807,11 @@ class GatherV2(PrimitiveWithCheck): def __check__(self, params, indices, axis): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) - validator.check_subclass("axis", axis['dtype'], [mstype.int_], self.name) + validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) + axis_v = axis['value'] + validator.check_value_type('axis', axis_v, [int], self.name) + rank = len(params['shape']) + validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) class SparseGatherV2(GatherV2): @@ -975,6 +979,12 @@ class Split(PrimitiveWithCheck): x_shape = list(x['shape']) dim = len(x_shape) validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) + if -1 not in x_shape: + # only validate when shape fully known + output_valid_check = x_shape[self.axis] % self.output_num + if output_valid_check != 0: + raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by" + f" output_num {self.output_num}") class Rank(PrimitiveWithInfer): @@ -1945,18 +1955,21 @@ class UnsortedSegmentMin(PrimitiveWithCheck): self.add_prim_attr("dynamic_shape_depends", [2]) def __check__(self, x, segment_ids, num_segments): + x_shape = x['shape'] segment_ids_shape = segment_ids['shape'] valid_type = [mstype.float16, mstype.float32, mstype.int32] validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) num_segments_type = num_segments['dtype'] - validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) - if isinstance(num_segments_type, type(mstype.tensor)): - validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64], - self.name) - else: - validator.check_value_type('num_segments', num_segments['value'], [int], self.name) + validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name) + if (not -1 in x_shape and not -1 in segment_ids_shape): + # only validate when both shapes fully known + validator.check(f'first shape of input_x', x_shape[0], + 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) + num_segments_v = num_segments['value'] + validator.check_value_type('num_segments', num_segments_v, [int], self.name) + validator.check_positive_int(num_segments_v, "num_segments", self.name) class UnsortedSegmentMax(PrimitiveWithCheck): @@ -1998,6 +2011,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck): self.add_prim_attr("dynamic_shape_depends", [2]) def __check__(self, x, segment_ids, num_segments): + x_shape = x['shape'] segment_ids_shape = segment_ids['shape'] valid_type = [mstype.float16, mstype.float32, mstype.int32] validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) @@ -2005,12 +2019,14 @@ class UnsortedSegmentMax(PrimitiveWithCheck): [mstype.int32, mstype.int64], self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) num_segments_type = num_segments['dtype'] - validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) - if isinstance(num_segments_type, type(mstype.tensor)): - validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64], - self.name) - else: - validator.check_value_type('num_segments', num_segments['value'], [int], self.name) + validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name) + if (not -1 in x_shape and not -1 in segment_ids_shape): + # only validate when both shapes fully known + validator.check(f'first shape of input_x', x_shape[0], + 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) + num_segments_v = num_segments['value'] + validator.check_value_type('num_segments', num_segments_v, [int], self.name) + validator.check_positive_int(num_segments_v, "num_segments", self.name) class UnsortedSegmentProd(PrimitiveWithInfer): diff --git a/tests/st/ops/gpu/test_unsorted_segment_max.py b/tests/st/ops/gpu/test_unsorted_segment_max.py index fa1e1a32a8..b45b699e68 100644 --- a/tests/st/ops/gpu/test_unsorted_segment_max.py +++ b/tests/st/ops/gpu/test_unsorted_segment_max.py @@ -76,7 +76,7 @@ def test_3d_float16_int64(): input_x = Tensor(np.arange( 4 * 5 * 3, dtype=np.float16).reshape(4, 5, 3), dtype=mindspore.float16) segment_ids = Tensor([2, 1, 1, -1], mstype.int64) - num_segments = Tensor(5, dtype=mstype.int64) + num_segments = 5 net = UnsortedSegmentMaxNet(num_segments) output = net(input_x, segment_ids).asnumpy() expect = np.array([[[-6.55e+04, -6.55e+04, -6.55e+04], @@ -115,7 +115,7 @@ def test_3d_float32_int64(): input_x = Tensor(np.arange( 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) segment_ids = Tensor([2, 1, 1, -1], mstype.int64) - num_segments = Tensor(3, dtype=mstype.int64) + num_segments = 3 net = UnsortedSegmentMaxNet(num_segments) output = net(input_x, segment_ids).asnumpy() expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],