Browse Source

delete useless parameter in pipeline parallel

tags/v1.2.0-rc1
yangzhenzhang 5 years ago
parent
commit
cbca482e59
9 changed files with 75 additions and 4 deletions
  1. +16
    -0
      mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc
  2. +1
    -0
      mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h
  3. +2
    -0
      mindspore/ccsrc/pipeline/jit/init.cc
  4. +6
    -0
      mindspore/ccsrc/pipeline/jit/pipeline.cc
  5. +1
    -0
      mindspore/ccsrc/pipeline/jit/pipeline.h
  6. +4
    -1
      mindspore/common/api.py
  7. +35
    -2
      mindspore/nn/cell.py
  8. +4
    -0
      mindspore/parallel/_utils.py
  9. +6
    -1
      tests/ut/python/parallel/test_pipeline_split.py

+ 16
- 0
mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc View File

@@ -79,5 +79,21 @@ py::dict GetAllreduceFusion(const FuncGraphPtr &graph) {
}
return dict;
}

// In pipeline parallel mode, many parameters are not used and need to be deleted
py::list GetParallelParameterNameList(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);

py::list parallel_parameter_name_list;
std::vector<AnfNodePtr> graph_params = graph->parameters();

for (auto param : graph_params) {
auto param_ptr = std::static_pointer_cast<Parameter>(param);
MS_EXCEPTION_IF_NULL(param_ptr);
std::string name = param_ptr->name();
parallel_parameter_name_list.append(name);
}
return parallel_parameter_name_list;
}
} // namespace parallel
} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h View File

@@ -26,6 +26,7 @@ namespace mindspore {
namespace parallel {
py::dict GetParameterLayout(const FuncGraphPtr &graph);
py::dict GetAllreduceFusion(const FuncGraphPtr &graph);
py::list GetParallelParameterNameList(const FuncGraphPtr &graph);
} // namespace parallel
} // namespace mindspore



+ 2
- 0
mindspore/ccsrc/pipeline/jit/init.cc View File

@@ -80,6 +80,8 @@ PYBIND11_MODULE(_c_expression, m) {
py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.")
.def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"),
"Get Parameter Tensor Layout Dictionary.")
.def("get_parallel_parameter_name_list", &ExecutorPy::GetParallelParameterNameList,
py::arg("phase") = py::str("train"), "Get Parallel Parameter Name List.")
.def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"),
"Get CNode Strategy Dictionary.")
.def("get_num_parallel_ops", &ExecutorPy::GetNumOpsInfo, py::arg("phase") = py::str("train"),


+ 6
- 0
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -286,6 +286,12 @@ py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) {
return stra_dict_[phase];
}

py::list ExecutorPy::GetParallelParameterNameList(const std::string &phase) {
std::string param_graph = phase + kStepParallelGraph;
auto graph = GetFuncGraph(param_graph);
return mindspore::parallel::GetParallelParameterNameList(graph);
}

void ExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) {
MS_LOG(DEBUG) << "SetCNodeStrategy!";
stra_dict_[phase_][py::str(name)] = strategy;


+ 1
- 0
mindspore/ccsrc/pipeline/jit/pipeline.h View File

@@ -93,6 +93,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
void PyExePath(const py::object &phase);
py::dict GetParameterLayout(const std::string &phase);
py::dict GetCNodeStrategy(const std::string &phase);
py::list GetParallelParameterNameList(const std::string &phase);
void SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy);
size_t GetNumOpsInfo(const std::string &phase);
void SetNumOpsInfo(size_t);


+ 4
- 1
mindspore/common/api.py View File

@@ -27,7 +27,7 @@ from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, Pynativ
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
from ..parallel._ps_context import _is_role_pserver
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
_get_parameter_broadcast
_get_parameter_broadcast, _get_pipeline_stages

# store ms_function class compiled pipeline cache
ms_compile_cache = {}
@@ -501,6 +501,9 @@ class _Executor:

