From d36af3d8e56c4f5c6a3aef452390f5af8630a972 Mon Sep 17 00:00:00 2001 From: Yi Huaijie Date: Wed, 27 May 2020 12:58:47 +0800 Subject: [PATCH] seperate auto_parallel and stand_alone when init initializer data --- mindspore/common/api.py | 19 +++++++++++-------- mindspore/common/parameter.py | 20 ++++++++++++++------ mindspore/nn/cell.py | 17 +++++++++++------ 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 5c1fba328e..529eb8060c 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -327,16 +327,19 @@ class _Executor: raise TypeError('Parameters need OrderedDict type, but got {}'. format(type(params))) - def _params_init_data(self, obj, params): + def _params_init_data(self, obj, params, auto_parallel_mode=False): + """Init parameters' data.""" if params is not None: for key, param in params.items(): - if key not in obj.parameter_layout_dict: - logger.info("Layout dict does not contain the key %s.", key) + if not auto_parallel_mode: param.init_data() + elif key not in obj.parameter_layout_dict: + logger.info("Layout dict does not contain the key %s.", key) + param.init_data(set_sliced=True) else: layout = obj.parameter_layout_dict[key] - param.init_data(layout) - obj.init_parameters_data() + param.init_data(layout, set_sliced=True) + obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False): """ @@ -383,11 +386,11 @@ class _Executor: if not do_convert: return phase, True - if auto_parallel_mode and "train" in phase: + if auto_parallel_mode: obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) - self._params_init_data(obj, params) + self._params_init_data(obj, params, auto_parallel_mode) if not enable_debug_runtime or enable_ge: - if auto_parallel_mode and "train" in phase: + if auto_parallel_mode: obj.load_parameter_slice(params) # set parallel inputs in sink mode diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index dc23d4e7f1..e7f96601bb 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -98,6 +98,10 @@ class Parameter: """Get slice status of the parameter.""" return self._sliced + @sliced.setter + def sliced(self, sliced_): + self._sliced = sliced_ + @property def is_init(self): """Get init status of the parameter.""" @@ -206,15 +210,18 @@ class Parameter: self.default_input = data - def init_data(self, layout=None): + def init_data(self, layout=None, set_sliced=False): """ Init data of the parameter. Args: - layout (list[list[int]]): parameter slice layout [dev_mat, tensor_map, slice_shape]. - dev_mat (list[int]): device matrix. - tensor_map (list[int]): tensor map. - slice_shape (list[int]): shape of slice. + layout (list[list[int]]): Parameter slice layout [dev_mat, tensor_map, slice_shape]. + + - dev_mat (list[int]): Device matrix. + - tensor_map (list[int]): Tensor map. + - slice_shape (list[int]): Shape of slice. + set_sliced (bool): True if should set parameter sliced after init the data of initializer. + Default: False. """ if not isinstance(self.default_input, MetaTensor): return @@ -230,7 +237,8 @@ class Parameter: self.default_input = self.init_mode.to_tensor() self.init_mode = None - self._sliced = True + if set_sliced: + self.sliced = True class ParameterTuple(tuple): diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index e0563a05fa..e6d2dc7383 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -264,11 +264,12 @@ class Cell: logger.info("layout dict does not contain the key %s", key) continue if self.parameters_dict()[key].sliced: - logger.info("Param %s is from initializer, already sliced.", key) + logger.info("Param %s is already sliced.", key) continue layout = self.parameter_layout_dict[key] new_tensor = _load_tensor_by_layout(tensor, layout) self.parameters_dict()[key].set_parameter_data(new_tensor) + self.parameters_dict()[key].sliced = True elif isinstance(params, OrderedDict): for key in params: tensor = params[key].data @@ -276,11 +277,12 @@ class Cell: logger.info("layout dict does not contain the key %s", key) continue if params[key].sliced: - logger.info("Param %s is from initializer, already sliced.", key) + logger.info("Param %s is already sliced.", key) continue layout = self.parameter_layout_dict[key] new_tensor = _load_tensor_by_layout(tensor, layout) params[key].set_parameter_data(new_tensor) + params[key].sliced = True else: raise TypeError('Parameters need OrderedDict type, but got {}'. format(type(params))) @@ -435,14 +437,17 @@ class Cell: """ raise NotImplementedError - def init_parameters_data(self, recurse=True): + def init_parameters_data(self, recurse=True, auto_parallel_mode=False): + """Init parameters' data.""" for param in self.get_parameters(expand=recurse): - if param.name not in self.parameter_layout_dict: - logger.info("Layout dict does not contain the key %s.", param.name) + if not auto_parallel_mode: param.init_data() + elif param.name not in self.parameter_layout_dict: + logger.info("Layout dict does not contain the key %s.", param.name) + param.init_data(set_sliced=True) else: layout = self.parameter_layout_dict[param.name] - param.init_data(layout) + param.init_data(layout, set_sliced=True) def parameters_dict(self, recurse=True): """