|
|
|
@@ -41,7 +41,6 @@ class Initializer: |
|
|
|
self._kwargs = kwargs |
|
|
|
self.shape = None |
|
|
|
self.dtype = None |
|
|
|
self._seed = None |
|
|
|
|
|
|
|
def _initialize(self, *kwargs): |
|
|
|
raise NotImplementedError('Must be overridden!') |
|
|
|
@@ -49,15 +48,6 @@ class Initializer: |
|
|
|
def __call__(self, arr): |
|
|
|
return self._initialize(arr) |
|
|
|
|
|
|
|
@property |
|
|
|
def seed(self): |
|
|
|
return self._seed |
|
|
|
|
|
|
|
@seed.setter |
|
|
|
def seed(self, seed_): |
|
|
|
"""set the random seed.""" |
|
|
|
self._seed = seed_ |
|
|
|
|
|
|
|
@property |
|
|
|
def shape(self): |
|
|
|
return self._shape |
|
|
|
@@ -74,8 +64,15 @@ class Initializer: |
|
|
|
def dtype(self, dtype): |
|
|
|
self._dtype = dtype |
|
|
|
|
|
|
|
def to_tensor(self): |
|
|
|
"""Get the tensor format data of this Initializer.""" |
|
|
|
def to_tensor(self, slice_index=None): |
|
|
|
""" |
|
|
|
Get the tensor format data of this Initializer. |
|
|
|
|
|
|
|
Args: |
|
|
|
slice_index (int): Slice index of a parameter's slices. |
|
|
|
Used when initialize a slice of a parameter, it guarantee that |
|
|
|
devices use the same slice can generate the same tensor. |
|
|
|
""" |
|
|
|
arr = None |
|
|
|
try: |
|
|
|
arr = np.ndarray(self.shape) |
|
|
|
@@ -83,10 +80,10 @@ class Initializer: |
|
|
|
msg = "Error shape={}".format(self.shape) |
|
|
|
logger.error(msg) |
|
|
|
raise ValueError(msg) |
|
|
|
if self._seed is not None: |
|
|
|
np.random.seed(self.seed) |
|
|
|
|
|
|
|
if slice_index is not None: |
|
|
|
np.random.seed(slice_index) |
|
|
|
self.__call__(arr) |
|
|
|
self._seed = None |
|
|
|
return Tensor(arr, dtype=self.dtype) |
|
|
|
|
|
|
|
def _register(*aliases): |
|
|
|
|