From: @Somnus2020 Reviewed-by: @kingxian Signed-off-by: @kingxiantags/v1.2.0-rc1
| @@ -21,7 +21,7 @@ import numpy as np | |||
| from scipy.stats import truncnorm | |||
| from .seed import get_seed, _get_graph_seed | |||
| from . import dtype as mstype | |||
| from .tensor import Tensor, MetaTensor | |||
| from .tensor import Tensor | |||
| from .._c_expression import random_normal | |||
| _INITIALIZER_ALIAS = dict() | |||
| @@ -416,8 +416,8 @@ def initializer(init, shape=None, dtype=mstype.float32): | |||
| dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mindspore.float32. | |||
| Returns: | |||
| Union[Tensor, MetaTensor], When `init` is Tensor, the return is Tensor object, | |||
| otherwise the return is Initialize object. | |||
| Union[Tensor], return is Tensor object. | |||
| Examples: | |||
| >>> tensor = initializer('ones', [1, 2, 3], mindspore.float32) | |||
| @@ -451,7 +451,7 @@ def initializer(init, shape=None, dtype=mstype.float32): | |||
| elif isinstance(init, numbers.Number): | |||
| init = Constant(init) | |||
| shape = shape if shape is not None else init.shape | |||
| init_obj = MetaTensor(dtype, shape, init) | |||
| init_obj = Tensor(dtype=dtype, shape=shape, init=init) | |||
| return init_obj | |||
| __all__ = [ | |||
| @@ -17,11 +17,11 @@ | |||
| from copy import copy | |||
| import numbers | |||
| from .._c_expression import ParamInfo | |||
| from .._c_expression import MetaTensor as MetaTensor_ | |||
| from . import dtype as mstype | |||
| from .initializer import initializer | |||
| from .tensor import Tensor, MetaTensor | |||
| from .tensor import Tensor | |||
| from .._checkparam import Validator | |||
| from .._c_expression import Tensor as Tensor_ | |||
| from ..parallel._tensor import _get_slice_index | |||
| from ..parallel._auto_parallel_context import auto_parallel_context | |||
| from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table | |||
| @@ -52,15 +52,15 @@ def init_to_value(init): | |||
| raise ValueError("init should be number or string") | |||
| class Parameter(MetaTensor_): | |||
| class Parameter(Tensor_): | |||
| """ | |||
| Parameter types of cell models. | |||
| After initialized `Parameter` is a subtype of `Tensor`. | |||
| In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by | |||
| an `MetaTensor`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor_` | |||
| only saves the shape and type info of a tensor with no memory usage. The shape can be changed while | |||
| an `Tensor`, the type of Parameter will be `Tensor`. `Tensor` | |||
| will save the shape and type info of a tensor with no memory usage. The shape can be changed while | |||
| compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. | |||
| Note: | |||
| @@ -72,7 +72,7 @@ class Parameter(MetaTensor_): | |||
| otherwise, the parameter name may be different than expected. | |||
| Args: | |||
| default_input (Union[Tensor, MetaTensor, Number]): Parameter data, to be set initialized. | |||
| default_input (Union[Tensor, Number]): Parameter data, to be set initialized. | |||
| name (str): Name of the child parameter. Default: None. | |||
| requires_grad (bool): True if the parameter requires gradient. Default: True. | |||
| layerwise_parallel (bool): When layerwise_parallel is true in data parallel mode, | |||
| @@ -116,12 +116,12 @@ class Parameter(MetaTensor_): | |||
| new_type = Parameter._get_base_class(input_class) | |||
| obj = input_class.__new__(new_type) | |||
| input_class.__init__(obj, *class_init_args) | |||
| # it's better to make the Initializer a kind of metatensor. | |||
| # it's better to make the Initializer a kind of tensor. | |||
| obj.init_mode = None | |||
| obj.is_default_input_meta = False | |||
| if isinstance(default_input, MetaTensor): | |||
| obj.is_default_input_meta = True | |||
| if not isinstance(obj, Tensor): | |||
| obj.is_default_input_init = False | |||
| if isinstance(default_input, Tensor) and default_input.has_init: | |||
| obj.is_default_input_init = True | |||
| if obj.has_init: | |||
| obj.init_mode = default_input | |||
| return obj | |||
| @@ -154,12 +154,18 @@ class Parameter(MetaTensor_): | |||
| self._cast_type = None | |||
| self._unique = False | |||
| self.is_in_parallel = _is_in_parallel_mode() | |||
| if isinstance(default_input, (MetaTensor, Tensor)): | |||
| MetaTensor_.__init__(self, default_input.dtype, default_input.shape) | |||
| if isinstance(default_input, Tensor): | |||
| Tensor_.__init__(self, default_input.dtype, default_input.shape) | |||
| elif isinstance(default_input, int): | |||
| MetaTensor_.__init__(self, mstype.int64, ()) | |||
| Tensor_.__init__(self, mstype.int64, ()) | |||
| elif isinstance(default_input, float): | |||
| MetaTensor_.__init__(self, mstype.float32, ()) | |||
| Tensor_.__init__(self, mstype.float32, ()) | |||
| def __deepcopy__(self, memodict): | |||
| new_obj = Parameter(self) | |||
| new_obj.name = self.name | |||
| new_obj._inited_param = self._inited_param # pylint: disable=W0212 | |||
| return new_obj | |||
| @staticmethod | |||
| def _get_base_class(input_class): | |||
| @@ -176,12 +182,12 @@ class Parameter(MetaTensor_): | |||
| """Set `set_data` of current `Parameter`.""" | |||
| if isinstance(data, bool): | |||
| raise ValueError('Parameter data can not be `bool`') | |||
| if isinstance(data, MetaTensor): | |||
| if isinstance(data, Tensor) and data.has_init: | |||
| if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched(): | |||
| # do not init data while in auto parallel. | |||
| return (MetaTensor_, data.dtype, data.shape) | |||
| data = data.to_tensor() | |||
| if isinstance(data, Tensor): | |||
| return (Tensor, None, data.dtype, data.shape, data.init) | |||
| data = data.init_data().asnumpy() | |||
| elif isinstance(data, Tensor): | |||
| # make a copy of Tensor to init the parameter | |||
| return (Tensor, data.asnumpy(),) | |||
| if isinstance(data, int): | |||
| @@ -322,7 +328,7 @@ class Parameter(MetaTensor_): | |||
| Clone the parameter. | |||
| Args: | |||
| init (Union[Tensor, str, MetaTensor, numbers.Number]): Initialize the shape of the parameter. | |||
| init (Union[Tensor, str, numbers.Number]): Initialize the shape of the parameter. | |||
| Default: 'same'. | |||
| Returns: | |||
| @@ -332,6 +338,7 @@ class Parameter(MetaTensor_): | |||
| # pylint: disable=protected-access | |||
| x._param_info = self._param_info.clone() | |||
| x.is_init = False | |||
| x.init = self.init | |||
| x.is_param_ps = self.is_param_ps | |||
| x.init_in_server = self.init_in_server | |||
| x.cache_enable = self.cache_enable | |||
| @@ -382,6 +389,7 @@ class Parameter(MetaTensor_): | |||
| if isinstance(self, Tensor): | |||
| # for Tensor same shape: | |||
| self.init_flag = False | |||
| self.init = None | |||
| return self.assign_value(data) | |||
| # create a new tensor | |||
| return Parameter(data, self.name, self.requires_grad) | |||
| @@ -391,7 +399,7 @@ class Parameter(MetaTensor_): | |||
| Set `set_data` of current `Parameter`. | |||
| Args: | |||
| data (Union[Tensor, MetaTensor, int, float]): new data. | |||
| data (Union[Tensor, int, float]): new data. | |||
| slice_shape (bool): If slice the parameter is set to true, the shape is not checked for consistency. | |||
| Default: False. | |||
| @@ -403,20 +411,20 @@ class Parameter(MetaTensor_): | |||
| f"Current dtype is {self.dtype}, and incoming is {incoming}. " | |||
| f"Use .set_dtype(xxx) to change the dtype.") | |||
| if not isinstance(data, (MetaTensor_, int, float)): | |||
| raise TypeError(f"Parameter data must be [`MetaTensor`, `int`, `float`] or a kind of `MetaTensor_` " | |||
| f"(like `Tensor` or `MetaTensor_`). But with type {type(data)}.") | |||
| if not isinstance(data, (Tensor, int, float)): | |||
| raise TypeError(f"Parameter data must be [`Tensor`, `int`, `float`] or a kind of `Tensor` " | |||
| f"(like `Tensor`). But with type {type(data)}.") | |||
| if isinstance(data, (int, float)): | |||
| if self.dtype in mstype.int_type and isinstance(data, float): | |||
| raise_type_error(mstype.float_) | |||
| data = Tensor(data, self.dtype) | |||
| # both not init. | |||
| is_incoming_tensor = isinstance(data, Tensor) | |||
| is_current_tensor = isinstance(self, Tensor) | |||
| incoming_tensor_is_init = isinstance(data, Tensor) and not data.has_init | |||
| current_tensor_is_init = isinstance(self, Tensor) and not self.has_init | |||
| if is_incoming_tensor and not is_current_tensor: | |||
| raise TypeError("Parameter is a `MetaTensor_` and not initializered, `data` for `set_data`" | |||
| "should be a MetaTensor. If you want to update it by Tensor, call method" | |||
| if incoming_tensor_is_init and not current_tensor_is_init: | |||
| raise TypeError("Parameter is a `Tensor` and not initializered, `data` for `set_data`" | |||
| "should be a Tensor. If you want to update it by Tensor, call method" | |||
| "`init_parameters_data` of `Cell` to init and replace all the Parameter of" | |||
| "network, then call this method.") | |||
| if tuple(self.shape) != tuple(data.shape): | |||
| @@ -429,16 +437,16 @@ class Parameter(MetaTensor_): | |||
| raise_type_error(data.dtype) | |||
| else: | |||
| data = Tensor(data, self.dtype) | |||
| if isinstance(data, MetaTensor): | |||
| if isinstance(data, Tensor) and data.has_init: | |||
| # The parameter has been initializered, directly update by the data | |||
| if is_current_tensor: | |||
| self._update_tensor_data(data.to_tensor()) | |||
| if current_tensor_is_init: | |||
| self._update_tensor_data(data.init_data()) | |||
| else: | |||
| # also update the related inited parameter data | |||
| if self.inited_param is not None: | |||
| self.inited_param.set_data(data) | |||
| self.init_mode = data | |||
| elif is_incoming_tensor or is_current_tensor: | |||
| elif incoming_tensor_is_init or current_tensor_is_init: | |||
| self._update_tensor_data(data) | |||
| else: | |||
| raise ValueError(f"Not support to update the Parameter by {data}") | |||
| @@ -465,10 +473,10 @@ class Parameter(MetaTensor_): | |||
| Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before, | |||
| returns the same initialized `Parameter`. | |||
| """ | |||
| if self.is_default_input_meta: | |||
| if self.is_default_input_init: | |||
| is_current_in_parallel = _is_in_parallel_mode() | |||
| if self.is_in_parallel != is_current_in_parallel: | |||
| raise RuntimeError("Must set or change parallel mode before any MetaTensor created.") | |||
| raise RuntimeError("Must set or change parallel mode before any Tensor created.") | |||
| if self.init_mode is None: | |||
| return self | |||
| if self.inited_param is not None: | |||
| @@ -482,21 +490,23 @@ class Parameter(MetaTensor_): | |||
| if len(layout) < 3: | |||
| raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout)) | |||
| slice_index = int(_get_slice_index(layout[0], layout[1])) | |||
| if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): | |||
| if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) | |||
| and self.init_mode.init is not None): | |||
| if _is_role_worker() or _is_role_sched(): | |||
| data = self.init_mode.to_tensor(0, [1]) | |||
| data = self.init_mode.init_data(0, [1]) | |||
| else: | |||
| data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) | |||
| data = self.init_mode.init_data(slice_index, layout[2], layout[5]) | |||
| else: | |||
| data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) | |||
| data = self.init_mode.init_data(slice_index, layout[2], layout[5]) | |||
| else: | |||
| if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): | |||
| if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) | |||
| and self.init_mode.init is not None): | |||
| if _is_role_worker() or _is_role_sched(): | |||
| data = self.init_mode.to_tensor(0, [1]) | |||
| data = self.init_mode.init_data(0, [1]) | |||
| else: | |||
| data = self.init_mode.to_tensor() | |||
| data = self.init_mode.init_data() | |||
| else: | |||
| data = self.init_mode.to_tensor() | |||
| data = self.init_mode.init_data() | |||
| obj = self._update_tensor_data(data) | |||
| if id(obj) != id(self): | |||
| @@ -19,7 +19,7 @@ from mindspore import log as logger | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| from . import dtype as mstype | |||
| from ._register_for_tensor import tensor_operator_registry | |||
| from .._c_expression import MetaTensor as MetaTensor_ | |||
| from .._c_expression import MetaTensor | |||
| from .._c_expression import Tensor as Tensor_ | |||
| from .._checkparam import Validator as validator | |||
| @@ -60,32 +60,46 @@ class Tensor(Tensor_): | |||
| >>> assert t2.dtype == mindspore.float64 | |||
| """ | |||
| def __init__(self, input_data, dtype=None): | |||
| def __init__(self, input_data=None, dtype=None, shape=None, init=None): | |||
| # If input data is numpy number, convert it to np array | |||
| if isinstance(input_data, np_types): | |||
| input_data = np.array(input_data) | |||
| if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False: | |||
| raise TypeError("input_data and init can not be None at the same time.") | |||
| # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. | |||
| validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool), | |||
| 'Tensor') | |||
| valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, | |||
| np.float16, np.float32, np.float64, np.bool_) | |||
| if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes: | |||
| raise TypeError(f"For Tensor, the input_data is a numpy array, " | |||
| f"but it's data type is not in supported list: {list(i.__name__ for i in valid_dtypes)}.") | |||
| if isinstance(input_data, (tuple, list)): | |||
| if np.array(input_data).dtype not in valid_dtypes: | |||
| raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.") | |||
| if dtype is not None: | |||
| validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor") | |||
| if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']): | |||
| input_data = np.ascontiguousarray(input_data) | |||
| if dtype is None: | |||
| Tensor_.__init__(self, input_data) | |||
| if init is None: | |||
| validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool), | |||
| 'Tensor') | |||
| valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, | |||
| np.float16, np.float32, np.float64, np.bool_) | |||
| if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes: | |||
| raise TypeError(f"For Tensor, the input_data is a numpy array, " | |||
| f"but it's data type is not in supported list:\ | |||
| {list(i.__name__ for i in valid_dtypes)}.") | |||
| if isinstance(input_data, (tuple, list)): | |||
| if np.array(input_data).dtype not in valid_dtypes: | |||
| raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.") | |||
| if dtype is not None: | |||
| validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor") | |||
| if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']): | |||
| input_data = np.ascontiguousarray(input_data) | |||
| if dtype is None: | |||
| Tensor_.__init__(self, input_data) | |||
| else: | |||
| Tensor_.__init__(self, input_data, dtype) | |||
| else: | |||
| Tensor_.__init__(self, input_data, dtype) | |||
| Tensor_.__init__(self, dtype, shape) | |||
| self._virtual_flag = False | |||
| self.init = init | |||
| def __deepcopy__(self, memodict): | |||
| new_obj = Tensor(self) | |||
| new_obj.init = self.init | |||
| new_obj._virtual_flag = self._virtual_flag # pylint:disable=w0212 | |||
| return new_obj | |||
| def __repr__(self): | |||
| Tensor_.data_sync(self, False) | |||
| @@ -248,6 +262,11 @@ class Tensor(Tensor_): | |||
| """The ndim of tensor is an integer.""" | |||
| return len(self._shape) | |||
| @property | |||
| def has_init(self): | |||
| """tensor is inited.""" | |||
| return self.init is not None | |||
| @property | |||
| def virtual_flag(self): | |||
| """Mark tensor is virtual.""" | |||
| @@ -366,6 +385,69 @@ class Tensor(Tensor_): | |||
| return tensor_operator_registry.get('mean')(keep_dims)(self, axis) | |||
| def init_data(self, slice_index=None, shape=None, opt_shard_group=None): | |||
| """ | |||
| Get the tensor format data of this Tensor. | |||
| Args: | |||
| slice_index (int): Slice index of a parameter's slices. | |||
| It is used when initialize a slice of a parameter, it guarantees that devices | |||
| using the same slice can generate the same tensor. | |||
| shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter. | |||
| opt_shard_group(str): Optimizer shard group which is used in auto or semi auto parallel mode | |||
| to get one shard of a parameter's slice. | |||
| """ | |||
| if self.init is None: | |||
| raise TypeError("init_data must be set Tensor.init, init can't be None") | |||
| if shape is None: | |||
| shape = self.shape | |||
| try: | |||
| arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype)) | |||
| except ValueError: | |||
| msg = "Error shape={}".format(shape) | |||
| logger.error(msg) | |||
| raise ValueError(msg) | |||
| class seed_context: | |||
| '''set and restore seed''' | |||
| def __init__(self, init): | |||
| self.init = init | |||
| from .seed import get_seed | |||
| global_seed = get_seed() | |||
| self._np_seed = np.random.get_state()[1][0] | |||
| self.need_set_seed = ((slice_index is not None) and (global_seed is None)) | |||
| def __enter__(self): | |||
| if self.need_set_seed: | |||
| self.seed = self.init.seed | |||
| np.random.seed(slice_index) | |||
| self.init.seed = slice_index | |||
| def __exit__(self, ptype, value, trace): | |||
| if self.need_set_seed: | |||
| np.random.seed(self._np_seed) | |||
| self.init.seed, _ = self.seed | |||
| with seed_context(self.init): | |||
| self.init(arr) | |||
| data = np.array(arr) | |||
| if opt_shard_group: | |||
| rank = get_rank(opt_shard_group) | |||
| size = get_group_size(opt_shard_group) | |||
| data = np.split(data, size)[rank] | |||
| return Tensor(data, dtype=self.dtype) | |||
| def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None): | |||
| """Return init_data().""" | |||
| logger.warning("WARN_DEPRECATED: The usage of to_tensor is deprecated." | |||
| " Please use init_data") | |||
| return self.init_data(slice_index, shape, opt_shard_group) | |||
| class RowTensor: | |||
| """ | |||
| A sparse representation of a set of tensor slices at given indices. | |||
| @@ -498,76 +580,6 @@ class SparseTensor: | |||
| return self.__dense_shape | |||
| class MetaTensor(MetaTensor_): | |||
| """ | |||
| The base class of the MetaTensor. | |||
| Initialization of tensor basic attributes and model weight values. | |||
| Returns: | |||
| Array, an array after being initialized. | |||
| """ | |||
| def __init__(self, dtype, shape, init=None): | |||
| # check param | |||
| self.init = init | |||
| MetaTensor_.__init__(self, dtype, shape) | |||
| def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None): | |||
| """ | |||
| Get the tensor format data of this MetaTensor. | |||
| Args: | |||
| slice_index (int): Slice index of a parameter's slices. | |||
| It is used when initialize a slice of a parameter, it guarantees that devices | |||
| using the same slice can generate the same tensor. | |||
| shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter. | |||
| opt_shard_group(str): Optimizer shard group which is used in auto or semi auto parallel mode | |||
| to get one shard of a parameter's slice. | |||
| """ | |||
| if self.init is None: | |||
| raise TypeError("to_dense must be set MetaTensor.init, init can't be None") | |||
| if shape is None: | |||
| shape = self.shape | |||
| try: | |||
| arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype)) | |||
| except ValueError: | |||
| msg = "Error shape={}".format(shape) | |||
| logger.error(msg) | |||
| raise ValueError(msg) | |||
| class seed_context: | |||
| '''set and restore seed''' | |||
| def __init__(self, init): | |||
| self.init = init | |||
| from .seed import get_seed | |||
| global_seed = get_seed() | |||
| self._np_seed = np.random.get_state()[1][0] | |||
| self.need_set_seed = ((slice_index is not None) and (global_seed is None)) | |||
| def __enter__(self): | |||
| if self.need_set_seed: | |||
| self.seed = self.init.seed | |||
| np.random.seed(slice_index) | |||
| self.init.seed = slice_index | |||
| def __exit__(self, ptype, value, trace): | |||
| if self.need_set_seed: | |||
| np.random.seed(self._np_seed) | |||
| self.init.seed, _ = self.seed | |||
| with seed_context(self.init): | |||
| self.init(arr) | |||
| data = np.array(arr) | |||
| if opt_shard_group: | |||
| rank = get_rank(opt_shard_group) | |||
| size = get_group_size(opt_shard_group) | |||
| data = np.split(data, size)[rank] | |||
| return Tensor(data, dtype=self.dtype) | |||
| def _vm_compare(*args): | |||
| """Implement `vm_compare` for tensor.""" | |||
| obj_str = args[-1] | |||
| @@ -29,7 +29,7 @@ from .._checkparam import Validator | |||
| from ..common import dtype as mstype | |||
| from ..common.api import _executor, _pynative_exec | |||
| from ..common.parameter import Parameter, ParameterTuple | |||
| from ..common.tensor import Tensor, MetaTensor | |||
| from ..common.tensor import Tensor | |||
| from ..ops.functional import cast | |||
| from ..ops.operations import HookBackward | |||
| from ..ops.primitive import Primitive | |||
| @@ -589,7 +589,7 @@ class Cell(Cell_): | |||
| new_inputs = [] | |||
| for i in inputs: | |||
| if isinstance(i, (Tensor, MetaTensor)): | |||
| if isinstance(i, Tensor): | |||
| new_inputs.append(i) | |||
| if self._auto_parallel_mode: | |||
| @@ -15,7 +15,7 @@ | |||
| """embedding""" | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import log as logger | |||
| from mindspore.common.tensor import Tensor, MetaTensor | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.parameter import Parameter | |||
| @@ -101,8 +101,8 @@ class Embedding(Cell): | |||
| if padding_idx is not None: | |||
| self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH, | |||
| "padding_idx", self.cls_name) | |||
| if isinstance(self.init_tensor, MetaTensor): | |||
| self.init_tensor = self.init_tensor.to_tensor() | |||
| if isinstance(self.init_tensor, Tensor) and self.init_tensor.init is not None: | |||
| self.init_tensor = self.init_tensor.init_data() | |||
| self.init_tensor = self.init_tensor.asnumpy() | |||
| self.init_tensor[self.padding_idx] = 0 | |||
| self.init_tensor = Tensor(self.init_tensor) | |||
| @@ -23,7 +23,7 @@ from .. import signature as sig | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common.tensor import Tensor, MetaTensor | |||
| from ...common.tensor import Tensor | |||
| from .._utils import get_broadcast_shape | |||
| from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | |||
| @@ -2542,10 +2542,10 @@ class Equal(_LogicBinaryOp): | |||
| def infer_value(self, x, y): | |||
| if x is None or y is None: | |||
| return None | |||
| if isinstance(x, MetaTensor): | |||
| x = x.to_tensor() | |||
| if isinstance(y, MetaTensor): | |||
| y = y.to_tensor() | |||
| if isinstance(x, Tensor) and x.has_init: | |||
| x = x.init_data() | |||
| if isinstance(y, Tensor) and y.has_init: | |||
| y = y.init_data() | |||
| return Tensor(x.asnumpy() == y.asnumpy()) | |||
| @@ -253,33 +253,33 @@ def default_recurisive_init(custom_cell): | |||
| if 'hm' in name or 'wh' in name or 'off' in name or 'kps' in name: | |||
| if isinstance(cell, (nn.Conv2d)): | |||
| cell.weight.set_data(init.initializer(RandomNormal(), cell.weight.data.shape, | |||
| cell.weight.data.dtype).to_tensor()) | |||
| cell.weight.data.dtype)) | |||
| if cell.bias is not None: | |||
| cell.bias.set_data(init.initializer('zeros', cell.bias.data.shape, | |||
| cell.bias.data.dtype).to_tensor()) | |||
| cell.bias.data.dtype)) | |||
| continue | |||
| if isinstance(cell, (nn.Conv2d)): | |||
| cell.weight.set_data(init.initializer(KaimingNormal(mode='fan_out'), | |||
| cell.weight.data.shape, | |||
| cell.weight.data.dtype).to_tensor()) | |||
| cell.weight.data.dtype)) | |||
| if cell.bias is not None: | |||
| cell.bias.set_data(init.initializer('zeros', cell.bias.data.shape, | |||
| cell.bias.data.dtype).to_tensor()) | |||
| cell.bias.data.dtype)) | |||
| elif isinstance(cell, nn.Dense): | |||
| cell.weight.set_data(init.initializer(KaimingNormal(mode='fan_out'), | |||
| cell.weight.data.shape, | |||
| cell.weight.data.dtype).to_tensor()) | |||
| cell.weight.data.dtype)) | |||
| if cell.bias is not None: | |||
| cell.bias.set_data(init.initializer('zeros', cell.bias.data.shape, | |||
| cell.bias.data.dtype).to_tensor()) | |||
| cell.bias.data.dtype)) | |||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | |||
| cell.gamma.set_data(init.initializer('ones', cell.gamma.data.shape).to_tensor()) | |||
| cell.beta.set_data(init.initializer('zeros', cell.beta.data.shape).to_tensor()) | |||
| cell.gamma.set_data(init.initializer('ones', cell.gamma.data.shape)) | |||
| cell.beta.set_data(init.initializer('zeros', cell.beta.data.shape)) | |||
| elif isinstance(cell, nn.Conv2dTranspose): | |||
| cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5), mode='fan_out'), | |||
| cell.weight.data.shape, | |||
| cell.weight.data.dtype).to_tensor()) | |||
| cell.weight.data.dtype)) | |||
| if cell.bias is not None: | |||
| cell.bias.set_data(init.initializer('zeros', cell.bias.data.shape, | |||
| cell.bias.data.dtype).to_tensor()) | |||
| cell.bias.data.dtype)) | |||
| @@ -31,7 +31,7 @@ def bias_init_zeros(shape): | |||
| def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): | |||
| """Conv2D wrapper.""" | |||
| shape = (out_channels, in_channels, kernel_size, kernel_size) | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16) | |||
| shape_bias = (out_channels,) | |||
| biass = bias_init_zeros(shape_bias) | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| @@ -29,7 +29,7 @@ class DenseNoTranpose(nn.Cell): | |||
| super(DenseNoTranpose, self).__init__() | |||
| self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16)) | |||
| self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16).to_tensor()) | |||
| self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16)) | |||
| self.matmul = P.MatMul(transpose_b=False) | |||
| self.bias_add = P.BiasAdd() | |||
| @@ -79,16 +79,16 @@ class Rcnn(nn.Cell): | |||
| self.test_batch_size = cfg.test_batch_size | |||
| shape_0 = (self.rcnn_fc_out_channels, representation_size) | |||
| weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16).to_tensor() | |||
| weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16) | |||
| shape_1 = (self.rcnn_fc_out_channels, self.rcnn_fc_out_channels) | |||
| weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16).to_tensor() | |||
| weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16) | |||
| self.shared_fc_0 = DenseNoTranpose(representation_size, self.rcnn_fc_out_channels, weights_0) | |||
| self.shared_fc_1 = DenseNoTranpose(self.rcnn_fc_out_channels, self.rcnn_fc_out_channels, weights_1) | |||
| cls_weight = initializer('Normal', shape=[num_classes, self.rcnn_fc_out_channels][::-1], | |||
| dtype=mstype.float16).to_tensor() | |||
| dtype=mstype.float16) | |||
| reg_weight = initializer('Normal', shape=[num_classes * 4, self.rcnn_fc_out_channels][::-1], | |||
| dtype=mstype.float16).to_tensor() | |||
| dtype=mstype.float16) | |||
| self.cls_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes, cls_weight) | |||
| self.reg_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes * 4, reg_weight) | |||
| @@ -164,18 +164,18 @@ class RPN(nn.Cell): | |||
| shp_weight_conv = (feat_channels, in_channels, 3, 3) | |||
| shp_bias_conv = (feat_channels,) | |||
| weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor() | |||
| bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor() | |||
| weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16) | |||
| bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16) | |||
| shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1) | |||
| shp_bias_cls = (num_anchors * cls_out_channels,) | |||
| weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16).to_tensor() | |||
| bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16).to_tensor() | |||
| weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16) | |||
| bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16) | |||
| shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1) | |||
| shp_bias_reg = (num_anchors * 4,) | |||
| weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16).to_tensor() | |||
| bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16).to_tensor() | |||
| weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16) | |||
| bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16) | |||
| for i in range(num_layers): | |||
| rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ | |||
| @@ -29,7 +29,7 @@ def bias_init_zeros(shape): | |||
| def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): | |||
| """Conv2D wrapper.""" | |||
| shape = (out_channels, in_channels, kernel_size, kernel_size) | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16) | |||
| shape_bias = (out_channels,) | |||
| biass = bias_init_zeros(shape_bias) | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| @@ -27,7 +27,7 @@ class DenseNoTranpose(nn.Cell): | |||
| def __init__(self, input_channels, output_channels, weight_init): | |||
| super(DenseNoTranpose, self).__init__() | |||
| self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16)) | |||
| self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16).to_tensor()) | |||
| self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16)) | |||
| self.matmul = P.MatMul(transpose_b=False) | |||
| self.bias_add = P.BiasAdd() | |||
| @@ -41,16 +41,16 @@ class FpnCls(nn.Cell): | |||
| super(FpnCls, self).__init__() | |||
| representation_size = input_channels * pool_size * pool_size | |||
| shape_0 = (output_channels, representation_size) | |||
| weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16).to_tensor() | |||
| weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16) | |||
| shape_1 = (output_channels, output_channels) | |||
| weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16).to_tensor() | |||
| weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16) | |||
| self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0) | |||
| self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1) | |||
| cls_weight = initializer('Normal', shape=[num_classes, output_channels][::-1], | |||
| dtype=mstype.float16).to_tensor() | |||
| dtype=mstype.float16) | |||
| reg_weight = initializer('Normal', shape=[num_classes * 4, output_channels][::-1], | |||
| dtype=mstype.float16).to_tensor() | |||
| dtype=mstype.float16) | |||
| self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight) | |||
| self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight) | |||
| @@ -24,7 +24,7 @@ from mindspore.common.initializer import initializer | |||
| def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'): | |||
| """Conv2D wrapper.""" | |||
| shape = (out_channels, in_channels, kernel_size, kernel_size) | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16) | |||
| shape_bias = (out_channels,) | |||
| bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16)) | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| @@ -34,7 +34,7 @@ def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mod | |||
| def _convTanspose(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'): | |||
| """ConvTranspose wrapper.""" | |||
| shape = (out_channels, in_channels, kernel_size, kernel_size) | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16) | |||
| shape_bias = (out_channels,) | |||
| bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16)) | |||
| return nn.Conv2dTranspose(in_channels, out_channels, | |||
| @@ -164,18 +164,18 @@ class RPN(nn.Cell): | |||
| shp_weight_conv = (feat_channels, in_channels, 3, 3) | |||
| shp_bias_conv = (feat_channels,) | |||
| weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor() | |||
| bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor() | |||
| weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16) | |||
| bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16) | |||
| shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1) | |||
| shp_bias_cls = (num_anchors * cls_out_channels,) | |||
| weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16).to_tensor() | |||
| bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16).to_tensor() | |||
| weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16) | |||
| bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16) | |||
| shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1) | |||
| shp_bias_reg = (num_anchors * 4,) | |||
| weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16).to_tensor() | |||
| bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16).to_tensor() | |||
| weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16) | |||
| bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16) | |||
| for i in range(num_layers): | |||
| rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ | |||
| @@ -28,7 +28,7 @@ class DenseNoTranpose(nn.Cell): | |||
| super(DenseNoTranpose, self).__init__() | |||
| self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16), | |||
| name="weight") | |||
| self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16).to_tensor(), name="bias") | |||
| self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16), name="bias") | |||
| self.matmul = P.MatMul(transpose_b=False) | |||
| self.bias_add = P.BiasAdd() | |||
| @@ -42,16 +42,16 @@ class FpnCls(nn.Cell): | |||
| super(FpnCls, self).__init__() | |||
| representation_size = input_channels * pool_size * pool_size | |||
| shape_0 = (output_channels, representation_size) | |||
| weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16).to_tensor() | |||
| weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16) | |||
| shape_1 = (output_channels, output_channels) | |||
| weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16).to_tensor() | |||
| weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16) | |||
| self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0) | |||
| self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1) | |||
| cls_weight = initializer('Normal', shape=[num_classes, output_channels][::-1], | |||
| dtype=mstype.float16).to_tensor() | |||
| dtype=mstype.float16) | |||
| reg_weight = initializer('Normal', shape=[num_classes * 4, output_channels][::-1], | |||
| dtype=mstype.float16).to_tensor() | |||
| dtype=mstype.float16) | |||
| self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight) | |||
| self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight) | |||
| @@ -24,7 +24,7 @@ from mindspore.common.initializer import initializer | |||
| def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'): | |||
| """Conv2D wrapper.""" | |||
| shape = (out_channels, in_channels, kernel_size, kernel_size) | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16) | |||
| shape_bias = (out_channels,) | |||
| bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16)) | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| @@ -34,7 +34,7 @@ def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mod | |||
| def _convTanspose(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'): | |||
| """ConvTranspose wrapper.""" | |||
| shape = (out_channels, in_channels, kernel_size, kernel_size) | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() | |||
| weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16) | |||
| shape_bias = (out_channels,) | |||
| bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16)) | |||
| return nn.Conv2dTranspose(in_channels, out_channels, | |||
| @@ -164,18 +164,18 @@ class RPN(nn.Cell): | |||
| shp_weight_conv = (feat_channels, in_channels, 3, 3) | |||
| shp_bias_conv = (feat_channels,) | |||
| weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor() | |||
| bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor() | |||
| weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16) | |||
| bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16) | |||
| shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1) | |||
| shp_bias_cls = (num_anchors * cls_out_channels,) | |||
| weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16).to_tensor() | |||
| bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16).to_tensor() | |||
| weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16) | |||
| bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16) | |||
| shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1) | |||
| shp_bias_reg = (num_anchors * 4,) | |||
| weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16).to_tensor() | |||
| bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16).to_tensor() | |||
| weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16) | |||
| bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16) | |||
| for i in range(num_layers): | |||
| rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ | |||
| @@ -34,7 +34,7 @@ def _make_layer(base, args, batch_norm): | |||
| weight = 'ones' | |||
| if args.initialize_mode == "XavierUniform": | |||
| weight_shape = (v, in_channels, 3, 3) | |||
| weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() | |||
| weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32) | |||
| conv2d = nn.Conv2d(in_channels=in_channels, | |||
| out_channels=v, | |||
| @@ -158,7 +158,7 @@ def default_recurisive_init(custom_cell): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.data.shape, | |||
| cell.weight.data.dtype).to_tensor()) | |||
| cell.weight.data.dtype)) | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.data.asnumpy()) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| @@ -167,7 +167,7 @@ def default_recurisive_init(custom_cell): | |||
| elif isinstance(cell, nn.Dense): | |||
| cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.data.shape, | |||
| cell.weight.data.dtype).to_tensor()) | |||
| cell.weight.data.dtype)) | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.data.asnumpy()) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| @@ -879,13 +879,13 @@ class CreateAttentionMaskFromInputMask(nn.Cell): | |||
| if not self.input_mask_from_dataset: | |||
| self.input_mask = initializer( | |||
| "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor() | |||
| "ones", [config.batch_size, config.seq_length], mstype.int32).init_data() | |||
| self.cast = P.Cast() | |||
| self.reshape = P.Reshape() | |||
| self.shape = (config.batch_size, 1, config.seq_length) | |||
| self.broadcast_ones = initializer( | |||
| "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor() | |||
| "ones", [config.batch_size, config.seq_length, 1], mstype.float32).init_data() | |||
| self.batch_matmul = P.BatchMatMul() | |||
| def construct(self, input_mask): | |||
| @@ -932,7 +932,7 @@ class BertModel(nn.Cell): | |||
| if not self.token_type_ids_from_dataset: | |||
| self.token_type_ids = initializer( | |||
| "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() | |||
| "zeros", [self.batch_size, self.seq_length], mstype.int32).init_data() | |||
| self.bert_embedding_lookup = Embedding_Thor( | |||
| vocab_size=config.vocab_size, | |||
| @@ -209,7 +209,7 @@ def test_bert_performance(): | |||
| for param in params: | |||
| value = param.data | |||
| name = param.name | |||
| if isinstance(value, Tensor): | |||
| if isinstance(value, Tensor) and not value.has_init: | |||
| if name.split('.')[-1] in ['weight']: | |||
| if name.split('.')[-3] in ['cls2']: | |||
| logger.info("***************** BERT param name is 1 {}".format(name)) | |||
| @@ -817,13 +817,13 @@ class CreateAttentionMaskFromInputMask(nn.Cell): | |||
| if not self.input_mask_from_dataset: | |||
| self.input_mask = initializer( | |||
| "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor() | |||
| "ones", [config.batch_size, config.seq_length], mstype.int32).init_data() | |||
| self.cast = P.Cast() | |||
| self.reshape = P.Reshape() | |||
| self.shape = (config.batch_size, 1, config.seq_length) | |||
| self.broadcast_ones = initializer( | |||
| "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor() | |||
| "ones", [config.batch_size, config.seq_length, 1], mstype.float32).init_data() | |||
| self.batch_matmul = P.BatchMatMul() | |||
| def construct(self, input_mask): | |||
| @@ -869,7 +869,7 @@ class BertModel(nn.Cell): | |||
| if not self.token_type_ids_from_dataset: | |||
| self.token_type_ids = initializer( | |||
| "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() | |||
| "zeros", [self.batch_size, self.seq_length], mstype.int32).init_data() | |||
| self.bert_embedding_lookup = EmbeddingLookup( | |||
| vocab_size=config.vocab_size, | |||
| @@ -195,14 +195,14 @@ def test_parameter_lazy_init(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8) | |||
| # Call init_data() without set set_data. | |||
| para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1') | |||
| assert not isinstance(para.data, Tensor) | |||
| assert isinstance(para.data, Tensor) | |||
| para = para.init_data() | |||
| assert isinstance(para.data, Tensor) | |||
| assert np.array_equal(para.data.asnumpy(), np.ones((1, 2, 3))) | |||
| # Call init_data() after set_data is set. | |||
| para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2') | |||
| assert not isinstance(para.data, Tensor) | |||
| assert isinstance(para.data, Tensor) | |||
| # expect type error when not init | |||
| with pytest.raises(TypeError): | |||
| para.set_data(Tensor(np.zeros((1, 2, 3)))) | |||
| @@ -20,7 +20,7 @@ import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore._c_expression import MetaTensor as MetaTensor_ | |||
| from mindspore._c_expression import MetaTensor | |||
| from mindspore.common import dtype | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops import functional as F | |||
| @@ -70,8 +70,8 @@ def scalar_mul_while(x): | |||
| return rv | |||
| @ms_function(input_signature=(MetaTensor_(dtype.float32, (1, 1, 3, 3)), | |||
| MetaTensor_(dtype.float32, (1, 1, 3, 3)))) | |||
| @ms_function(input_signature=(MetaTensor(dtype.float32, (1, 1, 3, 3)), | |||
| MetaTensor(dtype.float32, (1, 1, 3, 3)))) | |||
| def tensor_add_test(x, y): | |||
| """ tensor_add_test """ | |||
| z = F.tensor_add(x, y) | |||
| @@ -24,7 +24,7 @@ import mindspore.common.initializer as init | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.tensor import Tensor, MetaTensor | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.nn import Conv2d | |||
| from mindspore.ops import operations as P | |||
| from ..ut_filter import non_graph_engine | |||
| @@ -58,8 +58,8 @@ def _check_uniform(tensor, boundary_a, boundary_b): | |||
| def test_init_Initializer(): | |||
| tensor = init.initializer(InitTwo(), [2, 2], ms.int32) | |||
| assert tensor.shape == [2, 2] | |||
| _check_value(tensor.to_tensor(), 2, 2) | |||
| assert tensor.shape == (2, 2) | |||
| _check_value(tensor.init_data(), 2, 2) | |||
| def test_init_tensor(): | |||
| @@ -71,70 +71,70 @@ def test_init_tensor(): | |||
| def test_init_zero_default_dtype(): | |||
| tensor = init.initializer(init.Zero(), [2, 2]) | |||
| assert tensor.dtype == ms.float32 | |||
| _check_value(tensor.to_tensor(), 0, 0) | |||
| _check_value(tensor.init_data(), 0, 0) | |||
| def test_init_zero(): | |||
| tensor = init.initializer(init.Zero(), [2, 2], ms.float32) | |||
| _check_value(tensor.to_tensor(), 0, 0) | |||
| _check_value(tensor.init_data(), 0, 0) | |||
| def test_init_zero_alias_default_dtype(): | |||
| tensor = init.initializer('zeros', [1, 2]) | |||
| assert tensor.dtype == ms.float32 | |||
| _check_value(tensor.to_tensor(), 0, 0) | |||
| _check_value(tensor.init_data(), 0, 0) | |||
| def test_init_zero_alias(): | |||
| tensor = init.initializer('zeros', [1, 2], ms.float32) | |||
| _check_value(tensor.to_tensor(), 0, 0) | |||
| _check_value(tensor.init_data(), 0, 0) | |||
| def test_init_one(): | |||
| tensor = init.initializer(init.One(), [2, 2], ms.float32) | |||
| _check_value(tensor.to_tensor(), 1, 1) | |||
| _check_value(tensor.init_data(), 1, 1) | |||
| def test_init_one_alias(): | |||
| tensor = init.initializer('ones', [1, 2], ms.float32) | |||
| _check_value(tensor.to_tensor(), 1, 1) | |||
| _check_value(tensor.init_data(), 1, 1) | |||
| def test_init_constant(): | |||
| tensor = init.initializer(init.Constant(1), [2, 2], ms.float32) | |||
| _check_value(tensor.to_tensor(), 1, 1) | |||
| _check_value(tensor.init_data(), 1, 1) | |||
| def test_init_uniform(): | |||
| scale = 10 | |||
| tensor = init.initializer(init.Uniform(scale=scale), [5, 4], ms.float32) | |||
| _check_value(tensor.to_tensor(), -scale, scale) | |||
| _check_value(tensor.init_data(), -scale, scale) | |||
| def test_init_uniform_alias(): | |||
| scale = 100 | |||
| tensor = init.initializer('uniform', [5, 4], ms.float32) | |||
| _check_value(tensor.to_tensor(), -scale, scale) | |||
| _check_value(tensor.init_data(), -scale, scale) | |||
| def test_init_normal(): | |||
| tensor = init.initializer(init.Normal(), [5, 4], ms.float32) | |||
| assert isinstance(tensor, MetaTensor), 'Normal init failed!' | |||
| assert isinstance(tensor, Tensor), 'Normal init failed!' | |||
| def test_init_truncated_normal(): | |||
| tensor = init.initializer(init.TruncatedNormal(), [5, 4], ms.float32) | |||
| assert isinstance(tensor, MetaTensor), 'TruncatedNormal init failed!' | |||
| assert isinstance(tensor, Tensor), 'TruncatedNormal init failed!' | |||
| def test_init_normal_alias(): | |||
| tensor = init.initializer('normal', [5, 4], ms.float32) | |||
| assert isinstance(tensor, MetaTensor), 'Normal init failed!' | |||
| assert isinstance(tensor, Tensor), 'Normal init failed!' | |||
| def test_init_truncatednormal_alias(): | |||
| tensor = init.initializer('truncatednormal', [5, 4], ms.float32) | |||
| assert isinstance(tensor, MetaTensor), 'TruncatedNormal init failed!' | |||
| assert isinstance(tensor, Tensor), 'TruncatedNormal init failed!' | |||
| def test_init_abnormal(): | |||
| @@ -144,18 +144,18 @@ def test_init_abnormal(): | |||
| def test_initializer_reinit(): | |||
| weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16) | |||
| assert isinstance(weights, MetaTensor), 'XavierUniform init failed!' | |||
| assert isinstance(weights, Tensor), 'XavierUniform init failed!' | |||
| def test_init_xavier_uniform(): | |||
| """ test_init_xavier_uniform """ | |||
| gain = 1.2 | |||
| tensor1 = init.initializer(init.XavierUniform(gain=gain), [20, 22], ms.float32).to_tensor() | |||
| tensor2 = init.initializer(init.XavierUniform(), [20, 22], ms.float32).to_tensor() | |||
| tensor3 = init.initializer(init.XavierUniform(gain=gain), [20, 22, 5, 5], ms.float32).to_tensor() | |||
| tensor4 = init.initializer(init.XavierUniform(), [20, 22, 5, 5], ms.float32).to_tensor() | |||
| tensor5 = init.initializer('xavier_uniform', [20, 22, 5, 5], ms.float32).to_tensor() | |||
| tensor6 = init.initializer('xavier_uniform', [20, 22], ms.float32).to_tensor() | |||
| tensor1 = init.initializer(init.XavierUniform(gain=gain), [20, 22], ms.float32).init_data() | |||
| tensor2 = init.initializer(init.XavierUniform(), [20, 22], ms.float32).init_data() | |||
| tensor3 = init.initializer(init.XavierUniform(gain=gain), [20, 22, 5, 5], ms.float32).init_data() | |||
| tensor4 = init.initializer(init.XavierUniform(), [20, 22, 5, 5], ms.float32).init_data() | |||
| tensor5 = init.initializer('xavier_uniform', [20, 22, 5, 5], ms.float32).init_data() | |||
| tensor6 = init.initializer('xavier_uniform', [20, 22], ms.float32).init_data() | |||
| tensor_dict = {tensor1: gain, tensor2: None, tensor3: gain, tensor4: None, tensor5: None, tensor6: None} | |||
| for tensor, gain_value in tensor_dict.items(): | |||
| @@ -175,7 +175,7 @@ def test_init_xavier_uniform(): | |||
| def test_init_xavier_uniform_error(): | |||
| with py.raises(ValueError): | |||
| init.initializer(init.XavierUniform(), [6], ms.float32).to_tensor() | |||
| init.initializer(init.XavierUniform(), [6], ms.float32).init_data() | |||
| def test_init_he_uniform(): | |||
| @@ -184,7 +184,7 @@ def test_init_he_uniform(): | |||
| tensor2 = init.initializer(init.HeUniform(), [20, 22, 5, 5], ms.float32) | |||
| tensor3 = init.initializer('he_uniform', [20, 22, 5, 5], ms.float32) | |||
| tensor4 = init.initializer('he_uniform', [20, 22], ms.float32) | |||
| tensors = [tensor1.to_tensor(), tensor2.to_tensor(), tensor3.to_tensor(), tensor4.to_tensor()] | |||
| tensors = [tensor1.init_data(), tensor2.init_data(), tensor3.init_data(), tensor4.init_data()] | |||
| for tensor in tensors: | |||
| shape = tensor.asnumpy().shape | |||
| @@ -200,7 +200,7 @@ def test_init_he_uniform(): | |||
| def test_init_he_uniform_error(): | |||
| with py.raises(ValueError): | |||
| init.initializer(init.HeUniform(), [6], ms.float32).to_tensor() | |||
| init.initializer(init.HeUniform(), [6], ms.float32).init_data() | |||
| def test_conv2d_abnormal_kernel_negative(): | |||
| @@ -224,7 +224,7 @@ def test_conv2d_abnormal_kernel_normal(): | |||
| @non_graph_engine | |||
| def test_conv2d_abnormal_kernel_truncated_normal(): | |||
| input_data = init.initializer(init.TruncatedNormal(), [64, 3, 7, 7], ms.float32).to_tensor() | |||
| input_data = init.initializer(init.TruncatedNormal(), [64, 3, 7, 7], ms.float32).init_data() | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| model = ms.Model( | |||
| Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=3, | |||