|
|
|
@@ -298,6 +298,49 @@ def _generate_pip_args(obj, *args, method="construct"): |
|
|
|
return args_names, args_list |
|
|
|
|
|
|
|
|
|
|
|
def _get_auto_split_param_names(parameter_layout_dict): |
|
|
|
auto_split_params = {} |
|
|
|
for key, value in parameter_layout_dict.items(): |
|
|
|
for dim in value[1]: |
|
|
|
if dim != -1: |
|
|
|
auto_split_params[key] = value |
|
|
|
break |
|
|
|
auto_split_param_names = (param_name for param_name in auto_split_params) |
|
|
|
return auto_split_param_names |
|
|
|
|
|
|
|
|
|
|
|
def _build_broadcast_graph(broadcast_params_dict, broadcast_phase): |
|
|
|
"""Build broadcast graph.""" |
|
|
|
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell |
|
|
|
|
|
|
|
if not broadcast_params_dict: |
|
|
|
broadcast_params_dict = {} |
|
|
|
broadcast_params = [] |
|
|
|
for param in broadcast_params_dict.values(): |
|
|
|
broadcast_params.append(Tensor(param.asnumpy())) |
|
|
|
_broadcast_net = _BroadCastCell(broadcast_params) |
|
|
|
_broadcast_net.phase = broadcast_phase |
|
|
|
broadcasted_params = _broadcast_net() |
|
|
|
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params): |
|
|
|
broadcast_params_dict[param_name].set_data(param) |
|
|
|
|
|
|
|
|
|
|
|
def _parameter_broadcast(obj, auto_parallel_mode): |
|
|
|
"""Parameter broadcast.""" |
|
|
|
auto_split_param_names = [] |
|
|
|
if auto_parallel_mode: |
|
|
|
auto_split_param_names = _get_auto_split_param_names(obj.parameter_layout_dict) |
|
|
|
|
|
|
|
broadcast_params_dict = obj.parameters_broadcast_dict() |
|
|
|
if auto_split_param_names and broadcast_params_dict: |
|
|
|
broadcast_params_dict = OrderedDict() |
|
|
|
for param_name, param in obj.parameters_broadcast_dict().items(): |
|
|
|
if param_name not in auto_split_param_names: |
|
|
|
broadcast_params_dict[param_name] = param |
|
|
|
broadcast_phase = "_broadcast_subgraph" |
|
|
|
_build_broadcast_graph(broadcast_params_dict, broadcast_phase) |
|
|
|
|
|
|
|
|
|
|
|
class _PynativeExecutor: |
|
|
|
""" |
|
|
|
An pynative executor used to compile/manage/run graph. |
|
|
|
@@ -339,6 +382,10 @@ class _PynativeExecutor: |
|
|
|
def leave_construct(self, cell): |
|
|
|
self._executor.leave_construct(cell) |
|
|
|
|
|
|
|
def parameter_broadcast(self, obj, phase, auto_parallel_mode): |
|
|
|
if BROADCAST_PHASE not in phase and _get_parameter_broadcast(): |
|
|
|
_parameter_broadcast(obj, auto_parallel_mode) |
|
|
|
|
|
|
|
def __call__(self, obj, *args, **kwargs): |
|
|
|
args = args + tuple(kwargs.values()) |
|
|
|
return self._executor(obj, args, "") |
|
|
|
@@ -391,31 +438,6 @@ class _Executor: |
|
|
|
def _build_data_graph(self, obj, phase): |
|
|
|
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict()) |
|
|
|
|
|
|
|
def _get_auto_split_param_names(self, parameter_layout_dict): |
|
|
|
auto_split_params = {} |
|
|
|
for key, value in parameter_layout_dict.items(): |
|
|
|
for dim in value[1]: |
|
|
|
if dim != -1: |
|
|
|
auto_split_params[key] = value |
|
|
|
break |
|
|
|
auto_split_param_names = (param_name for param_name in auto_split_params) |
|
|
|
return auto_split_param_names |
|
|
|
|
|
|
|
def _build_broadcast_graph(self, broadcast_params_dict, broadcast_phase): |
|
|
|
"""Build broadcast graph.""" |
|
|
|
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell |
|
|
|
|
|
|
|
if not broadcast_params_dict: |
|
|
|
broadcast_params_dict = {} |
|
|
|
broadcast_params = [] |
|
|
|
for param in broadcast_params_dict.values(): |
|
|
|
broadcast_params.append(Tensor(param.asnumpy())) |
|
|
|
_broadcast_net = _BroadCastCell(broadcast_params) |
|
|
|
_broadcast_net.phase = broadcast_phase |
|
|
|
broadcasted_params = _broadcast_net() |
|
|
|
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params): |
|
|
|
broadcast_params_dict[param_name].set_data(param) |
|
|
|
|
|
|
|
def _set_dataset_mode(self, args_list): |
|
|
|
"""set dataset mode.""" |
|
|
|
# decide whether to sink based on whether the inputs is virtual or args_list is () |
|
|
|
@@ -500,18 +522,7 @@ class _Executor: |
|
|
|
elif not enable_ge and "export" in phase: |
|
|
|
self._build_data_graph(obj, phase) |
|
|
|
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast(): |
|
|
|
auto_split_param_names = [] |
|
|
|
if auto_parallel_mode: |
|
|
|
auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict) |
|
|
|
|
|
|
|
broadcast_params_dict = obj.parameters_broadcast_dict() |
|
|
|
if auto_split_param_names and broadcast_params_dict: |
|
|
|
broadcast_params_dict = OrderedDict() |
|
|
|
for param_name, param in obj.parameters_broadcast_dict().items(): |
|
|
|
if param_name not in auto_split_param_names: |
|
|
|
broadcast_params_dict[param_name] = param |
|
|
|
broadcast_phase = "_broadcast_subgraph" |
|
|
|
self._build_broadcast_graph(broadcast_params_dict, broadcast_phase) |
|
|
|
_parameter_broadcast(obj, auto_parallel_mode) |
|
|
|
|
|
|
|
return phase, True |
|
|
|
|
|
|
|
|