| @@ -14,8 +14,9 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Check parameters.""" | """Check parameters.""" | ||||
| import re | import re | ||||
| import inspect | |||||
| from enum import Enum | from enum import Enum | ||||
| from functools import reduce | |||||
| from functools import reduce, wraps | |||||
| from itertools import repeat | from itertools import repeat | ||||
| from collections.abc import Iterable | from collections.abc import Iterable | ||||
| @@ -181,7 +182,7 @@ class Validator: | |||||
| @staticmethod | @staticmethod | ||||
| def check_subclass(arg_name, type_, template_type, prim_name): | def check_subclass(arg_name, type_, template_type, prim_name): | ||||
| """Checks whether some type is sublcass of another type""" | |||||
| """Checks whether some type is subclass of another type""" | |||||
| if not isinstance(template_type, Iterable): | if not isinstance(template_type, Iterable): | ||||
| template_type = (template_type,) | template_type = (template_type,) | ||||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | if not any([mstype.issubclass_(type_, x) for x in template_type]): | ||||
| @@ -240,7 +241,6 @@ class Validator: | |||||
| elem_types = map(_check_tensor_type, args.items()) | elem_types = map(_check_tensor_type, args.items()) | ||||
| reduce(_check_types_same, elem_types) | reduce(_check_types_same, elem_types) | ||||
| @staticmethod | @staticmethod | ||||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): | def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): | ||||
| """ | """ | ||||
| @@ -261,7 +261,7 @@ class Validator: | |||||
| def _check_types_same(arg1, arg2): | def _check_types_same(arg1, arg2): | ||||
| arg1_name, arg1_type = arg1 | arg1_name, arg1_type = arg1 | ||||
| arg2_name, arg2_type = arg2 | arg2_name, arg2_type = arg2 | ||||
| excp_flag = False | |||||
| except_flag = False | |||||
| if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)): | if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)): | ||||
| arg1_type = arg1_type.element_type() | arg1_type = arg1_type.element_type() | ||||
| arg2_type = arg2_type.element_type() | arg2_type = arg2_type.element_type() | ||||
| @@ -271,9 +271,9 @@ class Validator: | |||||
| arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type | arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type | ||||
| arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type | arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type | ||||
| else: | else: | ||||
| excp_flag = True | |||||
| except_flag = True | |||||
| if excp_flag or arg1_type != arg2_type: | |||||
| if except_flag or arg1_type != arg2_type: | |||||
| raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' | raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' | ||||
| f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') | f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') | ||||
| return arg1 | return arg1 | ||||
| @@ -283,11 +283,12 @@ class Validator: | |||||
| def check_value_type(arg_name, arg_value, valid_types, prim_name): | def check_value_type(arg_name, arg_value, valid_types, prim_name): | ||||
| """Checks whether a value is instance of some types.""" | """Checks whether a value is instance of some types.""" | ||||
| valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | ||||
| def raise_error_msg(): | def raise_error_msg(): | ||||
| """func for raising error message when check failed""" | """func for raising error message when check failed""" | ||||
| type_names = [t.__name__ for t in valid_types] | type_names = [t.__name__ for t in valid_types] | ||||
| num_types = len(valid_types) | num_types = len(valid_types) | ||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | ||||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | ||||
| @@ -303,6 +304,7 @@ class Validator: | |||||
| def check_type_name(arg_name, arg_type, valid_types, prim_name): | def check_type_name(arg_name, arg_type, valid_types, prim_name): | ||||
| """Checks whether a type in some specified types""" | """Checks whether a type in some specified types""" | ||||
| valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | ||||
| def get_typename(t): | def get_typename(t): | ||||
| return t.__name__ if hasattr(t, '__name__') else str(t) | return t.__name__ if hasattr(t, '__name__') else str(t) | ||||
| @@ -368,9 +370,9 @@ class ParamValidator: | |||||
| @staticmethod | @staticmethod | ||||
| def check_isinstance(arg_name, arg_value, classes): | def check_isinstance(arg_name, arg_value, classes): | ||||
| """Check arg isintance of classes""" | |||||
| """Check arg isinstance of classes""" | |||||
| if not isinstance(arg_value, classes): | if not isinstance(arg_value, classes): | ||||
| raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.') | |||||
| raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') | |||||
| return arg_value | return arg_value | ||||
| @staticmethod | @staticmethod | ||||
| @@ -384,7 +386,7 @@ class ParamValidator: | |||||
| @staticmethod | @staticmethod | ||||
| def check_subclass(arg_name, type_, template_type, with_type_of=True): | def check_subclass(arg_name, type_, template_type, with_type_of=True): | ||||
| """Check whether some type is sublcass of another type""" | |||||
| """Check whether some type is subclass of another type""" | |||||
| if not isinstance(template_type, Iterable): | if not isinstance(template_type, Iterable): | ||||
| template_type = (template_type,) | template_type = (template_type,) | ||||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | if not any([mstype.issubclass_(type_, x) for x in template_type]): | ||||
| @@ -402,9 +404,9 @@ class ParamValidator: | |||||
| @staticmethod | @staticmethod | ||||
| def check_bool(arg_name, arg_value): | def check_bool(arg_name, arg_value): | ||||
| """Check arg isintance of bool""" | |||||
| """Check arg isinstance of bool""" | |||||
| if not isinstance(arg_value, bool): | if not isinstance(arg_value, bool): | ||||
| raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.') | |||||
| raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.') | |||||
| return arg_value | return arg_value | ||||
| @staticmethod | @staticmethod | ||||
| @@ -771,3 +773,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII): | |||||
| if re.match(reg, target, flag) is None: | if re.match(reg, target, flag) is None: | ||||
| raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag)) | raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag)) | ||||
| return True | return True | ||||
| def args_type_check(*type_args, **type_kwargs): | |||||
| """Check whether input data type is correct.""" | |||||
| def type_check(func): | |||||
| sig = inspect.signature(func) | |||||
| bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments | |||||
| @wraps(func) | |||||
| def wrapper(*args, **kwargs): | |||||
| nonlocal bound_types | |||||
| bound_values = sig.bind(*args, **kwargs) | |||||
| argument_dict = bound_values.arguments | |||||
| if "kwargs" in bound_types: | |||||
| bound_types = bound_types["kwargs"] | |||||
| if "kwargs" in argument_dict: | |||||
| argument_dict = argument_dict["kwargs"] | |||||
| for name, value in argument_dict.items(): | |||||
| if name in bound_types: | |||||
| if value is not None and not isinstance(value, bound_types[name]): | |||||
| raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) | |||||
| return func(*args, **kwargs) | |||||
| return wrapper | |||||
| return type_check | |||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """ | """ | ||||
| Extension functions. | |||||
| Extension functions. | |||||
| Python functions that will be called in the c++ parts of MindSpore. | Python functions that will be called in the c++ parts of MindSpore. | ||||
| """ | """ | ||||
| @@ -1,44 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Pynative mode help module.""" | |||||
| from inspect import signature | |||||
| from functools import wraps | |||||
| def args_type_check(*type_args, **type_kwargs): | |||||
| """Check whether input data type is correct.""" | |||||
| def type_check(func): | |||||
| sig = signature(func) | |||||
| bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments | |||||
| @wraps(func) | |||||
| def wrapper(*args, **kwargs): | |||||
| nonlocal bound_types | |||||
| bound_values = sig.bind(*args, **kwargs) | |||||
| argument_dict = bound_values.arguments | |||||
| if "kwargs" in bound_types: | |||||
| bound_types = bound_types["kwargs"] | |||||
| if "kwargs" in argument_dict: | |||||
| argument_dict = argument_dict["kwargs"] | |||||
| for name, value in argument_dict.items(): | |||||
| if name in bound_types: | |||||
| if value is not None and not isinstance(value, bound_types[name]): | |||||
| raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) | |||||
| return func(*args, **kwargs) | |||||
| return wrapper | |||||
| return type_check | |||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ | """ | ||||
| The context of mindspore, used to configure the current execution environment, | The context of mindspore, used to configure the current execution environment, | ||||
| including execution mode, execution backend and other feature switchs. | |||||
| including execution mode, execution backend and other feature switches. | |||||
| """ | """ | ||||
| import os | import os | ||||
| import threading | import threading | ||||
| @@ -22,7 +22,7 @@ from collections import namedtuple | |||||
| from types import FunctionType | from types import FunctionType | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore._c_expression import MSContext | from mindspore._c_expression import MSContext | ||||
| from mindspore._extends.pynative_helper import args_type_check | |||||
| from mindspore._checkparam import args_type_check | |||||
| from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ | from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ | ||||
| _reset_auto_parallel_context | _reset_auto_parallel_context | ||||
| @@ -38,7 +38,7 @@ def _make_directory(path: str): | |||||
| """Make directory.""" | """Make directory.""" | ||||
| real_path = None | real_path = None | ||||
| if path is None or not isinstance(path, str) or path.strip() == "": | if path is None or not isinstance(path, str) or path.strip() == "": | ||||
| raise ValueError(f"Input path `{path}` is invaild type") | |||||
| raise ValueError(f"Input path `{path}` is invalid type") | |||||
| # convert the relative paths | # convert the relative paths | ||||
| path = os.path.realpath(path) | path = os.path.realpath(path) | ||||
| @@ -63,6 +63,7 @@ class _ThreadLocalInfo(threading.local): | |||||
| """ | """ | ||||
| Thread local Info used for store thread local attributes. | Thread local Info used for store thread local attributes. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(_ThreadLocalInfo, self).__init__() | super(_ThreadLocalInfo, self).__init__() | ||||
| self._reserve_class_name_in_scope = True | self._reserve_class_name_in_scope = True | ||||
| @@ -90,6 +91,7 @@ class _ContextSwitchInfo(threading.local): | |||||
| Args: | Args: | ||||
| is_pynative (bool): Whether to adopt the PyNative mode. | is_pynative (bool): Whether to adopt the PyNative mode. | ||||
| """ | """ | ||||
| def __init__(self, is_pynative): | def __init__(self, is_pynative): | ||||
| super(_ContextSwitchInfo, self).__init__() | super(_ContextSwitchInfo, self).__init__() | ||||
| self.context_stack = [] | self.context_stack = [] | ||||
| @@ -209,7 +211,7 @@ class _Context: | |||||
| def device_target(self, target): | def device_target(self, target): | ||||
| success = self._context_handle.set_device_target(target) | success = self._context_handle.set_device_target(target) | ||||
| if not success: | if not success: | ||||
| raise ValueError("target device name is invalid!!!") | |||||
| raise ValueError("Target device name is invalid!!!") | |||||
| @property | @property | ||||
| def device_id(self): | def device_id(self): | ||||
| @@ -335,7 +337,7 @@ class _Context: | |||||
| @graph_memory_max_size.setter | @graph_memory_max_size.setter | ||||
| def graph_memory_max_size(self, graph_memory_max_size): | def graph_memory_max_size(self, graph_memory_max_size): | ||||
| if check_input_fotmat(graph_memory_max_size): | |||||
| if check_input_format(graph_memory_max_size): | |||||
| graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024" | graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024" | ||||
| self._context_handle.set_graph_memory_max_size(graph_memory_max_size_) | self._context_handle.set_graph_memory_max_size(graph_memory_max_size_) | ||||
| else: | else: | ||||
| @@ -347,7 +349,7 @@ class _Context: | |||||
| @variable_memory_max_size.setter | @variable_memory_max_size.setter | ||||
| def variable_memory_max_size(self, variable_memory_max_size): | def variable_memory_max_size(self, variable_memory_max_size): | ||||
| if check_input_fotmat(variable_memory_max_size): | |||||
| if check_input_format(variable_memory_max_size): | |||||
| variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" | variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" | ||||
| self._context_handle.set_variable_memory_max_size(variable_memory_max_size_) | self._context_handle.set_variable_memory_max_size(variable_memory_max_size_) | ||||
| else: | else: | ||||
| @@ -367,12 +369,13 @@ class _Context: | |||||
| thread_info.debug_runtime = enable | thread_info.debug_runtime = enable | ||||
| def check_input_fotmat(x): | |||||
| def check_input_format(x): | |||||
| import re | import re | ||||
| pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' | pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' | ||||
| result = re.match(pattern, x) | result = re.match(pattern, x) | ||||
| return result is not None | return result is not None | ||||
| _k_context = None | _k_context = None | ||||
| @@ -17,7 +17,7 @@ import threading | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size | from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size | ||||
| from mindspore._c_expression import AutoParallelContext | from mindspore._c_expression import AutoParallelContext | ||||
| from mindspore._extends.pynative_helper import args_type_check | |||||
| from mindspore._checkparam import args_type_check | |||||
| class _AutoParallelContext: | class _AutoParallelContext: | ||||
| @@ -15,7 +15,7 @@ | |||||
| """Context of cost_model in auto_parallel""" | """Context of cost_model in auto_parallel""" | ||||
| import threading | import threading | ||||
| from mindspore._c_expression import CostModelContext | from mindspore._c_expression import CostModelContext | ||||
| from mindspore._extends.pynative_helper import args_type_check | |||||
| from mindspore._checkparam import args_type_check | |||||
| class _CostModelContext: | class _CostModelContext: | ||||
| @@ -16,7 +16,7 @@ | |||||
| import threading | import threading | ||||
| from mindspore._c_expression import CostModelContext | from mindspore._c_expression import CostModelContext | ||||
| from mindspore._extends.pynative_helper import args_type_check | |||||
| from mindspore._checkparam import args_type_check | |||||
| __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] | __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] | ||||
| @@ -14,16 +14,13 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test_backend """ | """ test_backend """ | ||||
| import os | import os | ||||
| import numpy as np | |||||
| import pytest | import pytest | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | |||||
| from mindspore import context, ms_function | |||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore._extends.pynative_helper import args_type_check | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.common.api import ms_function | |||||
| from mindspore._checkparam import args_type_check | |||||
| def setup_module(module): | def setup_module(module): | ||||
| @@ -32,6 +29,7 @@ def setup_module(module): | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| """ Net definition """ | """ Net definition """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| self.add = P.TensorAdd() | self.add = P.TensorAdd() | ||||
| @@ -50,6 +48,7 @@ def test_vm_backend(): | |||||
| output = add() | output = add() | ||||
| assert output.asnumpy().shape == (1, 3, 3, 4) | assert output.asnumpy().shape == (1, 3, 3, 4) | ||||
| def test_vm_set_context(): | def test_vm_set_context(): | ||||
| """ test_vm_set_context """ | """ test_vm_set_context """ | ||||
| context.set_context(save_graphs=True, save_graphs_path="mindspore_ir_path", mode=context.GRAPH_MODE) | context.set_context(save_graphs=True, save_graphs_path="mindspore_ir_path", mode=context.GRAPH_MODE) | ||||
| @@ -59,6 +58,7 @@ def test_vm_set_context(): | |||||
| assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0 | assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0 | ||||
| context.set_context(mode=context.PYNATIVE_MODE) | context.set_context(mode=context.PYNATIVE_MODE) | ||||
| @args_type_check(v_str=str, v_int=int, v_tuple=tuple) | @args_type_check(v_str=str, v_int=int, v_tuple=tuple) | ||||
| def check_input(v_str, v_int, v_tuple): | def check_input(v_str, v_int, v_tuple): | ||||
| """ check_input """ | """ check_input """ | ||||