Browse Source

!6400 add overflow check for make_range and optimize isinstance processing

Merge pull request !6400 from zhangbuxue/add_overflow_check_for_make_range_and_optimize_isinstance_processing
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
8346da267b
4 changed files with 16 additions and 6 deletions
  1. +1
    -1
      mindspore/_extends/parse/standard_method.py
  2. +10
    -0
      mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc
  3. +1
    -1
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  4. +4
    -4
      tests/ut/python/pipeline/infer/test_not_in.py

+ 1
- 1
mindspore/_extends/parse/standard_method.py View File

@@ -173,7 +173,7 @@ def check_type_same(x_type, base_type):
"""Check x_type is same as base_type.""" """Check x_type is same as base_type."""
if mstype.issubclass_(x_type, base_type): if mstype.issubclass_(x_type, base_type):
return True return True
raise TypeError(f"The arg 'x' should be a {base_type}, but got {x_type}.")
return False




@constexpr @constexpr


+ 10
- 0
mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc View File

@@ -489,15 +489,25 @@ AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr
if (slide.step <= 0) { if (slide.step <= 0) {
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
} }

for (int i = slide.start; i < slide.stop; i += slide.step) { for (int i = slide.start; i < slide.stop; i += slide.step) {
args.push_back(abstract::FromValue(i)); args.push_back(abstract::FromValue(i));
if (i > 0 && INT_MAX - i < slide.step) {
MS_EXCEPTION(ValueError) << "For make range, the required cycles number is greater than max cycles number, "
"will cause integer overflow.";
}
} }
} else { } else {
if (slide.step >= 0) { if (slide.step >= 0) {
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
} }

for (int i = slide.start; i > slide.stop; i += slide.step) { for (int i = slide.start; i > slide.stop; i += slide.step) {
args.push_back(abstract::FromValue(i)); args.push_back(abstract::FromValue(i));
if (i < 0 && INT_MIN - i > slide.step) {
MS_EXCEPTION(ValueError) << "For make range, the required cycles number is greater than max cycles number, "
"will cause integer overflow.";
}
} }
} }




+ 1
- 1
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -268,7 +268,7 @@ def _tensor_index_by_tuple_slice(data, t):
def tensor_index_by_tuple(data, tuple_index): def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types""" """Tensor getitem by tuple of various types"""
if len(tuple_index) == 1: if len(tuple_index) == 1:
return data[tuple_index[0]]
return data[tuple_index[0]]
indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR: if index_elements_type == const_utils.NO_TENSOR:


+ 4
- 4
tests/ut/python/pipeline/infer/test_not_in.py View File

@@ -40,17 +40,17 @@ def test_number_not_in_tuple():
if self.number_in not in self.tuple_: if self.number_in not in self.tuple_:
ret += 1 ret += 1
if self.number_not_in not in self.tuple_: if self.number_not_in not in self.tuple_:
ret += 1
ret += 2
if self.number_in not in self.list_: if self.number_in not in self.list_:
ret += 3 ret += 3
if self.number_not_in not in self.list_: if self.number_not_in not in self.list_:
ret += 3
ret += 4
if self.str_in not in self.dict_: if self.str_in not in self.dict_:
ret += 5 ret += 5
if self.str_not_in not in self.dict_: if self.str_not_in not in self.dict_:
ret += 5
ret += 6
return ret return ret


net = Net() net = Net()
output = net() output = net()
assert output == 9
assert output == 12

Loading…
Cancel
Save