|
|
|
@@ -807,7 +807,11 @@ class GatherV2(PrimitiveWithCheck): |
|
|
|
def __check__(self, params, indices, axis): |
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) |
|
|
|
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) |
|
|
|
validator.check_subclass("axis", axis['dtype'], [mstype.int_], self.name) |
|
|
|
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) |
|
|
|
axis_v = axis['value'] |
|
|
|
validator.check_value_type('axis', axis_v, [int], self.name) |
|
|
|
rank = len(params['shape']) |
|
|
|
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) |
|
|
|
|
|
|
|
|
|
|
|
class SparseGatherV2(GatherV2): |
|
|
|
@@ -975,6 +979,12 @@ class Split(PrimitiveWithCheck): |
|
|
|
x_shape = list(x['shape']) |
|
|
|
dim = len(x_shape) |
|
|
|
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) |
|
|
|
if -1 not in x_shape: |
|
|
|
# only validate when shape fully known |
|
|
|
output_valid_check = x_shape[self.axis] % self.output_num |
|
|
|
if output_valid_check != 0: |
|
|
|
raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by" |
|
|
|
f" output_num {self.output_num}") |
|
|
|
|
|
|
|
|
|
|
|
class Rank(PrimitiveWithInfer): |
|
|
|
@@ -1945,18 +1955,21 @@ class UnsortedSegmentMin(PrimitiveWithCheck): |
|
|
|
self.add_prim_attr("dynamic_shape_depends", [2]) |
|
|
|
|
|
|
|
def __check__(self, x, segment_ids, num_segments): |
|
|
|
x_shape = x['shape'] |
|
|
|
segment_ids_shape = segment_ids['shape'] |
|
|
|
valid_type = [mstype.float16, mstype.float32, mstype.int32] |
|
|
|
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) |
|
|
|
validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name) |
|
|
|
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) |
|
|
|
num_segments_type = num_segments['dtype'] |
|
|
|
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) |
|
|
|
if isinstance(num_segments_type, type(mstype.tensor)): |
|
|
|
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64], |
|
|
|
self.name) |
|
|
|
else: |
|
|
|
validator.check_value_type('num_segments', num_segments['value'], [int], self.name) |
|
|
|
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name) |
|
|
|
if (not -1 in x_shape and not -1 in segment_ids_shape): |
|
|
|
# only validate when both shapes fully known |
|
|
|
validator.check(f'first shape of input_x', x_shape[0], |
|
|
|
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) |
|
|
|
num_segments_v = num_segments['value'] |
|
|
|
validator.check_value_type('num_segments', num_segments_v, [int], self.name) |
|
|
|
validator.check_positive_int(num_segments_v, "num_segments", self.name) |
|
|
|
|
|
|
|
|
|
|
|
class UnsortedSegmentMax(PrimitiveWithCheck): |
|
|
|
@@ -1998,6 +2011,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck): |
|
|
|
self.add_prim_attr("dynamic_shape_depends", [2]) |
|
|
|
|
|
|
|
def __check__(self, x, segment_ids, num_segments): |
|
|
|
x_shape = x['shape'] |
|
|
|
segment_ids_shape = segment_ids['shape'] |
|
|
|
valid_type = [mstype.float16, mstype.float32, mstype.int32] |
|
|
|
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) |
|
|
|
@@ -2005,12 +2019,14 @@ class UnsortedSegmentMax(PrimitiveWithCheck): |
|
|
|
[mstype.int32, mstype.int64], self.name) |
|
|
|
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) |
|
|
|
num_segments_type = num_segments['dtype'] |
|
|
|
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) |
|
|
|
if isinstance(num_segments_type, type(mstype.tensor)): |
|
|
|
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64], |
|
|
|
self.name) |
|
|
|
else: |
|
|
|
validator.check_value_type('num_segments', num_segments['value'], [int], self.name) |
|
|
|
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name) |
|
|
|
if (not -1 in x_shape and not -1 in segment_ids_shape): |
|
|
|
# only validate when both shapes fully known |
|
|
|
validator.check(f'first shape of input_x', x_shape[0], |
|
|
|
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) |
|
|
|
num_segments_v = num_segments['value'] |
|
|
|
validator.check_value_type('num_segments', num_segments_v, [int], self.name) |
|
|
|
validator.check_positive_int(num_segments_v, "num_segments", self.name) |
|
|
|
|
|
|
|
|
|
|
|
class UnsortedSegmentProd(PrimitiveWithInfer): |
|
|
|
|