if auto_parallel_mode:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
if _get_pipeline_stages() > 1:
obj.parallel_parameter_name_list = self._executor.get_parallel_parameter_name_list(phase)
obj.remove_redundant_parameters()
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
if not enable_debug_runtime or enable_ge:
if auto_parallel_mode:


+ 35
- 2
mindspore/nn/cell.py View File

@@ -89,6 +89,7 @@ class Cell(Cell_):
self._scope = None
self._phase = 'train'
self._parameter_layout_dict = {}
self._parallel_parameter_name_list = ()
self._create_time = int(time.time() * 1e9)
self.phase_prefix = ""
self.parameter_broadcast_done = False
@@ -213,6 +214,16 @@ class Cell(Cell_):
raise TypeError("'parameter_layout_dict' must be dict type.")
self._parameter_layout_dict = value

@property
def parallel_parameter_name_list(self):
return self._parallel_parameter_name_list

@parallel_parameter_name_list.setter
def parallel_parameter_name_list(self, value):
if not isinstance(value, list):
raise TypeError("'parallel_parameter_name_list' must be list type.")
self._parallel_parameter_name_list = value

def get_func_graph_proto(self):
"""Return graph binary proto."""
return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time), "anf_ir", True)
@@ -656,6 +667,28 @@ class Cell(Cell_):
"""
return None

def remove_redundant_parameters(self):
"""Remove the redundant parameters"""
cells = self.cells_and_names()
for _, cell in cells:
params = cell._params.items()
for param_name, param in list(params):
if param.name not in self.parallel_parameter_name_list:
cell._params.pop(param_name)
logger.info("remove the redundant parameter: %s", param.name)
continue
cell_dict = cell.__dict__
for key in cell_dict:
if isinstance(cell_dict[key], ParameterTuple):
param_tuple = cell_dict[key]
new_param_tuple = []
for param in param_tuple:
if param.name not in self.parallel_parameter_name_list:
logger.info("remove the redundant parameter: %s in ParameterTuple", param.name)
continue
new_param_tuple.append(param)
cell.__dict__[key] = ParameterTuple(new_param_tuple)

def init_parameters_data(self, auto_parallel_mode=False):
"""
Initialize all parameters and replace the original saved parameters in cell.
@@ -750,7 +783,7 @@ class Cell(Cell_):
"""
Returns all trainable parameters.

Returns a list of all trainable parmeters.
Returns a list of all trainable parameters.

Args:
recurse (bool): Whether contains the trainable parameters of subcells. Default: True.
@@ -1031,7 +1064,7 @@ class Cell(Cell_):
Note:
fn must be defined as the following code. `cell_name` is the name of registered cell.
`grad_input` is gradient passed to the cell. `grad_output` is the gradient computed and passed to the
next cell or primitve, which may be modified and returned.
next cell or primitive, which may be modified and returned.
hook_fn(cell_name, grad_input, grad_output) -> Tensor or None.

Args:


+ 4
- 0
mindspore/parallel/_utils.py View File

@@ -35,6 +35,10 @@ def _get_full_batch():
"""Get whether to use full_batch."""
return auto_parallel_context().get_full_batch()

def _get_pipeline_stages():
"""Get pipeline stages"""
return auto_parallel_context().get_pipeline_stages()

def _check_full_batch():
"""
full_batch could only be used under semi_auto_parallel or auto_parallel, check it.


+ 6
- 1
tests/ut/python/parallel/test_pipeline_split.py View File

@@ -120,7 +120,9 @@ def test_pipeline_split_stage0():
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)

for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.1.param"
assert param.name != "cell.block.1.param1"

def test_pipeline_split_stage1():
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
@@ -135,6 +137,9 @@ def test_pipeline_split_stage1():
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.0.param"
assert param.name != "cell.block.0.param1"


def test_pipeline_split_shared_parameter_stage0():


Loading…
Cancel
Save