| @@ -41,7 +41,6 @@ class Initializer: | |||||
| self._kwargs = kwargs | self._kwargs = kwargs | ||||
| self.shape = None | self.shape = None | ||||
| self.dtype = None | self.dtype = None | ||||
| self._seed = None | |||||
| def _initialize(self, *kwargs): | def _initialize(self, *kwargs): | ||||
| raise NotImplementedError('Must be overridden!') | raise NotImplementedError('Must be overridden!') | ||||
| @@ -49,15 +48,6 @@ class Initializer: | |||||
| def __call__(self, arr): | def __call__(self, arr): | ||||
| return self._initialize(arr) | return self._initialize(arr) | ||||
| @property | |||||
| def seed(self): | |||||
| return self._seed | |||||
| @seed.setter | |||||
| def seed(self, seed_): | |||||
| """set the random seed.""" | |||||
| self._seed = seed_ | |||||
| @property | @property | ||||
| def shape(self): | def shape(self): | ||||
| return self._shape | return self._shape | ||||
| @@ -74,8 +64,15 @@ class Initializer: | |||||
| def dtype(self, dtype): | def dtype(self, dtype): | ||||
| self._dtype = dtype | self._dtype = dtype | ||||
| def to_tensor(self): | |||||
| """Get the tensor format data of this Initializer.""" | |||||
| def to_tensor(self, slice_index=None): | |||||
| """ | |||||
| Get the tensor format data of this Initializer. | |||||
| Args: | |||||
| slice_index (int): Slice index of a parameter's slices. | |||||
| Used when initialize a slice of a parameter, it guarantee that | |||||
| devices use the same slice can generate the same tensor. | |||||
| """ | |||||
| arr = None | arr = None | ||||
| try: | try: | ||||
| arr = np.ndarray(self.shape) | arr = np.ndarray(self.shape) | ||||
| @@ -83,10 +80,10 @@ class Initializer: | |||||
| msg = "Error shape={}".format(self.shape) | msg = "Error shape={}".format(self.shape) | ||||
| logger.error(msg) | logger.error(msg) | ||||
| raise ValueError(msg) | raise ValueError(msg) | ||||
| if self._seed is not None: | |||||
| np.random.seed(self.seed) | |||||
| if slice_index is not None: | |||||
| np.random.seed(slice_index) | |||||
| self.__call__(arr) | self.__call__(arr) | ||||
| self._seed = None | |||||
| return Tensor(arr, dtype=self.dtype) | return Tensor(arr, dtype=self.dtype) | ||||
| def _register(*aliases): | def _register(*aliases): | ||||
| @@ -22,7 +22,7 @@ from .initializer import initializer, Initializer | |||||
| from .tensor import Tensor, MetaTensor | from .tensor import Tensor, MetaTensor | ||||
| from .._checkparam import _check_str_by_regular | from .._checkparam import _check_str_by_regular | ||||
| from ..parallel._utils import _set_clone_info, _CloneInfo | from ..parallel._utils import _set_clone_info, _CloneInfo | ||||
| from ..parallel._tensor import _get_seed | |||||
| from ..parallel._tensor import _get_slice_index | |||||
| __all__ = ['Parameter', 'ParameterTuple'] | __all__ = ['Parameter', 'ParameterTuple'] | ||||
| @@ -250,9 +250,11 @@ class Parameter: | |||||
| raise ValueError("The length of layout must be 3! layout is {}." | raise ValueError("The length of layout must be 3! layout is {}." | ||||
| .format(layout)) | .format(layout)) | ||||
| self.init_mode.shape = layout[2] | self.init_mode.shape = layout[2] | ||||
| self.init_mode.seed = int(_get_seed(layout[0], layout[1])) | |||||
| slice_index = int(_get_slice_index(layout[0], layout[1])) | |||||
| self.default_input = self.init_mode.to_tensor(slice_index) | |||||
| else: | |||||
| self.default_input = self.init_mode.to_tensor() | |||||
| self.default_input = self.init_mode.to_tensor() | |||||
| self.init_mode = None | self.init_mode = None | ||||
| if set_sliced: | if set_sliced: | ||||
| self.sliced = True | self.sliced = True | ||||
| @@ -168,21 +168,21 @@ def _chunk_tensor_by_strategy(np_tensor, strategy): | |||||
| raise ValueError("The length of np_tensor does not match the length of strategy!") | raise ValueError("The length of np_tensor does not match the length of strategy!") | ||||
| return _chunk_tensor(np_tensor, strategy, len(strategy)) | return _chunk_tensor(np_tensor, strategy, len(strategy)) | ||||
| def _get_seed(dev_mat, tensor_map): | |||||
| def _get_slice_index(dev_mat, tensor_map): | |||||
| """ | """ | ||||
| Get the random seed for current slice. | |||||
| Get the slice index for current slice. | |||||
| Args: | Args: | ||||
| dev_mat (list): The device matrix of devices. | dev_mat (list): The device matrix of devices. | ||||
| tensor_map (list): The split strategy of tensor. | tensor_map (list): The split strategy of tensor. | ||||
| Returns: | Returns: | ||||
| Integer, the local random seed for this device. | |||||
| Integer, the slice index for slice on this device. | |||||
| """ | """ | ||||
| rank = get_rank() | rank = get_rank() | ||||
| tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) | tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) | ||||
| tensor_slice_seed = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) | |||||
| return tensor_slice_seed | |||||
| tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) | |||||
| return tensor_slice_index | |||||
| def _load_tensor(tensor, dev_mat, tensor_map): | def _load_tensor(tensor, dev_mat, tensor_map): | ||||
| """ | """ | ||||