From 8d88f6d93595a00897d347b8d075b08441041c51 Mon Sep 17 00:00:00 2001 From: jinyaohui Date: Mon, 8 Feb 2021 17:15:34 +0800 Subject: [PATCH] modify broadcast --- mindspore/common/api.py | 55 ++++++++++++++++++--------------- mindspore/common/initializer.py | 2 ++ 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 5dfe8fe092..1312e8ee4a 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -225,11 +225,10 @@ def ms_function(fn=None, obj=None, input_signature=None): fn (Function): The Python function that will be run as a graph. Default: None. obj (Object): The Python Object that provides the information for identifying the compiled function.Default: None. - input_signature (MetaTensor): The MetaTensor which describes the input arguments. The MetaTensor specifies - the shape and dtype of the Tensor and they will be supplied to this function. If input_signature - is specified, each input to `fn` must be a `Tensor`. And the input parameters of `fn` cannot accept - `**kwargs`. The shape and dtype of actual inputs should keep the same as input_signature. Otherwise, - TypeError will be raised. Default: None. + input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor + will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`. + And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should + keep the same as input_signature. Otherwise, TypeError will be raised. Default: None. Returns: Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is @@ -259,8 +258,8 @@ def ms_function(fn=None, obj=None, input_signature=None): >>> out = tensor_add_with_dec(x, y) ... >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter - >>> @ms_function(input_signature=(MetaTensor(mindspore.float32, (1, 1, 3, 3)), - ... MetaTensor(mindspore.float32, (1, 1, 3, 3)))) + >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)), + ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)))) ... def tensor_add_with_sig(x, y): ... z = x + y ... return z @@ -299,13 +298,12 @@ def _generate_pip_args(obj, *args, method="construct"): def _get_auto_split_param_names(parameter_layout_dict): - auto_split_params = {} + auto_split_param_names = [] for key, value in parameter_layout_dict.items(): for dim in value[1]: if dim != -1: - auto_split_params[key] = value + auto_split_param_names.append(key) break - auto_split_param_names = (param_name for param_name in auto_split_params) return auto_split_param_names @@ -499,21 +497,7 @@ class _Executor: if not do_convert: return phase, True - if auto_parallel_mode: - obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) - if _get_pipeline_stages() > 1: - obj.parallel_parameter_name_list = self._executor.get_parallel_parameter_name_list(phase) - obj.remove_redundant_parameters() - replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) - if not enable_debug_runtime or enable_ge: - if auto_parallel_mode: - obj.load_parameter_slice(None) - - self._updata_param_node_default_input(phase, replace) - - # set parallel inputs in sink mode - if auto_parallel_mode and is_sink_mode: - obj.set_parallel_input_with_inputs(*args) + self._auto_parallel_process(obj, phase, is_sink_mode, auto_parallel_mode, *args) # the following GE init process is not needed when use vm or ms backend if enable_ge: @@ -529,6 +513,27 @@ class _Executor: return phase, True + def _auto_parallel_process(self, obj, phase, is_sink_mode, auto_parallel_mode, *args): + """compile graph in auto parallel mode.""" + if not auto_parallel_mode: + replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) + self._updata_param_node_default_input(phase, replace) + return + + obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) + if _get_pipeline_stages() > 1: + obj.parallel_parameter_name_list = self._executor.get_parallel_parameter_name_list(phase) + obj.remove_redundant_parameters() + replace = obj.init_parameters_data(auto_parallel_mode=True) + if not context.get_context("enable_debug_runtime") or context.get_context("enable_ge"): + obj.load_parameter_slice(None) + + self._updata_param_node_default_input(phase, replace) + + # set parallel inputs in sink mode + if is_sink_mode: + obj.set_parallel_input_with_inputs(*args) + def _updata_param_node_default_input(self, phase, replace): new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])} return self._executor.updata_param_node_default_input(phase, new_param) diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index d3b186e02a..42c2a18606 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -420,6 +420,8 @@ def initializer(init, shape=None, dtype=mstype.float32): Examples: + >>> import mindspore + >>> from mindspore.common.initializer import initializer, One >>> tensor = initializer('ones', [1, 2, 3], mindspore.float32) >>> tensor = initializer(One(), [1, 2, 3], mindspore.float32) >>> tensor = initializer(0, [1, 2, 3], mindspore.float32)