|
|
@@ -64,7 +64,7 @@ class Initializer: |
|
|
def dtype(self, dtype): |
|
|
def dtype(self, dtype): |
|
|
self._dtype = dtype |
|
|
self._dtype = dtype |
|
|
|
|
|
|
|
|
def to_tensor(self, slice_index=None): |
|
|
|
|
|
|
|
|
def to_tensor(self, slice_index=None, shape=None): |
|
|
""" |
|
|
""" |
|
|
Get the tensor format data of this Initializer. |
|
|
Get the tensor format data of this Initializer. |
|
|
|
|
|
|
|
|
@@ -72,12 +72,16 @@ class Initializer: |
|
|
slice_index (int): Slice index of a parameter's slices. |
|
|
slice_index (int): Slice index of a parameter's slices. |
|
|
Used when initialize a slice of a parameter, it guarantee that |
|
|
Used when initialize a slice of a parameter, it guarantee that |
|
|
devices use the same slice can generate the same tensor. |
|
|
devices use the same slice can generate the same tensor. |
|
|
|
|
|
shape (list[int]): Shape of the slice, used when initialize a slice of the parameter. |
|
|
""" |
|
|
""" |
|
|
arr = None |
|
|
arr = None |
|
|
|
|
|
if shape is None: |
|
|
|
|
|
shape = self.shape |
|
|
|
|
|
|
|
|
try: |
|
|
try: |
|
|
arr = np.ndarray(self.shape) |
|
|
|
|
|
|
|
|
arr = np.ndarray(shape) |
|
|
except ValueError: |
|
|
except ValueError: |
|
|
msg = "Error shape={}".format(self.shape) |
|
|
|
|
|
|
|
|
msg = "Error shape={}".format(shape) |
|
|
logger.error(msg) |
|
|
logger.error(msg) |
|
|
raise ValueError(msg) |
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
|