| @@ -41,7 +41,6 @@ class Initializer: | |||
| self._kwargs = kwargs | |||
| self.shape = None | |||
| self.dtype = None | |||
| self._seed = None | |||
| def _initialize(self, *kwargs): | |||
| raise NotImplementedError('Must be overridden!') | |||
| @@ -49,15 +48,6 @@ class Initializer: | |||
| def __call__(self, arr): | |||
| return self._initialize(arr) | |||
| @property | |||
| def seed(self): | |||
| return self._seed | |||
| @seed.setter | |||
| def seed(self, seed_): | |||
| """set the random seed.""" | |||
| self._seed = seed_ | |||
| @property | |||
| def shape(self): | |||
| return self._shape | |||
| @@ -74,8 +64,15 @@ class Initializer: | |||
| def dtype(self, dtype): | |||
| self._dtype = dtype | |||
| def to_tensor(self): | |||
| """Get the tensor format data of this Initializer.""" | |||
| def to_tensor(self, slice_index=None): | |||
| """ | |||
| Get the tensor format data of this Initializer. | |||
| Args: | |||
| slice_index (int): Slice index of a parameter's slices. | |||
| Used when initialize a slice of a parameter, it guarantee that | |||
| devices use the same slice can generate the same tensor. | |||
| """ | |||
| arr = None | |||
| try: | |||
| arr = np.ndarray(self.shape) | |||
| @@ -83,10 +80,10 @@ class Initializer: | |||
| msg = "Error shape={}".format(self.shape) | |||
| logger.error(msg) | |||
| raise ValueError(msg) | |||
| if self._seed is not None: | |||
| np.random.seed(self.seed) | |||
| if slice_index is not None: | |||
| np.random.seed(slice_index) | |||
| self.__call__(arr) | |||
| self._seed = None | |||
| return Tensor(arr, dtype=self.dtype) | |||
| def _register(*aliases): | |||
| @@ -22,7 +22,7 @@ from .initializer import initializer, Initializer | |||
| from .tensor import Tensor, MetaTensor | |||
| from .._checkparam import _check_str_by_regular | |||
| from ..parallel._utils import _set_clone_info, _CloneInfo | |||
| from ..parallel._tensor import _get_seed | |||
| from ..parallel._tensor import _get_slice_index | |||
| __all__ = ['Parameter', 'ParameterTuple'] | |||
| @@ -250,9 +250,11 @@ class Parameter: | |||
| raise ValueError("The length of layout must be 3! layout is {}." | |||
| .format(layout)) | |||
| self.init_mode.shape = layout[2] | |||
| self.init_mode.seed = int(_get_seed(layout[0], layout[1])) | |||
| slice_index = int(_get_slice_index(layout[0], layout[1])) | |||
| self.default_input = self.init_mode.to_tensor(slice_index) | |||
| else: | |||
| self.default_input = self.init_mode.to_tensor() | |||
| self.default_input = self.init_mode.to_tensor() | |||
| self.init_mode = None | |||
| if set_sliced: | |||
| self.sliced = True | |||
| @@ -168,21 +168,21 @@ def _chunk_tensor_by_strategy(np_tensor, strategy): | |||
| raise ValueError("The length of np_tensor does not match the length of strategy!") | |||
| return _chunk_tensor(np_tensor, strategy, len(strategy)) | |||
| def _get_seed(dev_mat, tensor_map): | |||
| def _get_slice_index(dev_mat, tensor_map): | |||
| """ | |||
| Get the random seed for current slice. | |||
| Get the slice index for current slice. | |||
| Args: | |||
| dev_mat (list): The device matrix of devices. | |||
| tensor_map (list): The split strategy of tensor. | |||
| Returns: | |||
| Integer, the local random seed for this device. | |||
| Integer, the slice index for slice on this device. | |||
| """ | |||
| rank = get_rank() | |||
| tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) | |||
| tensor_slice_seed = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) | |||
| return tensor_slice_seed | |||
| tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) | |||
| return tensor_slice_index | |||
| def _load_tensor(tensor, dev_mat, tensor_map): | |||
| """ | |||