|
|
|
@@ -1883,18 +1883,27 @@ class UnsortedSegmentSum(PrimitiveWithInfer): |
|
|
|
shp = [num_segments_v] |
|
|
|
|
|
|
|
shp += x_shp[segment_ids_shp_len:] |
|
|
|
if 'max_shape' in x: |
|
|
|
output_incoming = x['max_shape'] |
|
|
|
if "max_value" in num_segments and "min_value" in num_segments: |
|
|
|
output_max_shape = list(num_segments['max_value']) |
|
|
|
output_min_shape = list(num_segments['min_value']) |
|
|
|
else: |
|
|
|
if isinstance(num_segments_type, type(mstype.tensor)): |
|
|
|
raise ValueError("In dynamic shape scene, the num_segments should contains max_value and min_value") |
|
|
|
output_max_shape = [num_segments_v] |
|
|
|
output_max_shape += output_incoming[segment_ids_shp_len:] |
|
|
|
output_min_shape = [num_segments_v] |
|
|
|
if 'max_shape' in x and 'min_shape' in x: |
|
|
|
max_output_incoming = x['max_shape'] |
|
|
|
min_output_incoming = x['min_shape'] |
|
|
|
else: |
|
|
|
output_max_shape = x_shp |
|
|
|
out = {'shape': shp, |
|
|
|
'max_shape': output_max_shape, |
|
|
|
'min_shape': [1] * segment_ids_shp_len + x_shp[segment_ids_shp_len:], |
|
|
|
'dtype': mstype.tensor_type(x_type.element_type()), |
|
|
|
'value': None} |
|
|
|
return out |
|
|
|
max_output_incoming = x_shp |
|
|
|
min_output_incoming = x_shp |
|
|
|
output_max_shape += max_output_incoming[segment_ids_shp_len:] |
|
|
|
output_min_shape += min_output_incoming[segment_ids_shp_len:] |
|
|
|
return {'shape': shp, |
|
|
|
'max_shape': output_max_shape, |
|
|
|
'min_shape': output_min_shape, |
|
|
|
'dtype': mstype.tensor_type(x_type.element_type()), |
|
|
|
'value': None} |
|
|
|
|
|
|
|
|
|
|
|
class UnsortedSegmentMin(PrimitiveWithCheck): |
|
|
|
@@ -2688,6 +2697,26 @@ class StridedSlice(PrimitiveWithInfer): |
|
|
|
ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v) |
|
|
|
|
|
|
|
value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type()) |
|
|
|
if "max_value" in x and "min_value" in x: |
|
|
|
validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name) |
|
|
|
validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name) |
|
|
|
max_value_np = np.array(x["max_value"]) |
|
|
|
min_value_np = np.array(x["min_value"]) |
|
|
|
slice_index = [] |
|
|
|
for begin_i, end_i, strides_i in zip(begin_v, end_v, strides_v): |
|
|
|
s = slice(begin_i, end_i, strides_i) |
|
|
|
slice_index.append(s) |
|
|
|
slice_index = tuple(slice_index) |
|
|
|
max_value_slice = max_value_np[slice_index] |
|
|
|
min_value_slice = min_value_np[slice_index] |
|
|
|
max_value_slice = tuple(max_value_slice.tolist()) |
|
|
|
min_value_slice = tuple(min_value_slice.tolist()) |
|
|
|
return {'shape': ret_shape, |
|
|
|
'dtype': x['dtype'], |
|
|
|
'value': value, |
|
|
|
'max_value': max_value_slice, |
|
|
|
'min_value': min_value_slice} |
|
|
|
|
|
|
|
return {'shape': ret_shape, |
|
|
|
'dtype': x['dtype'], |
|
|
|
'value': value} |
|
|
|
|