@@ -807,7 +807,11 @@ class GatherV2(PrimitiveWithCheck):
def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.int_], self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
axis_v = axis['value']
validator.check_value_type('axis', axis_v, [int], self.name)
rank = len(params['shape'])
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
class SparseGatherV2(GatherV2):
@@ -975,6 +979,12 @@ class Split(PrimitiveWithCheck):
x_shape = list(x['shape'])
dim = len(x_shape)
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
if -1 not in x_shape:
# only validate when shape fully known
output_valid_check = x_shape[self.axis] % self.output_num
if output_valid_check != 0:
raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by"
f" output_num {self.output_num}")
class Rank(PrimitiveWithInfer):
@@ -1945,18 +1955,21 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
self.add_prim_attr("dynamic_shape_depends", [2])
def __check__(self, x, segment_ids, num_segments):
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
if isinstance(num_segments_type, type(mstype.tensor)):
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64],
self.name)
else:
validator.check_value_type('num_segments', num_segments['value'], [int], self.name)
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
if (not -1 in x_shape and not -1 in segment_ids_shape):
# only validate when both shapes fully known
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_positive_int(num_segments_v, "num_segments", self.name)
class UnsortedSegmentMax(PrimitiveWithCheck):
@@ -1998,6 +2011,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
self.add_prim_attr("dynamic_shape_depends", [2])
def __check__(self, x, segment_ids, num_segments):
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
@@ -2005,12 +2019,14 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
[mstype.int32, mstype.int64], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
if isinstance(num_segments_type, type(mstype.tensor)):
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64],
self.name)
else:
validator.check_value_type('num_segments', num_segments['value'], [int], self.name)
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
if (not -1 in x_shape and not -1 in segment_ids_shape):
# only validate when both shapes fully known
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_positive_int(num_segments_v, "num_segments", self.name)
class UnsortedSegmentProd(PrimitiveWithInfer):