Browse Source

!11338 Support parameter broadcast in data parallel mode under PyNaitve

From: @zuochuanyong
Reviewed-by: @zhoufeng54,@chujinjin
Signed-off-by: @chujinjin
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
b988780fd7
2 changed files with 55 additions and 37 deletions
  1. +48
    -37
      mindspore/common/api.py
  2. +7
    -0
      mindspore/nn/cell.py

+ 48
- 37
mindspore/common/api.py View File

@@ -298,6 +298,49 @@ def _generate_pip_args(obj, *args, method="construct"):
return args_names, args_list


def _get_auto_split_param_names(parameter_layout_dict):
auto_split_params = {}
for key, value in parameter_layout_dict.items():
for dim in value[1]:
if dim != -1:
auto_split_params[key] = value
break
auto_split_param_names = (param_name for param_name in auto_split_params)
return auto_split_param_names


def _build_broadcast_graph(broadcast_params_dict, broadcast_phase):
"""Build broadcast graph."""
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell

if not broadcast_params_dict:
broadcast_params_dict = {}
broadcast_params = []
for param in broadcast_params_dict.values():
broadcast_params.append(Tensor(param.asnumpy()))
_broadcast_net = _BroadCastCell(broadcast_params)
_broadcast_net.phase = broadcast_phase
broadcasted_params = _broadcast_net()
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
broadcast_params_dict[param_name].set_data(param)


def _parameter_broadcast(obj, auto_parallel_mode):
"""Parameter broadcast."""
auto_split_param_names = []
if auto_parallel_mode:
auto_split_param_names = _get_auto_split_param_names(obj.parameter_layout_dict)

broadcast_params_dict = obj.parameters_broadcast_dict()
if auto_split_param_names and broadcast_params_dict:
broadcast_params_dict = OrderedDict()
for param_name, param in obj.parameters_broadcast_dict().items():
if param_name not in auto_split_param_names:
broadcast_params_dict[param_name] = param
broadcast_phase = "_broadcast_subgraph"
_build_broadcast_graph(broadcast_params_dict, broadcast_phase)


class _PynativeExecutor:
"""
An pynative executor used to compile/manage/run graph.
@@ -339,6 +382,10 @@ class _PynativeExecutor:
def leave_construct(self, cell):
self._executor.leave_construct(cell)

def parameter_broadcast(self, obj, phase, auto_parallel_mode):
if BROADCAST_PHASE not in phase and _get_parameter_broadcast():
_parameter_broadcast(obj, auto_parallel_mode)

def __call__(self, obj, *args, **kwargs):
args = args + tuple(kwargs.values())
return self._executor(obj, args, "")
@@ -391,31 +438,6 @@ class _Executor:
def _build_data_graph(self, obj, phase):
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())

def _get_auto_split_param_names(self, parameter_layout_dict):
auto_split_params = {}
for key, value in parameter_layout_dict.items():
for dim in value[1]:
if dim != -1:
auto_split_params[key] = value
break
auto_split_param_names = (param_name for param_name in auto_split_params)
return auto_split_param_names

def _build_broadcast_graph(self, broadcast_params_dict, broadcast_phase):
"""Build broadcast graph."""
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell

if not broadcast_params_dict:
broadcast_params_dict = {}
broadcast_params = []
for param in broadcast_params_dict.values():
broadcast_params.append(Tensor(param.asnumpy()))
_broadcast_net = _BroadCastCell(broadcast_params)
_broadcast_net.phase = broadcast_phase
broadcasted_params = _broadcast_net()
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
broadcast_params_dict[param_name].set_data(param)

def _set_dataset_mode(self, args_list):
"""set dataset mode."""
# decide whether to sink based on whether the inputs is virtual or args_list is ()
@@ -500,18 +522,7 @@ class _Executor:
elif not enable_ge and "export" in phase:
self._build_data_graph(obj, phase)
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
auto_split_param_names = []
if auto_parallel_mode:
auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict)

broadcast_params_dict = obj.parameters_broadcast_dict()
if auto_split_param_names and broadcast_params_dict:
broadcast_params_dict = OrderedDict()
for param_name, param in obj.parameters_broadcast_dict().items():
if param_name not in auto_split_param_names:
broadcast_params_dict[param_name] = param
broadcast_phase = "_broadcast_subgraph"
self._build_broadcast_graph(broadcast_params_dict, broadcast_phase)
_parameter_broadcast(obj, auto_parallel_mode)

return phase, True



+ 7
- 0
mindspore/nn/cell.py View File

@@ -23,6 +23,7 @@ import numpy

from mindspore import log as logger
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from mindspore.context import ParallelMode
from .. import context
from .._c_expression import init_pipeline, Cell_
from .._checkparam import Validator
@@ -90,6 +91,7 @@ class Cell(Cell_):
self._parameter_layout_dict = {}
self._create_time = int(time.time() * 1e9)
self.phase_prefix = ""
self.parameter_broadcast_done = False
init_pipeline()

# call gc to release GE session resources used by non-used cell objects
@@ -300,6 +302,11 @@ class Cell(Cell_):
out = self.compile_and_run(*inputs)
return out

if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
if not self.parameter_broadcast_done:
_pynative_exec.parameter_broadcast(self, self.phase, self._auto_parallel_mode)
self.parameter_broadcast_done = True

for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.")


Loading…
Cancel
Save