Merge pull request !1510 from yihuaijie/devtags/v0.5.0-beta
| @@ -327,16 +327,19 @@ class _Executor: | |||
| raise TypeError('Parameters need OrderedDict type, but got {}'. | |||
| format(type(params))) | |||
| def _params_init_data(self, obj, params): | |||
| def _params_init_data(self, obj, params, auto_parallel_mode=False): | |||
| """Init parameters' data.""" | |||
| if params is not None: | |||
| for key, param in params.items(): | |||
| if key not in obj.parameter_layout_dict: | |||
| logger.info("Layout dict does not contain the key %s.", key) | |||
| if not auto_parallel_mode: | |||
| param.init_data() | |||
| elif key not in obj.parameter_layout_dict: | |||
| logger.info("Layout dict does not contain the key %s.", key) | |||
| param.init_data(set_sliced=True) | |||
| else: | |||
| layout = obj.parameter_layout_dict[key] | |||
| param.init_data(layout) | |||
| obj.init_parameters_data() | |||
| param.init_data(layout, set_sliced=True) | |||
| obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) | |||
| def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False): | |||
| """ | |||
| @@ -383,11 +386,11 @@ class _Executor: | |||
| if not do_convert: | |||
| return phase, True | |||
| if auto_parallel_mode and "train" in phase: | |||
| if auto_parallel_mode: | |||
| obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) | |||
| self._params_init_data(obj, params) | |||
| self._params_init_data(obj, params, auto_parallel_mode) | |||
| if not enable_debug_runtime or enable_ge: | |||
| if auto_parallel_mode and "train" in phase: | |||
| if auto_parallel_mode: | |||
| obj.load_parameter_slice(params) | |||
| # set parallel inputs in sink mode | |||
| @@ -99,6 +99,10 @@ class Parameter: | |||
| """Get slice status of the parameter.""" | |||
| return self._sliced | |||
| @sliced.setter | |||
| def sliced(self, sliced_): | |||
| self._sliced = sliced_ | |||
| @property | |||
| def is_init(self): | |||
| """Get init status of the parameter.""" | |||
| @@ -211,15 +215,18 @@ class Parameter: | |||
| self.default_input = data | |||
| def init_data(self, layout=None): | |||
| def init_data(self, layout=None, set_sliced=False): | |||
| """ | |||
| Init data of the parameter. | |||
| Args: | |||
| layout (list[list[int]]): parameter slice layout [dev_mat, tensor_map, slice_shape]. | |||
| dev_mat (list[int]): device matrix. | |||
| tensor_map (list[int]): tensor map. | |||
| slice_shape (list[int]): shape of slice. | |||
| layout (list[list[int]]): Parameter slice layout [dev_mat, tensor_map, slice_shape]. | |||
| - dev_mat (list[int]): Device matrix. | |||
| - tensor_map (list[int]): Tensor map. | |||
| - slice_shape (list[int]): Shape of slice. | |||
| set_sliced (bool): True if should set parameter sliced after init the data of initializer. | |||
| Default: False. | |||
| """ | |||
| if not isinstance(self.default_input, MetaTensor): | |||
| return | |||
| @@ -235,7 +242,8 @@ class Parameter: | |||
| self.default_input = self.init_mode.to_tensor() | |||
| self.init_mode = None | |||
| self._sliced = True | |||
| if set_sliced: | |||
| self.sliced = True | |||
| class ParameterTuple(tuple): | |||
| @@ -264,11 +264,12 @@ class Cell: | |||
| logger.info("layout dict does not contain the key %s", key) | |||
| continue | |||
| if self.parameters_dict()[key].sliced: | |||
| logger.info("Param %s is from initializer, already sliced.", key) | |||
| logger.info("Param %s is already sliced.", key) | |||
| continue | |||
| layout = self.parameter_layout_dict[key] | |||
| new_tensor = _load_tensor_by_layout(tensor, layout) | |||
| self.parameters_dict()[key].set_parameter_data(new_tensor) | |||
| self.parameters_dict()[key].sliced = True | |||
| elif isinstance(params, OrderedDict): | |||
| for key in params: | |||
| tensor = params[key].data | |||
| @@ -276,11 +277,12 @@ class Cell: | |||
| logger.info("layout dict does not contain the key %s", key) | |||
| continue | |||
| if params[key].sliced: | |||
| logger.info("Param %s is from initializer, already sliced.", key) | |||
| logger.info("Param %s is already sliced.", key) | |||
| continue | |||
| layout = self.parameter_layout_dict[key] | |||
| new_tensor = _load_tensor_by_layout(tensor, layout) | |||
| params[key].set_parameter_data(new_tensor) | |||
| params[key].sliced = True | |||
| else: | |||
| raise TypeError('Parameters need OrderedDict type, but got {}'. | |||
| format(type(params))) | |||
| @@ -435,14 +437,17 @@ class Cell: | |||
| """ | |||
| raise NotImplementedError | |||
| def init_parameters_data(self, recurse=True): | |||
| def init_parameters_data(self, recurse=True, auto_parallel_mode=False): | |||
| """Init parameters' data.""" | |||
| for param in self.get_parameters(expand=recurse): | |||
| if param.name not in self.parameter_layout_dict: | |||
| logger.info("Layout dict does not contain the key %s.", param.name) | |||
| if not auto_parallel_mode: | |||
| param.init_data() | |||
| elif param.name not in self.parameter_layout_dict: | |||
| logger.info("Layout dict does not contain the key %s.", param.name) | |||
| param.init_data(set_sliced=True) | |||
| else: | |||
| layout = self.parameter_layout_dict[param.name] | |||
| param.init_data(layout) | |||
| param.init_data(layout, set_sliced=True) | |||
| def parameters_dict(self, recurse=True): | |||
| """ | |||