|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
"""Initializer for cell parameters.""" |
|
|
|
import numbers |
|
|
|
import math |
|
|
|
import copy |
|
|
|
|
|
|
|
from functools import reduce |
|
|
|
import numpy as np |
|
|
|
@@ -82,7 +83,7 @@ class Initializer: |
|
|
|
shape = self.shape |
|
|
|
|
|
|
|
try: |
|
|
|
arr = np.ndarray(shape) |
|
|
|
arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype)) |
|
|
|
except ValueError: |
|
|
|
msg = "Error shape={}".format(shape) |
|
|
|
logger.error(msg) |
|
|
|
@@ -478,9 +479,10 @@ def initializer(init, shape=None, dtype=mstype.float32): |
|
|
|
raise ValueError(f"shape is invalid, shape value must be positive integer, shape:{shape}") |
|
|
|
|
|
|
|
if isinstance(init, Initializer): |
|
|
|
init.shape = init.shape if init.shape is not None else shape |
|
|
|
init.dtype = init.dtype if init.dtype is not None else dtype |
|
|
|
return init |
|
|
|
init_copy = copy.deepcopy(init) |
|
|
|
init_copy.shape = shape if shape is not None else init.shape |
|
|
|
init_copy.dtype = init.dtype if init.dtype is not None else dtype |
|
|
|
return init_copy |
|
|
|
|
|
|
|
if isinstance(init, str): |
|
|
|
init_obj = _INITIALIZER_ALIAS[init.lower()]() |
|
|
|
|