| @@ -327,16 +327,19 @@ class _Executor: | |||||
| raise TypeError('Parameters need OrderedDict type, but got {}'. | raise TypeError('Parameters need OrderedDict type, but got {}'. | ||||
| format(type(params))) | format(type(params))) | ||||
| def _params_init_data(self, obj, params): | |||||
| def _params_init_data(self, obj, params, auto_parallel_mode=False): | |||||
| """Init parameters' data.""" | |||||
| if params is not None: | if params is not None: | ||||
| for key, param in params.items(): | for key, param in params.items(): | ||||
| if key not in obj.parameter_layout_dict: | |||||
| logger.info("Layout dict does not contain the key %s.", key) | |||||
| if not auto_parallel_mode: | |||||
| param.init_data() | param.init_data() | ||||
| elif key not in obj.parameter_layout_dict: | |||||
| logger.info("Layout dict does not contain the key %s.", key) | |||||
| param.init_data(set_sliced=True) | |||||
| else: | else: | ||||
| layout = obj.parameter_layout_dict[key] | layout = obj.parameter_layout_dict[key] | ||||
| param.init_data(layout) | |||||
| obj.init_parameters_data() | |||||
| param.init_data(layout, set_sliced=True) | |||||
| obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) | |||||
| def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False): | def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False): | ||||
| """ | """ | ||||
| @@ -383,11 +386,11 @@ class _Executor: | |||||
| if not do_convert: | if not do_convert: | ||||
| return phase, True | return phase, True | ||||
| if auto_parallel_mode and "train" in phase: | |||||
| if auto_parallel_mode: | |||||
| obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) | obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) | ||||
| self._params_init_data(obj, params) | |||||
| self._params_init_data(obj, params, auto_parallel_mode) | |||||
| if not enable_debug_runtime or enable_ge: | if not enable_debug_runtime or enable_ge: | ||||
| if auto_parallel_mode and "train" in phase: | |||||
| if auto_parallel_mode: | |||||
| obj.load_parameter_slice(params) | obj.load_parameter_slice(params) | ||||
| # set parallel inputs in sink mode | # set parallel inputs in sink mode | ||||
| @@ -98,6 +98,10 @@ class Parameter: | |||||
| """Get slice status of the parameter.""" | """Get slice status of the parameter.""" | ||||
| return self._sliced | return self._sliced | ||||
| @sliced.setter | |||||
| def sliced(self, sliced_): | |||||
| self._sliced = sliced_ | |||||
| @property | @property | ||||
| def is_init(self): | def is_init(self): | ||||
| """Get init status of the parameter.""" | """Get init status of the parameter.""" | ||||
| @@ -206,15 +210,18 @@ class Parameter: | |||||
| self.default_input = data | self.default_input = data | ||||
| def init_data(self, layout=None): | |||||
| def init_data(self, layout=None, set_sliced=False): | |||||
| """ | """ | ||||
| Init data of the parameter. | Init data of the parameter. | ||||
| Args: | Args: | ||||
| layout (list[list[int]]): parameter slice layout [dev_mat, tensor_map, slice_shape]. | |||||
| dev_mat (list[int]): device matrix. | |||||
| tensor_map (list[int]): tensor map. | |||||
| slice_shape (list[int]): shape of slice. | |||||
| layout (list[list[int]]): Parameter slice layout [dev_mat, tensor_map, slice_shape]. | |||||
| - dev_mat (list[int]): Device matrix. | |||||
| - tensor_map (list[int]): Tensor map. | |||||
| - slice_shape (list[int]): Shape of slice. | |||||
| set_sliced (bool): True if should set parameter sliced after init the data of initializer. | |||||
| Default: False. | |||||
| """ | """ | ||||
| if not isinstance(self.default_input, MetaTensor): | if not isinstance(self.default_input, MetaTensor): | ||||
| return | return | ||||
| @@ -230,7 +237,8 @@ class Parameter: | |||||
| self.default_input = self.init_mode.to_tensor() | self.default_input = self.init_mode.to_tensor() | ||||
| self.init_mode = None | self.init_mode = None | ||||
| self._sliced = True | |||||
| if set_sliced: | |||||
| self.sliced = True | |||||
| class ParameterTuple(tuple): | class ParameterTuple(tuple): | ||||
| @@ -264,11 +264,12 @@ class Cell: | |||||
| logger.info("layout dict does not contain the key %s", key) | logger.info("layout dict does not contain the key %s", key) | ||||
| continue | continue | ||||
| if self.parameters_dict()[key].sliced: | if self.parameters_dict()[key].sliced: | ||||
| logger.info("Param %s is from initializer, already sliced.", key) | |||||
| logger.info("Param %s is already sliced.", key) | |||||
| continue | continue | ||||
| layout = self.parameter_layout_dict[key] | layout = self.parameter_layout_dict[key] | ||||
| new_tensor = _load_tensor_by_layout(tensor, layout) | new_tensor = _load_tensor_by_layout(tensor, layout) | ||||
| self.parameters_dict()[key].set_parameter_data(new_tensor) | self.parameters_dict()[key].set_parameter_data(new_tensor) | ||||
| self.parameters_dict()[key].sliced = True | |||||
| elif isinstance(params, OrderedDict): | elif isinstance(params, OrderedDict): | ||||
| for key in params: | for key in params: | ||||
| tensor = params[key].data | tensor = params[key].data | ||||
| @@ -276,11 +277,12 @@ class Cell: | |||||
| logger.info("layout dict does not contain the key %s", key) | logger.info("layout dict does not contain the key %s", key) | ||||
| continue | continue | ||||
| if params[key].sliced: | if params[key].sliced: | ||||
| logger.info("Param %s is from initializer, already sliced.", key) | |||||
| logger.info("Param %s is already sliced.", key) | |||||
| continue | continue | ||||
| layout = self.parameter_layout_dict[key] | layout = self.parameter_layout_dict[key] | ||||
| new_tensor = _load_tensor_by_layout(tensor, layout) | new_tensor = _load_tensor_by_layout(tensor, layout) | ||||
| params[key].set_parameter_data(new_tensor) | params[key].set_parameter_data(new_tensor) | ||||
| params[key].sliced = True | |||||
| else: | else: | ||||
| raise TypeError('Parameters need OrderedDict type, but got {}'. | raise TypeError('Parameters need OrderedDict type, but got {}'. | ||||
| format(type(params))) | format(type(params))) | ||||
| @@ -435,14 +437,17 @@ class Cell: | |||||
| """ | """ | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def init_parameters_data(self, recurse=True): | |||||
| def init_parameters_data(self, recurse=True, auto_parallel_mode=False): | |||||
| """Init parameters' data.""" | |||||
| for param in self.get_parameters(expand=recurse): | for param in self.get_parameters(expand=recurse): | ||||
| if param.name not in self.parameter_layout_dict: | |||||
| logger.info("Layout dict does not contain the key %s.", param.name) | |||||
| if not auto_parallel_mode: | |||||
| param.init_data() | param.init_data() | ||||
| elif param.name not in self.parameter_layout_dict: | |||||
| logger.info("Layout dict does not contain the key %s.", param.name) | |||||
| param.init_data(set_sliced=True) | |||||
| else: | else: | ||||
| layout = self.parameter_layout_dict[param.name] | layout = self.parameter_layout_dict[param.name] | ||||
| param.init_data(layout) | |||||
| param.init_data(layout, set_sliced=True) | |||||
| def parameters_dict(self, recurse=True): | def parameters_dict(self, recurse=True): | ||||
| """ | """ | ||||