Browse Source

!1510 seperate auto_parallel and stand_alone when init initializer data

Merge pull request !1510 from yihuaijie/dev
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
d5634eb4d5
3 changed files with 36 additions and 20 deletions
  1. +11
    -8
      mindspore/common/api.py
  2. +14
    -6
      mindspore/common/parameter.py
  3. +11
    -6
      mindspore/nn/cell.py

+ 11
- 8
mindspore/common/api.py View File

@@ -327,16 +327,19 @@ class _Executor:
raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params)))

def _params_init_data(self, obj, params):
def _params_init_data(self, obj, params, auto_parallel_mode=False):
"""Init parameters' data."""
if params is not None:
for key, param in params.items():
if key not in obj.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", key)
if not auto_parallel_mode:
param.init_data()
elif key not in obj.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", key)
param.init_data(set_sliced=True)
else:
layout = obj.parameter_layout_dict[key]
param.init_data(layout)
obj.init_parameters_data()
param.init_data(layout, set_sliced=True)
obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)

def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
"""
@@ -383,11 +386,11 @@ class _Executor:
if not do_convert:
return phase, True

if auto_parallel_mode and "train" in phase:
if auto_parallel_mode:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
self._params_init_data(obj, params)
self._params_init_data(obj, params, auto_parallel_mode)
if not enable_debug_runtime or enable_ge:
if auto_parallel_mode and "train" in phase:
if auto_parallel_mode:
obj.load_parameter_slice(params)

# set parallel inputs in sink mode


+ 14
- 6
mindspore/common/parameter.py View File

@@ -99,6 +99,10 @@ class Parameter:
"""Get slice status of the parameter."""
return self._sliced

@sliced.setter
def sliced(self, sliced_):
self._sliced = sliced_

@property
def is_init(self):
"""Get init status of the parameter."""
@@ -211,15 +215,18 @@ class Parameter:
self.default_input = data


def init_data(self, layout=None):
def init_data(self, layout=None, set_sliced=False):
"""
Init data of the parameter.

Args:
layout (list[list[int]]): parameter slice layout [dev_mat, tensor_map, slice_shape].
dev_mat (list[int]): device matrix.
tensor_map (list[int]): tensor map.
slice_shape (list[int]): shape of slice.
layout (list[list[int]]): Parameter slice layout [dev_mat, tensor_map, slice_shape].

- dev_mat (list[int]): Device matrix.
- tensor_map (list[int]): Tensor map.
- slice_shape (list[int]): Shape of slice.
set_sliced (bool): True if should set parameter sliced after init the data of initializer.
Default: False.
"""
if not isinstance(self.default_input, MetaTensor):
return
@@ -235,7 +242,8 @@ class Parameter:

self.default_input = self.init_mode.to_tensor()
self.init_mode = None
self._sliced = True
if set_sliced:
self.sliced = True


class ParameterTuple(tuple):


+ 11
- 6
mindspore/nn/cell.py View File

@@ -264,11 +264,12 @@ class Cell:
logger.info("layout dict does not contain the key %s", key)
continue
if self.parameters_dict()[key].sliced:
logger.info("Param %s is from initializer, already sliced.", key)
logger.info("Param %s is already sliced.", key)
continue
layout = self.parameter_layout_dict[key]
new_tensor = _load_tensor_by_layout(tensor, layout)
self.parameters_dict()[key].set_parameter_data(new_tensor)
self.parameters_dict()[key].sliced = True
elif isinstance(params, OrderedDict):
for key in params:
tensor = params[key].data
@@ -276,11 +277,12 @@ class Cell:
logger.info("layout dict does not contain the key %s", key)
continue
if params[key].sliced:
logger.info("Param %s is from initializer, already sliced.", key)
logger.info("Param %s is already sliced.", key)
continue
layout = self.parameter_layout_dict[key]
new_tensor = _load_tensor_by_layout(tensor, layout)
params[key].set_parameter_data(new_tensor)
params[key].sliced = True
else:
raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params)))
@@ -435,14 +437,17 @@ class Cell:
"""
raise NotImplementedError

def init_parameters_data(self, recurse=True):
def init_parameters_data(self, recurse=True, auto_parallel_mode=False):
"""Init parameters' data."""
for param in self.get_parameters(expand=recurse):
if param.name not in self.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", param.name)
if not auto_parallel_mode:
param.init_data()
elif param.name not in self.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", param.name)
param.init_data(set_sliced=True)
else:
layout = self.parameter_layout_dict[param.name]
param.init_data(layout)
param.init_data(layout, set_sliced=True)

def parameters_dict(self, recurse=True):
"""


Loading…
Cancel
Save