|
|
|
@@ -225,11 +225,10 @@ def ms_function(fn=None, obj=None, input_signature=None): |
|
|
|
fn (Function): The Python function that will be run as a graph. Default: None. |
|
|
|
obj (Object): The Python Object that provides the information for identifying the compiled function.Default: |
|
|
|
None. |
|
|
|
input_signature (MetaTensor): The MetaTensor which describes the input arguments. The MetaTensor specifies |
|
|
|
the shape and dtype of the Tensor and they will be supplied to this function. If input_signature |
|
|
|
is specified, each input to `fn` must be a `Tensor`. And the input parameters of `fn` cannot accept |
|
|
|
`**kwargs`. The shape and dtype of actual inputs should keep the same as input_signature. Otherwise, |
|
|
|
TypeError will be raised. Default: None. |
|
|
|
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor |
|
|
|
will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`. |
|
|
|
And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should |
|
|
|
keep the same as input_signature. Otherwise, TypeError will be raised. Default: None. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is |
|
|
|
@@ -259,8 +258,8 @@ def ms_function(fn=None, obj=None, input_signature=None): |
|
|
|
>>> out = tensor_add_with_dec(x, y) |
|
|
|
... |
|
|
|
>>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter |
|
|
|
>>> @ms_function(input_signature=(MetaTensor(mindspore.float32, (1, 1, 3, 3)), |
|
|
|
... MetaTensor(mindspore.float32, (1, 1, 3, 3)))) |
|
|
|
>>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)), |
|
|
|
... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)))) |
|
|
|
... def tensor_add_with_sig(x, y): |
|
|
|
... z = x + y |
|
|
|
... return z |
|
|
|
@@ -299,13 +298,12 @@ def _generate_pip_args(obj, *args, method="construct"): |
|
|
|
|
|
|
|
|
|
|
|
def _get_auto_split_param_names(parameter_layout_dict): |
|
|
|
auto_split_params = {} |
|
|
|
auto_split_param_names = [] |
|
|
|
for key, value in parameter_layout_dict.items(): |
|
|
|
for dim in value[1]: |
|
|
|
if dim != -1: |
|
|
|
auto_split_params[key] = value |
|
|
|
auto_split_param_names.append(key) |
|
|
|
break |
|
|
|
auto_split_param_names = (param_name for param_name in auto_split_params) |
|
|
|
return auto_split_param_names |
|
|
|
|
|
|
|
|
|
|
|
@@ -499,21 +497,7 @@ class _Executor: |
|
|
|
if not do_convert: |
|
|
|
return phase, True |
|
|
|
|
|
|
|
if auto_parallel_mode: |
|
|
|
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) |
|
|
|
if _get_pipeline_stages() > 1: |
|
|
|
obj.parallel_parameter_name_list = self._executor.get_parallel_parameter_name_list(phase) |
|
|
|
obj.remove_redundant_parameters() |
|
|
|
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) |
|
|
|
if not enable_debug_runtime or enable_ge: |
|
|
|
if auto_parallel_mode: |
|
|
|
obj.load_parameter_slice(None) |
|
|
|
|
|
|
|
self._updata_param_node_default_input(phase, replace) |
|
|
|
|
|
|
|
# set parallel inputs in sink mode |
|
|
|
if auto_parallel_mode and is_sink_mode: |
|
|
|
obj.set_parallel_input_with_inputs(*args) |
|
|
|
self._auto_parallel_process(obj, phase, is_sink_mode, auto_parallel_mode, *args) |
|
|
|
|
|
|
|
# the following GE init process is not needed when use vm or ms backend |
|
|
|
if enable_ge: |
|
|
|
@@ -529,6 +513,27 @@ class _Executor: |
|
|
|
|
|
|
|
return phase, True |
|
|
|
|
|
|
|
def _auto_parallel_process(self, obj, phase, is_sink_mode, auto_parallel_mode, *args): |
|
|
|
"""compile graph in auto parallel mode.""" |
|
|
|
if not auto_parallel_mode: |
|
|
|
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) |
|
|
|
self._updata_param_node_default_input(phase, replace) |
|
|
|
return |
|
|
|
|
|
|
|
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) |
|
|
|
if _get_pipeline_stages() > 1: |
|
|
|
obj.parallel_parameter_name_list = self._executor.get_parallel_parameter_name_list(phase) |
|
|
|
obj.remove_redundant_parameters() |
|
|
|
replace = obj.init_parameters_data(auto_parallel_mode=True) |
|
|
|
if not context.get_context("enable_debug_runtime") or context.get_context("enable_ge"): |
|
|
|
obj.load_parameter_slice(None) |
|
|
|
|
|
|
|
self._updata_param_node_default_input(phase, replace) |
|
|
|
|
|
|
|
# set parallel inputs in sink mode |
|
|
|
if is_sink_mode: |
|
|
|
obj.set_parallel_input_with_inputs(*args) |
|
|
|
|
|
|
|
def _updata_param_node_default_input(self, phase, replace): |
|
|
|
new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])} |
|
|
|
return self._executor.updata_param_node_default_input(phase, new_param) |
|
|
|
|