|
- # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
- #
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Providing interface methods."""
- import types
- from collections import OrderedDict
- from functools import wraps
- from mindspore import context
- from mindspore import log as logger
- from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
- from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
- from .tensor import Tensor as MsTensor
- from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor
- from ..parallel._ps_context import _is_role_pserver
- # store ms_function class compiled pipeline cache
- ms_compile_cache = {}
-
-
- def _convert_function_arguments(fn, *args):
- """
- Process the fn default parameters.
-
- Args:
- fn (Function): The function to be parsed.
- args (tuple): The parameters of the function.
- """
- arguments_dict = OrderedDict()
- parse_method = None
- if isinstance(fn, (types.FunctionType, types.MethodType)):
- parse_method = fn.__name__
- index = 0
- for value in args:
- arguments_dict[f'arg{index}'] = value
- index = index + 1
- logger.debug("fn(%r) full parameters dict is: %r", fn, arguments_dict)
- converted = True
- else:
- logger.warning("Find error: fn isn't function or method")
- converted = False
- return converted, arguments_dict, parse_method
-
-
- def _wrap_func(fn):
- """
- Wrapper function, convert return data to tensor or tuple of tensor.
-
- Args:
- fn (Function): The function need be wrapped.
-
- Returns:
- Function, a new function with return suitable format data.
- """
-
- @wraps(fn)
- def wrapper(*arg, **kwargs):
- results = fn(*arg, **kwargs)
-
- def _convert_data(data):
- if isinstance(data, Tensor) and not isinstance(data, MsTensor):
- return MsTensor(data)
- if isinstance(data, tuple):
- return tuple(_convert_data(x) for x in data)
- if isinstance(data, list):
- return list(_convert_data(x) for x in data)
- return data
-
- return _convert_data(results)
-
- return wrapper
-
-
- def _exec_init_graph(obj, init_phase):
- """Execute the parameter initializer graph."""
- inst_executor = Executor_.get_instance()
- param_dict = OrderedDict()
- for name, param in obj.parameters_dict().items():
- if not param.is_init:
- param_dict[name] = param
- param.is_init = True
- param.data.init_flag = True
-
- if param_dict:
- inst_executor.run_init_graph(param_dict, init_phase)
-
-
- class _MindSporeFunction:
- """
- Represents a function compiled by mind expression.
-
- _MindSporeFunction will compile the original function for every combination
- of argument types and shapes it is given (as well as their values, optionally).
-
- Args:
- fn (Function): The root function to compile.
- input_signature (Function): User defines signature to verify input.
- obj (Object): If function is a method, obj is the owner of function,
- else, obj is none.
- """
-
- def __init__(self, fn, input_signature=None, obj=None):
- self.fn = fn
- self.save_graphs = context.get_context("save_graphs")
- self.save_graphs_path = context.get_context("save_graphs_path")
- self.input_signature = input_signature
- self.obj = None
- self.identify_obj = None
- if hasattr(obj, fn.__name__):
- self.obj = obj
- elif obj is not None:
- self.identify_obj = obj
- self._executor = Executor_.get_instance()
-
- def build_data_init_graph(self, graph_name):
- """Build GE data graph and init graph for the given graph name."""
- if self.obj is None:
- logger.warning("Make sure parameter should not be used in function")
- para_dict = OrderedDict()
- self._executor.build_data_graph(para_dict, graph_name)
- return
- self._executor.build_data_graph(self.obj.parameters_dict(), graph_name, self.obj.parameters_broadcast_dict())
- init_phase = "init_subgraph" + graph_name[graph_name.find("."):]
- _exec_init_graph(self.obj, init_phase)
-
- def compile(self, arguments_dict, method_name):
- """Returns pipline for the given args."""
- args_list = tuple(arguments_dict.values())
- arg_names = tuple(arguments_dict.keys())
-
- # remove first self parameter when fn is a method
- if self.obj is not None:
- args_list = args_list[1:]
- arg_names = arg_names[1:]
-
- # verify the signature for both function and method
- if self.input_signature is not None:
- signatures = []
- for sig_spec in self.input_signature:
- if not isinstance(sig_spec, MetaTensor):
- raise TypeError("Input_signature is not MetaTensor")
- signatures.append(sig_spec)
- is_valid_input = verify_inputs_signature(signatures, args_list)
- if not is_valid_input:
- raise ValueError("Inputs is incompatible with input signature!")
-
- dic = dict(zip(arg_names, args_list))
- generate_name = self.fn.__module__ + "." + self.fn.__name__
- self.fn.__parse_method__ = method_name
-
- # replace key with obj info and object ext info when fn is a method
- if self.obj is not None:
- self.obj.__parse_method__ = method_name
- generate_name = self.obj.__module__ + "."
- if self.obj.__class__.__name__ != "ClipByNorm":
- generate_name = generate_name + str(self.obj.create_time)
- if self.identify_obj is not None:
- generate_name = generate_name + str(id(self.identify_obj))
-
- key = generate_key(generate_name, dic)
- phase = str(key[1]) + generate_name
- if key not in ms_compile_cache.keys():
- is_compile = False
- if self.obj is None:
- is_compile = self._executor.compile(self.fn, args_list, phase, True)
- else:
- is_compile = self._executor.compile(self.obj, args_list, phase, True)
- if not is_compile:
- raise RuntimeError("Executor compile failed.")
- if context.get_context("enable_ge"):
- self.build_data_init_graph(phase)
- # since function can be redefined, we only cache class method pipeline
- if self.obj is not None or self.identify_obj is not None:
- ms_compile_cache[key] = phase
- return phase
-
- return ms_compile_cache[key]
-
- @_wrap_func
- def __call__(self, *args):
- init_backend()
- converted, arguments_dict, parse_method = _convert_function_arguments(self.fn, *args)
- if not converted:
- raise RuntimeError('Process function parameter is failure')
-
- args_list = tuple(arguments_dict.values())
- if self.obj is not None:
- args_list = args_list[1:]
-
- phase = self.compile(arguments_dict, parse_method)
-
- if context.get_context("precompile_only"):
- return None
- return self._executor(args_list, phase)
-
-
- def ms_function(fn=None, obj=None, input_signature=None):
- """
- Create a callable MindSpore graph from a python function.
-
- This allows the MindSpore runtime to apply optimizations based on graph.
-
- Args:
- fn (Function): The Python function that will be run as a graph. Default: None.
- obj (Object): The Python Object that provides the information for identifying the compiled function.Default:
- None.
- input_signature (MetaTensor): The MetaTensor which describes the input arguments. The MetaTensor specifies
- the shape and dtype of the Tensor and they will be supplied to this function. If input_signature
- is specified, each input to `fn` must be a `Tensor`. And the input parameters of `fn` cannot accept
- `**kwargs`. The shape and dtype of actual inputs should keep the same as input_signature. Otherwise,
- TypeError will be raised. Default: None.
-
- Returns:
- Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
- None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
- equal to the case when `fn` is not None.
-
- Examples:
- >>> def tensor_add(x, y):
- >>> z = F.tensor_add(x, y)
- >>> return z
- >>>
- >>> @ms_function
- >>> def tensor_add_with_dec(x, y):
- >>> z = F.tensor_add(x, y)
- >>> return z
- >>>
- >>> @ms_function(input_signature=(MetaTensor(mindspore.float32, (1, 1, 3, 3)),
- >>> MetaTensor(mindspore.float32, (1, 1, 3, 3))))
- >>> def tensor_add_with_sig(x, y):
- >>> z = F.tensor_add(x, y)
- >>> return z
- >>>
- >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
- >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
- >>>
- >>> tensor_add_graph = ms_function(fn=tensor_add)
- >>> out = tensor_add_graph(x, y)
- >>> out = tensor_add_with_dec(x, y)
- >>> out = tensor_add_with_sig(x, y)
- """
-
- def wrap_mindspore(func):
- @wraps(func)
- def staging_specialize(*args):
- process_obj = obj
- if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
- process_obj = args[0]
- return _MindSporeFunction(func, input_signature, process_obj)(*args)
-
- return staging_specialize
-
- if fn is not None:
- return wrap_mindspore(fn)
- return wrap_mindspore
-
-
- def _generate_pip_args(obj, *args, method="construct"):
- """Generate arguments for pipeline."""
- if hasattr(obj, method):
- fn = getattr(obj, method)
- else:
- raise AttributeError('The process method is not exist')
- converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
- if not converted:
- raise RuntimeError('Process method parameter is failure')
- args_list = tuple(arguments_dict.values())
- args_names = tuple(arguments_dict.keys())
- obj.__parse_method__ = parse_method
- return args_names, args_list
-
-
- class _PynativeExecutor:
- """
- An pynative executor used to compile/manage/run graph.
-
- Returns:
- Graph, return the result of pipeline running.
- """
-
- def __init__(self):
- self._executor = PynativeExecutor_.get_instance()
-
- def new_graph(self, obj, *args, **kwargs):
- self._executor.new_graph(obj, *args, *(kwargs.values()))
-
- def end_graph(self, obj, output, *args, **kwargs):
- self._executor.end_graph(obj, output, *args, *(kwargs.values()))
-
- def grad(self, grad, obj, weights, *args, **kwargs):
- self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values()))
-
- def clear(self, flag=""):
- self._executor.clear(flag)
-
- def set_grad_flag(self, flag):
- self._executor.set_grad_flag(flag)
-
- def __call__(self, *args, **kwargs):
- args = args + tuple(kwargs.values())
- return self._executor(args, "")
-
-
- class _Executor:
- """
- An executor used to compile/manage/run graph.
-
- Including data_graph, train_graph, eval_graph and predict graph.
-
- Returns:
- Graph, return the result of pipeline running.
- """
-
- def __init__(self):
- # create needed graph by lazy mode
- self.is_init = False
- self._executor = Executor_.get_instance()
- self.compile_cache = {}
- self.phase_prefix = ""
-
- def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
- input_indexs, phase='dataset'):
- """
- Initialization interface for calling data subgraph.
-
- Args:
- queue_name (str): The name of tdt queue on the device.
- dataset_size (int): The size of dataset.
- batch_size (int): The size of batch.
- dataset_types (list): The output types of element in dataset.
- dataset_shapes (list): The output shapes of element in dataset.
- input_indexs (list): The index of data with net.
- phase (str): The name of phase, e.g., train_dataset/eval_dataset. Default: 'dataset'.
-
- Returns:
- bool, specifies whether the data subgraph was initialized successfully.
- """
- if not init_exec_dataset(queue_name=queue_name,
- size=dataset_size,
- batch_size=batch_size,
- types=dataset_types,
- shapes=dataset_shapes,
- input_indexs=input_indexs,
- phase=phase):
- raise RuntimeError("Failure to init and dataset subgraph!")
- return True
-
- def _build_data_graph(self, obj, phase):
- self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
-
- def _set_dataset_mode(self, args_list):
- """set dataset mode."""
- # decide whether to sink based on whether the inputs is virtual or args_list is ()
- if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \
- (args_list is not None and args_list == ()):
- _set_dataset_mode_config('sink')
- else:
- _set_dataset_mode_config('normal')
-
- def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False):
- """
- Compiles graph.
-
- Args:
- obj (Function/Cell): The function or cell instance need compile.
- args (tuple): Function or cell input arguments.
- phase (str): The name of compile phase. Default: 'predict'.
- do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
- auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
-
- Return:
- Str, the full phase of the cell.
- Bool, if the graph has been compiled before, return False, else return True.
- """
- obj.check_names()
- _check_full_batch()
- args_names, args_list = _generate_pip_args(obj, *args)
- dic = dict(zip(args_names, args_list))
- key = generate_key(phase, dic)
- self.phase_prefix = str(key[1])
- if 'export' in phase:
- phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time)
- else:
- phase = self.phase_prefix + phase + '.' + str(obj.create_time)
- enable_debug_runtime = context.get_context("enable_debug_runtime")
- enable_ge = context.get_context("enable_ge")
-
- use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
-
- self._set_dataset_mode(args_list)
-
- if phase in self.compile_cache.keys():
- logger.debug("%r graph has existed.", phase)
- return phase, False
-
- is_sink_mode = args and isinstance(args[0], Tensor) and args[0].virtual_flag
- if auto_parallel_mode and _need_to_full() and not is_sink_mode and obj.auto_parallel_compile_and_run():
- args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank())
- _, args_list = _generate_pip_args(obj, *args_full)
-
- result = self._executor.compile(obj, args_list, phase, use_vm)
- self.compile_cache[phase] = phase
- if not result:
- raise RuntimeError("Executor compile failed.")
- graph = self._executor.get_func_graph(phase)
-
- if graph is None:
- logger.error("%r graph compile failed.", phase)
- if not do_convert:
- return phase, True
-
- if auto_parallel_mode:
- obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
- replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
- if not enable_debug_runtime or enable_ge:
- if auto_parallel_mode:
- obj.load_parameter_slice(None)
-
- self._updata_param_node_default_input(phase, replace)
-
- # set parallel inputs in sink mode
- if auto_parallel_mode and is_sink_mode:
- obj.set_parallel_input_with_inputs(*args)
-
- # the following GE init process is not needed when use vm or ms backend
- if enable_ge:
- self._build_data_graph(obj, phase)
-
- if "export" not in phase:
- init_phase = "init_subgraph" + "." + str(obj.create_time)
- _exec_init_graph(obj, init_phase)
- elif not enable_ge and "export" in phase:
- self._build_data_graph(obj, phase)
-
- return phase, True
-
- def _updata_param_node_default_input(self, phase, replace):
- new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
- return self._executor.updata_param_node_default_input(phase, new_param)
-
- def _get_shard_strategy(self, obj):
- real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
- return self._executor.get_strategy(real_phase)
-
- def _get_allreduce_fusion(self, obj):
- real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
- return self._executor.get_allreduce_fusion(real_phase)
-
- def has_compiled(self, phase='predict'):
- """
- Specify whether have been compiled.
-
- Args:
- phase (str): The phase name. Default: 'predict'.
-
- Returns:
- bool, specifies whether the specific graph has been compiled.
- """
- return self._executor.has_compiled(phase)
-
- def __call__(self, obj, *args, phase='predict'):
- if context.get_context("precompile_only") or _is_role_pserver():
- return None
- return self.run(obj, *args, phase=phase)
-
- @_wrap_func
- def _exec_pip(self, obj, *args, phase=''):
- """Execute the generated pipeline."""
- fn = obj.construct
- converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
- if not converted:
- raise RuntimeError('Process method parameter is failure')
- args_list = tuple(arguments_dict.values())
- obj.__parse_method__ = parse_method
- return self._executor(args_list, phase)
-
- def run(self, obj, *args, phase='predict'):
- """
- Run the specific graph.
-
- Args:
- phase (str): The phase name. Default: 'predict'.
-
- Returns:
- Tensor/Tuple, return execute result.
- """
- if phase == 'save':
- return self._executor((), phase + '.' + str(obj.create_time))
-
- phase_real = self.phase_prefix + phase + '.' + str(obj.create_time)
- if self.has_compiled(phase_real):
- return self._exec_pip(obj, *args, phase=phase_real)
- raise KeyError('{} graph is not exist.'.format(phase_real))
-
- def del_net_res(self, net_id):
- self._executor.del_net_res(net_id)
-
- def _get_func_graph_proto(self, exec_id, ir_type="onnx_ir", use_prefix=False):
- """Get graph proto from pipeline."""
- if use_prefix:
- exec_id = self.phase_prefix + exec_id
- if self._executor.has_compiled(exec_id) is False:
- return None
- return self._executor.get_func_graph_proto(exec_id, ir_type)
-
- def export(self, file_name, graph_id):
- """
- Export graph.
-
- Args:
- file_name (str): File name of model to export
- graph_id (str): id of graph to be exported
- """
- from .._c_expression import export_graph
- export_graph(file_name, 'AIR', graph_id)
-
- def fetch_info_for_quant_export(self, exec_id):
- """Get graph proto from pipeline."""
- if self._executor.has_compiled(exec_id) is False:
- return None
- return self._executor.fetch_info_for_quant_export(exec_id)
-
-
- _executor = _Executor()
- _pynative_exec = _PynativeExecutor()
-
- __all__ = ['ms_function']
|