Merge pull request !6400 from zhangbuxue/add_overflow_check_for_make_range_and_optimize_isinstance_processingtags/v1.0.0
| @@ -173,7 +173,7 @@ def check_type_same(x_type, base_type): | |||||
| """Check x_type is same as base_type.""" | """Check x_type is same as base_type.""" | ||||
| if mstype.issubclass_(x_type, base_type): | if mstype.issubclass_(x_type, base_type): | ||||
| return True | return True | ||||
| raise TypeError(f"The arg 'x' should be a {base_type}, but got {x_type}.") | |||||
| return False | |||||
| @constexpr | @constexpr | ||||
| @@ -489,15 +489,25 @@ AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| if (slide.step <= 0) { | if (slide.step <= 0) { | ||||
| MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; | MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; | ||||
| } | } | ||||
| for (int i = slide.start; i < slide.stop; i += slide.step) { | for (int i = slide.start; i < slide.stop; i += slide.step) { | ||||
| args.push_back(abstract::FromValue(i)); | args.push_back(abstract::FromValue(i)); | ||||
| if (i > 0 && INT_MAX - i < slide.step) { | |||||
| MS_EXCEPTION(ValueError) << "For make range, the required cycles number is greater than max cycles number, " | |||||
| "will cause integer overflow."; | |||||
| } | |||||
| } | } | ||||
| } else { | } else { | ||||
| if (slide.step >= 0) { | if (slide.step >= 0) { | ||||
| MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; | MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; | ||||
| } | } | ||||
| for (int i = slide.start; i > slide.stop; i += slide.step) { | for (int i = slide.start; i > slide.stop; i += slide.step) { | ||||
| args.push_back(abstract::FromValue(i)); | args.push_back(abstract::FromValue(i)); | ||||
| if (i < 0 && INT_MIN - i > slide.step) { | |||||
| MS_EXCEPTION(ValueError) << "For make range, the required cycles number is greater than max cycles number, " | |||||
| "will cause integer overflow."; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -268,7 +268,7 @@ def _tensor_index_by_tuple_slice(data, t): | |||||
| def tensor_index_by_tuple(data, tuple_index): | def tensor_index_by_tuple(data, tuple_index): | ||||
| """Tensor getitem by tuple of various types""" | """Tensor getitem by tuple of various types""" | ||||
| if len(tuple_index) == 1: | if len(tuple_index) == 1: | ||||
| return data[tuple_index[0]] | |||||
| return data[tuple_index[0]] | |||||
| indexes_types = hyper_map(F.typeof, tuple_index) | indexes_types = hyper_map(F.typeof, tuple_index) | ||||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) | index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) | ||||
| if index_elements_type == const_utils.NO_TENSOR: | if index_elements_type == const_utils.NO_TENSOR: | ||||
| @@ -40,17 +40,17 @@ def test_number_not_in_tuple(): | |||||
| if self.number_in not in self.tuple_: | if self.number_in not in self.tuple_: | ||||
| ret += 1 | ret += 1 | ||||
| if self.number_not_in not in self.tuple_: | if self.number_not_in not in self.tuple_: | ||||
| ret += 1 | |||||
| ret += 2 | |||||
| if self.number_in not in self.list_: | if self.number_in not in self.list_: | ||||
| ret += 3 | ret += 3 | ||||
| if self.number_not_in not in self.list_: | if self.number_not_in not in self.list_: | ||||
| ret += 3 | |||||
| ret += 4 | |||||
| if self.str_in not in self.dict_: | if self.str_in not in self.dict_: | ||||
| ret += 5 | ret += 5 | ||||
| if self.str_not_in not in self.dict_: | if self.str_not_in not in self.dict_: | ||||
| ret += 5 | |||||
| ret += 6 | |||||
| return ret | return ret | ||||
| net = Net() | net = Net() | ||||
| output = net() | output = net() | ||||
| assert output == 9 | |||||
| assert output == 12 | |||||