Browse Source

range check - UortedSegmin/max num_seg tensors

added in checks in PythonAPI to trigger Value errors in case of non dyn

added required check for python API

ST update + op check update

revert segSum change

updated gatherV2 and SplitOp check

added shape -1 sanity check - splitOp

slight refactor in unseg-Min/Max

lintfix

lint fix
tags/v1.1.0
danishnxt 5 years ago
parent
commit
bc84e9b4e7
2 changed files with 31 additions and 15 deletions
  1. +29
    -13
      mindspore/ops/operations/array_ops.py
  2. +2
    -2
      tests/st/ops/gpu/test_unsorted_segment_max.py

+ 29
- 13
mindspore/ops/operations/array_ops.py View File

@@ -807,7 +807,11 @@ class GatherV2(PrimitiveWithCheck):
def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.int_], self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
axis_v = axis['value']
validator.check_value_type('axis', axis_v, [int], self.name)
rank = len(params['shape'])
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)


class SparseGatherV2(GatherV2):
@@ -975,6 +979,12 @@ class Split(PrimitiveWithCheck):
x_shape = list(x['shape'])
dim = len(x_shape)
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
if -1 not in x_shape:
# only validate when shape fully known
output_valid_check = x_shape[self.axis] % self.output_num
if output_valid_check != 0:
raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by"
f" output_num {self.output_num}")


class Rank(PrimitiveWithInfer):
@@ -1945,18 +1955,21 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
self.add_prim_attr("dynamic_shape_depends", [2])

def __check__(self, x, segment_ids, num_segments):
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
if isinstance(num_segments_type, type(mstype.tensor)):
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64],
self.name)
else:
validator.check_value_type('num_segments', num_segments['value'], [int], self.name)
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
if (not -1 in x_shape and not -1 in segment_ids_shape):
# only validate when both shapes fully known
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_positive_int(num_segments_v, "num_segments", self.name)


class UnsortedSegmentMax(PrimitiveWithCheck):
@@ -1998,6 +2011,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
self.add_prim_attr("dynamic_shape_depends", [2])

def __check__(self, x, segment_ids, num_segments):
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
@@ -2005,12 +2019,14 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
[mstype.int32, mstype.int64], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
if isinstance(num_segments_type, type(mstype.tensor)):
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64],
self.name)
else:
validator.check_value_type('num_segments', num_segments['value'], [int], self.name)
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
if (not -1 in x_shape and not -1 in segment_ids_shape):
# only validate when both shapes fully known
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_positive_int(num_segments_v, "num_segments", self.name)


class UnsortedSegmentProd(PrimitiveWithInfer):


+ 2
- 2
tests/st/ops/gpu/test_unsorted_segment_max.py View File

@@ -76,7 +76,7 @@ def test_3d_float16_int64():
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float16).reshape(4, 5, 3), dtype=mindspore.float16)
segment_ids = Tensor([2, 1, 1, -1], mstype.int64)
num_segments = Tensor(5, dtype=mstype.int64)
num_segments = 5
net = UnsortedSegmentMaxNet(num_segments)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[-6.55e+04, -6.55e+04, -6.55e+04],
@@ -115,7 +115,7 @@ def test_3d_float32_int64():
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([2, 1, 1, -1], mstype.int64)
num_segments = Tensor(3, dtype=mstype.int64)
num_segments = 3
net = UnsortedSegmentMaxNet(num_segments)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],


Loading…
Cancel
Save