Merge pull request !1015 from candanzg/me_with_shapetags/v0.3.0-alpha
| @@ -319,6 +319,10 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { | |||
| std::shared_ptr<tensor::Tensor> m_tensor = py_p.cast<std::shared_ptr<tensor::Tensor>>(); | |||
| py::tuple shape = m_tensor->GetPyTupleShape(); | |||
| buffer_ << "[" << std::string(py::str(shape)) << "]"; | |||
| } else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) { | |||
| std::shared_ptr<tensor::MetaTensor> m_tensor = py_p.cast<std::shared_ptr<tensor::MetaTensor>>(); | |||
| py::tuple shape = m_tensor->GetPyTupleShape(); | |||
| buffer_ << "[" << std::string(py::str(shape)) << "]"; | |||
| } | |||
| } | |||
| buffer_ << "</td></tr>"; | |||
| @@ -102,6 +102,26 @@ int MetaTensor::DimensionSize(const size_t index) const { | |||
| return dim_size; | |||
| } | |||
| abstract::AbstractBasePtr MetaTensor::ToAbstract() { | |||
| auto tens = shared_from_base<MetaTensor>(); | |||
| auto dtype = tens->Dtype(); | |||
| if (!IsSubType(dtype, kNumber)) { | |||
| MS_LOG(EXCEPTION) << "Expect MetaTensor type kNumber but got: " << dtype->ToString() << "."; | |||
| } | |||
| auto tensor_shape = tens->shape(); | |||
| auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); | |||
| abs_tensor->set_value(shared_from_base<MetaTensor>()); | |||
| return abs_tensor; | |||
| } | |||
| py::tuple MetaTensor::GetPyTupleShape() const { | |||
| py::tuple dims(shape_.size()); | |||
| for (size_t i = 0; i < dims.size(); ++i) { | |||
| dims[i] = py::int_(shape_[i]); | |||
| } | |||
| return dims; | |||
| } | |||
| int MetaTensor::ElementsNum() const { | |||
| return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies<int>()); | |||
| } | |||
| @@ -197,14 +217,6 @@ int Tensor::DataDim() const { return static_cast<int>(data_.ndim()); } | |||
| int Tensor::DataSize() const { return static_cast<int>(data_.size()); } | |||
| py::tuple Tensor::GetPyTupleShape() const { | |||
| py::tuple dims(shape_.size()); | |||
| for (size_t i = 0; i < dims.size(); ++i) { | |||
| dims[i] = py::int_(shape_[i]); | |||
| } | |||
| return dims; | |||
| } | |||
| py::array Tensor::data() const { return data_; } | |||
| int Tensor::data_type_c() const { return static_cast<int>(data_type_); } | |||
| @@ -547,7 +559,10 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||
| return tensor; | |||
| })); | |||
| (void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor") | |||
| .def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape")); | |||
| .def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape")) | |||
| .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) | |||
| .def("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | |||
| .def("shape", &MetaTensor::GetPyTupleShape, "Get the MetaTensor's shape."); | |||
| })); | |||
| } // namespace tensor | |||
| @@ -163,6 +163,8 @@ class MetaTensor : public Value { | |||
| // | |||
| // All the types are defined in "ir/dtype.h". | |||
| TypePtr Dtype() const; | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| py::tuple GetPyTupleShape() const; | |||
| TypeId data_type() const { return data_type_; } | |||
| std::string ToString() const override; | |||
| std::string DumpText() const override; | |||
| @@ -230,6 +232,7 @@ class MetaTensor : public Value { | |||
| return false; | |||
| } | |||
| } | |||
| const bool parse_info_ = true; | |||
| protected: | |||
| // brief Data type of the tensor. | |||
| @@ -348,11 +351,6 @@ class Tensor : public MetaTensor { | |||
| // return The total number of elements of the tensor data. | |||
| int DataSize() const; | |||
| // brief Get tensor's shape | |||
| // | |||
| // return [py::tuple] The tensor's shape | |||
| py::tuple GetPyTupleShape() const; | |||
| // brief Tensor's data value. | |||
| // | |||
| // return [py::array] The tensor's data in py::array. | |||
| @@ -423,6 +421,7 @@ class Tensor : public MetaTensor { | |||
| }; | |||
| using TensorPtr = std::shared_ptr<Tensor>; | |||
| using MetaTensorPtr = std::shared_ptr<MetaTensor>; | |||
| using TensorPtrList = std::vector<std::shared_ptr<Tensor>>; | |||
| } // namespace tensor | |||
| @@ -36,6 +36,8 @@ namespace mindspore { | |||
| namespace parse { | |||
| using Tensor = mindspore::tensor::Tensor; | |||
| using TensorPtr = mindspore::tensor::TensorPtr; | |||
| using MetaTensor = mindspore::tensor::MetaTensor; | |||
| using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; | |||
| namespace { | |||
| bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { | |||
| @@ -181,6 +183,18 @@ bool ConvertDataType(const py::object &obj, ValuePtr *const data) { | |||
| return true; | |||
| } | |||
| bool ConvertMetaTensor(const py::object &obj, ValuePtr *const data) { | |||
| MS_LOG(DEBUG) << "Converting MetaTensor object."; | |||
| auto m_tensor = obj.cast<MetaTensorPtr>(); | |||
| if (m_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "Resolve MetaTensor error, get ptr is null."; | |||
| return false; | |||
| } | |||
| *data = m_tensor; | |||
| return true; | |||
| } | |||
| bool ConvertTensor(const py::object &obj, ValuePtr *const data) { | |||
| MS_LOG(DEBUG) << "Converting tensor object"; | |||
| @@ -283,6 +297,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature | |||
| ret = ConvertDataType(obj, &converted); | |||
| } else if (py::hasattr(obj, PYTHON_TENSOR_FLAG)) { | |||
| ret = ConvertTensor(obj, &converted); | |||
| } else if (py::hasattr(obj, PYTHON_META_TENSOR_FLAG)) { | |||
| ret = ConvertMetaTensor(obj, &converted); | |||
| } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) { | |||
| std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>(); | |||
| converted = env; | |||
| @@ -20,6 +20,7 @@ namespace mindspore { | |||
| const char PYTHON_PRIMITIVE_FLAG[] = "__primitive_flag__"; | |||
| const char PYTHON_METAFUNCGRAPH_FLAG[] = "__metafuncgraph_flag__"; | |||
| const char PYTHON_TENSOR_FLAG[] = "__tensor_flag__"; | |||
| const char PYTHON_META_TENSOR_FLAG[] = "__meta_tensor_flag__"; | |||
| const char PYTHON_ENVINSTANCE_FLAG[] = "__envinstance_flag__"; | |||
| const char PYTHON_DTYPE_FLAG[] = "__dtype_flag__"; | |||
| const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__"; | |||
| @@ -22,6 +22,7 @@ namespace mindspore { | |||
| extern const char PYTHON_PRIMITIVE_FLAG[]; | |||
| extern const char PYTHON_METAFUNCGRAPH_FLAG[]; | |||
| extern const char PYTHON_TENSOR_FLAG[]; | |||
| extern const char PYTHON_META_TENSOR_FLAG[]; | |||
| extern const char PYTHON_ENVINSTANCE_FLAG[]; | |||
| extern const char PYTHON_DTYPE_FLAG[]; | |||
| extern const char PYTHON_CELL_AS_LIST[]; | |||
| @@ -71,6 +71,11 @@ py::object ValuePtrToPyData(const ValuePtr &value) { | |||
| py::tuple v(1); | |||
| v[0] = value->cast<tensor::TensorPtr>(); | |||
| ret = v[0]; | |||
| } else if (value->isa<tensor::MetaTensor>()) { | |||
| MS_LOG(DEBUG) << "MetaTensor"; | |||
| py::tuple v(1); | |||
| v[0] = value->cast<tensor::MetaTensorPtr>(); | |||
| ret = v[0]; | |||
| } else if (value->isa<RefKey>()) { | |||
| MS_LOG(DEBUG) << "RefKey"; | |||
| py::tuple v(1); | |||
| @@ -326,6 +326,12 @@ class _Executor: | |||
| raise TypeError('Parameters need OrderedDict type, but got {}'. | |||
| format(type(params))) | |||
| def _params_init_data(self, obj, params): | |||
| if params is not None: | |||
| for _, param in params.items(): | |||
| param.init_data() | |||
| obj.init_parameters_data() | |||
| def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False): | |||
| """ | |||
| Compiles graph. | |||
| @@ -371,6 +377,7 @@ class _Executor: | |||
| if not do_convert: | |||
| return phase, True | |||
| self._params_init_data(obj, params) | |||
| if not enable_debug_runtime or enable_ge: | |||
| if auto_parallel_mode: | |||
| obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) | |||
| @@ -39,6 +39,8 @@ class Initializer: | |||
| """ | |||
| def __init__(self, **kwargs): | |||
| self._kwargs = kwargs | |||
| self.shape = None | |||
| self.dtype = None | |||
| def _initialize(self, *kwargs): | |||
| raise NotImplementedError('Must be overridden!') | |||
| @@ -46,6 +48,32 @@ class Initializer: | |||
| def __call__(self, arr): | |||
| return self._initialize(arr) | |||
| @property | |||
| def shape(self): | |||
| return self._shape | |||
| @shape.setter | |||
| def shape(self, shape): | |||
| self._shape = shape | |||
| @property | |||
| def dtype(self): | |||
| return self._dtype | |||
| @dtype.setter | |||
| def dtype(self, dtype): | |||
| self._dtype = dtype | |||
| def to_tensor(self): | |||
| arr = None | |||
| try: | |||
| arr = np.ndarray(self.shape) | |||
| except ValueError: | |||
| msg = "Error shape={}".format(self.shape) | |||
| logger.error(msg) | |||
| raise ValueError(msg) | |||
| self.__call__(arr) | |||
| return Tensor(arr, dtype=self.dtype) | |||
| def _register(*aliases): | |||
| """Return the alias register.""" | |||
| @@ -279,13 +307,14 @@ def initializer(init, shape=None, dtype=mstype.float32): | |||
| dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mindspore.float32. | |||
| Returns: | |||
| Tensor, initialized tensor. | |||
| Union[Tensor, Initialized], When `init` is Tensor, the return is Tensor object, | |||
| otherwise the return is Initialize object. | |||
| Examples: | |||
| >>> tensor = initializer('ones', [1, 2, 3], mindspore.float32) | |||
| """ | |||
| if not isinstance(init, (Tensor, numbers.Number, str, Initializer)): | |||
| raise TypeError('Unsupported init type.') | |||
| raise TypeError("Unsupported init type '{}'.".format(type(init))) | |||
| if isinstance(init, Tensor): | |||
| init_shape = init.shape() | |||
| @@ -295,23 +324,32 @@ def initializer(init, shape=None, dtype=mstype.float32): | |||
| "the variable shape {}.".format(list(init.shape()), shape)) | |||
| return init | |||
| if isinstance(init, str): | |||
| init_obj = _INITIALIZER_ALIAS[init.lower()]() | |||
| if init_obj is None: | |||
| raise ValueError("The class corresponding to '{}' was not found.".format(init)) | |||
| init = init_obj | |||
| if isinstance(shape, list): | |||
| shape = tuple(shape) | |||
| elif isinstance(shape, numbers.Number): | |||
| shape = (shape,) | |||
| try: | |||
| arr = np.ndarray(shape) | |||
| np.ndarray(shape) | |||
| except ValueError: | |||
| msg = "Error shape={}".format(shape) | |||
| logger.error(msg) | |||
| raise ValueError(msg) | |||
| raise ValueError("Error shape={}".format(shape)) | |||
| if isinstance(init, Initializer): | |||
| init.shape = shape | |||
| init.dtype = dtype | |||
| return init | |||
| if isinstance(init, numbers.Number): | |||
| init_obj = Constant(init) | |||
| elif isinstance(init, str): | |||
| init_obj = _INITIALIZER_ALIAS[init.lower()]() | |||
| else: | |||
| init_obj = init | |||
| init_obj(arr) | |||
| return Tensor(arr, dtype=dtype) | |||
| init_obj.shape = shape | |||
| init_obj.dtype = dtype | |||
| return init_obj | |||
| raise TypeError("Unsupported init type '{}'.".format(type(init))) | |||
| __all__ = [ | |||
| 'Initializer', | |||
| @@ -14,9 +14,10 @@ | |||
| # ============================================================================ | |||
| """Parameter for cell.""" | |||
| import numbers | |||
| from copy import copy, deepcopy | |||
| from .initializer import initializer | |||
| from .tensor import Tensor | |||
| from .initializer import initializer, Initializer | |||
| from .tensor import Tensor, MetaTensor | |||
| from .._checkparam import _check_str_by_regular | |||
| from ..parallel._utils import _set_clone_info, _CloneInfo | |||
| @@ -41,7 +42,8 @@ class Parameter: | |||
| Each parameter of Cell is represented by Parameter class. | |||
| Args: | |||
| default_input (Tensor): A parameter tensor. | |||
| default_input (Union[Tensor, Initializer]): Parameter data, when `default_input` is` Initializer`, | |||
| the data stored by Parameter is `MetaTensor`, otherwise it is `Tensor`. | |||
| name (str): Name of the child parameter. | |||
| requires_grad (bool): True if the parameter requires gradient. Default: True. | |||
| layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, | |||
| @@ -123,7 +125,11 @@ class Parameter: | |||
| if init != 'same': | |||
| shape = self.default_input.shape() | |||
| dtype = self.default_input.dtype() | |||
| x.default_input = initializer(init, shape=shape, dtype=dtype) | |||
| if isinstance(init, (str, Initializer, numbers.Number)): | |||
| x.init_mode = initializer(init, shape=shape, dtype=dtype) | |||
| x.default_input = MetaTensor(dtype, shape) | |||
| else: | |||
| x.default_input = initializer(init, shape=shape, dtype=dtype) | |||
| x.clone_info = copy(self.clone_info) | |||
| _set_clone_info(self.clone_info, x.clone_info) | |||
| @@ -181,11 +187,21 @@ class Parameter: | |||
| if isinstance(data, Tensor): | |||
| # make a copy of Tensor to init the parameter | |||
| data = Tensor(data.asnumpy().copy()) | |||
| elif isinstance(data, Initializer): | |||
| self.init_mode = data | |||
| data = MetaTensor(self.init_mode.dtype, self.init_mode.shape) | |||
| else: | |||
| data = Tensor(data) | |||
| self.default_input = data | |||
| def init_data(self): | |||
| if not isinstance(self.default_input, MetaTensor): | |||
| return | |||
| self.default_input = self.init_mode.to_tensor() | |||
| self.init_mode = None | |||
| class ParameterTuple(tuple): | |||
| """ | |||
| Class for storing tuple of parameters. | |||
| @@ -92,7 +92,7 @@ class GetMaskedLMOutput(nn.Cell): | |||
| config.hidden_size, | |||
| weight_init=weight_init, | |||
| activation=config.hidden_act).to_float(config.compute_type) | |||
| self.layernorm = nn.LayerNorm(config.hidden_size).to_float(config.compute_type) | |||
| self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) | |||
| self.output_bias = Parameter( | |||
| initializer( | |||
| 'zero', | |||
| @@ -190,7 +190,7 @@ class EmbeddingPostprocessor(nn.Cell): | |||
| self.array_mul = P.MatMul() | |||
| self.reshape = P.Reshape() | |||
| self.shape = tuple(embedding_shape) | |||
| self.layernorm = nn.LayerNorm(embedding_size) | |||
| self.layernorm = nn.LayerNorm((embedding_size,)) | |||
| self.dropout = nn.Dropout(1 - dropout_prob) | |||
| self.gather = P.GatherV2() | |||
| self.use_relative_positions = use_relative_positions | |||
| @@ -246,7 +246,7 @@ class BertOutput(nn.Cell): | |||
| weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) | |||
| self.dropout = nn.Dropout(1 - dropout_prob) | |||
| self.add = P.TensorAdd() | |||
| self.layernorm = nn.LayerNorm(out_channels).to_float(compute_type) | |||
| self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) | |||
| self.cast = P.Cast() | |||
| def construct(self, hidden_status, input_tensor): | |||
| @@ -802,13 +802,13 @@ class CreateAttentionMaskFromInputMask(nn.Cell): | |||
| if not self.input_mask_from_dataset: | |||
| self.input_mask = initializer( | |||
| "ones", [config.batch_size, config.seq_length], mstype.int32) | |||
| "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor() | |||
| self.cast = P.Cast() | |||
| self.reshape = P.Reshape() | |||
| self.shape = (config.batch_size, 1, config.seq_length) | |||
| self.broadcast_ones = initializer( | |||
| "ones", [config.batch_size, config.seq_length, 1], mstype.float32) | |||
| "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor() | |||
| self.batch_matmul = P.BatchMatMul() | |||
| def construct(self, input_mask): | |||
| @@ -854,7 +854,7 @@ class BertModel(nn.Cell): | |||
| if not self.token_type_ids_from_dataset: | |||
| self.token_type_ids = initializer( | |||
| "zeros", [self.batch_size, self.seq_length], mstype.int32) | |||
| "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() | |||
| self.bert_embedding_lookup = EmbeddingLookup( | |||
| vocab_size=config.vocab_size, | |||
| @@ -29,7 +29,7 @@ from .mobilenet import InvertedResidual, ConvBNReLU | |||
| def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'): | |||
| weight_shape = (out_channel, in_channel, kernel_size, kernel_size) | |||
| weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32) | |||
| weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() | |||
| return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, | |||
| padding=0, pad_mode=pad_mod, weight_init=weight) | |||
| @@ -26,7 +26,7 @@ def _make_layer(base, batch_norm): | |||
| layers += [nn.MaxPool2d(kernel_size=2, stride=2)] | |||
| else: | |||
| weight_shape = (v, in_channels, 3, 3) | |||
| weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32) | |||
| weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() | |||
| conv2d = nn.Conv2d(in_channels=in_channels, | |||
| out_channels=v, | |||
| kernel_size=3, | |||
| @@ -163,6 +163,7 @@ class Cell: | |||
| if context.get_context("mode") == context.GRAPH_MODE: | |||
| out = self.compile_and_run(*inputs) | |||
| return out | |||
| self.init_parameters_data() | |||
| output = self.construct(*inputs) | |||
| if isinstance(output, Parameter): | |||
| output = output.data | |||
| @@ -395,6 +396,10 @@ class Cell: | |||
| """ | |||
| raise NotImplementedError | |||
| def init_parameters_data(self, recurse=True): | |||
| for param in self.get_parameters(expand=recurse): | |||
| param.init_data() | |||
| def parameters_dict(self, recurse=True): | |||
| """ | |||
| Gets parameters dictionary. | |||
| @@ -471,6 +471,9 @@ class LayerNorm(Cell): | |||
| beta_init='zeros', | |||
| ): | |||
| super(LayerNorm, self).__init__() | |||
| if not isinstance(normalized_shape, (tuple, list)): | |||
| raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}." | |||
| .format(normalized_shape, type(normalized_shape))) | |||
| self.normalized_shape = normalized_shape | |||
| self.begin_norm_axis = begin_norm_axis | |||
| self.begin_params_axis = begin_params_axis | |||
| @@ -116,6 +116,8 @@ def save_checkpoint(parameter_list, ckpoint_file_name): | |||
| param_value = checkpoint_list.value.add() | |||
| param_value.tag = param["name"] | |||
| param_tensor = param_value.tensor | |||
| if isinstance(param["data"], Parameter): | |||
| param["data"].init_data() | |||
| param_data = param["data"].asnumpy().reshape(-1) | |||
| param_tensor.tensor_content = param_data.tostring() | |||
| param_tensor.tensor_type = str(param["data"].dtype()) | |||
| @@ -238,6 +240,7 @@ def load_param_into_net(net, parameter_dict): | |||
| logger.error("Failed to combine the net and the parameters.") | |||
| msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param))) | |||
| raise TypeError(msg) | |||
| param.init_data() | |||
| _update_param(param, new_param) | |||
| else: | |||
| param_not_load.append(param.name) | |||
| @@ -311,6 +314,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True | |||
| param_list = [] | |||
| for (key, value) in param_dict.items(): | |||
| each_param = {"name": key} | |||
| value.init_data() | |||
| if isinstance(value.data, Tensor): | |||
| param_data = value.data | |||
| else: | |||
| @@ -371,6 +375,8 @@ def _fill_param_into_net(net, parameter_list): | |||
| parameter_dict = {} | |||
| for each_param in parameter_list: | |||
| param_name = each_param["name"] | |||
| if isinstance(each_param["data"], Parameter): | |||
| each_param["data"].init_data() | |||
| np_val = each_param["data"].asnumpy() | |||
| if np_val.shape == (1,): | |||
| parameter_dict[param_name] = Parameter(np_val, name=param_name) | |||
| @@ -35,6 +35,7 @@ def get_uniform_with_shape(shape): | |||
| def set_block_param_with_rand(net, rand_func=None): | |||
| if not isinstance(net, nn.Cell) or rand_func is None: | |||
| return | |||
| net.init_parameters_data() | |||
| for param in net.trainable_params(): | |||
| param.default_input = Tensor(rand_func(param.default_input.asnumpy().shape)) | |||
| @@ -143,6 +143,7 @@ def test_bert_tdt(): | |||
| callback = ModelCallback() | |||
| params = netwithloss.trainable_params() | |||
| for param in params: | |||
| param.init_data() | |||
| value = param.default_input | |||
| name = param.name | |||
| if isinstance(value, Tensor): | |||
| @@ -223,6 +223,7 @@ def test_div(): | |||
| @non_graph_engine | |||
| def test_parameter(): | |||
| x = Parameter(initializer(1, [1], ms.float32), name="beta1_power") | |||
| x.init_data() | |||
| z = x / 2 | |||
| print(z) | |||
| @@ -34,7 +34,7 @@ def test_dense_str_activation(): | |||
| assert isinstance(dense.activation, nn.ReLU) | |||
| input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32)) | |||
| dense.construct(input_data) | |||
| dense(input_data) | |||
| def test_dense_weight_error(): | |||
| @@ -40,8 +40,10 @@ class ParameterNet(nn.Cell): | |||
| def test_using_same_seed_for_initializer(): | |||
| np.random.seed(0) | |||
| net1 = ParameterNet() | |||
| net1.init_parameters_data() | |||
| np.random.seed(0) | |||
| net2 = ParameterNet() | |||
| net2.init_parameters_data() | |||
| for key in net1.parameters_dict(): | |||
| if key not in net2.parameters_dict(): | |||
| assert False | |||
| @@ -52,8 +54,10 @@ def test_using_same_seed_for_initializer(): | |||
| def test_using_diffserent_seed_for_initializer(): | |||
| np.random.seed(0) | |||
| net1 = ParameterNet() | |||
| net1.init_parameters_data() | |||
| np.random.seed(1) | |||
| net2 = ParameterNet() | |||
| net2.init_parameters_data() | |||
| for key in net1.parameters_dict(): | |||
| if key not in net2.parameters_dict(): | |||
| assert False | |||
| @@ -59,7 +59,7 @@ def test_bn2d(): | |||
| #3-channel RGB | |||
| input_data = Tensor(np.random.randint(0, 1, [1, 3, 224, 224]).astype(np.float32)) | |||
| output = bn.construct(input_data) | |||
| output = bn(input_data) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) | |||
| @@ -68,7 +68,7 @@ def test_bn1d(): | |||
| """ut of nn.BatchNorm1d""" | |||
| bn = nn.BatchNorm1d(3) | |||
| input_data = Tensor(np.random.randint(0, 1, [1, 3, 100, 100]).astype(np.float32)) | |||
| output = bn.construct(input_data) | |||
| output = bn(input_data) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) | |||
| @@ -27,7 +27,7 @@ kernel_size = 3 | |||
| def test_check_conv2d_1(): | |||
| m = nn.Conv2d(3, 64, 3, bias_init='zeros') | |||
| output = m.construct(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))) | |||
| output = m(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) | |||
| @@ -35,7 +35,7 @@ def test_check_conv2d_1(): | |||
| def test_check_conv2d_2(): | |||
| Tensor(np.ones([2, 2])) | |||
| m = nn.Conv2d(3, 64, 4, has_bias=False, weight_init='normal') | |||
| output = m.construct(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))) | |||
| output = m(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) | |||
| @@ -43,7 +43,7 @@ def test_check_conv2d_2(): | |||
| def test_check_conv2d_3(): | |||
| Tensor(np.ones([2, 2])) | |||
| m = nn.Conv2d(3, 64, (3, 3)) | |||
| output = m.construct(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))) | |||
| output = m(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) | |||
| @@ -51,13 +51,13 @@ def test_check_conv2d_3(): | |||
| def test_check_conv2d_4(): | |||
| Tensor(np.ones([2, 2])) | |||
| m = nn.Conv2d(3, 64, (3, 3), stride=2, pad_mode='pad', padding=4) | |||
| output = m.construct(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))) | |||
| output = m(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) | |||
| def test_check_conv2d_bias(): | |||
| m = nn.Conv2d(3, 64, 3, bias_init='zeros') | |||
| output = m.construct(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))) | |||
| output = m(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) | |||
| @@ -27,7 +27,7 @@ def test_dense_defaultbias_noactivation(): | |||
| assert dense.activation is None | |||
| input_data = Tensor(np.random.randint(0, 255, [1, 3]).astype(np.float32)) | |||
| output = dense.construct(input_data) | |||
| output = dense(input_data) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0], (np.float32, np.float64)) | |||
| @@ -37,7 +37,7 @@ def test_dense_defaultweight(): | |||
| dense = nn.Dense(3, 2, bias_init=bias) | |||
| #batch_size 1 && 3-channel RGB | |||
| input_data = Tensor(np.random.randint(0, 255, [1, 3]).astype(np.float32)) | |||
| output = dense.construct(input_data) | |||
| output = dense(input_data) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0], (np.float32, np.float64)) | |||
| @@ -48,7 +48,7 @@ def test_dense_bias(): | |||
| dense = nn.Dense(3, 2, weight, bias) | |||
| input_data = Tensor(np.random.randint(0, 255, [2, 3]).astype(np.float32)) | |||
| output = dense.construct(input_data) | |||
| output = dense(input_data) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0], (np.float32, np.float64)) | |||
| @@ -58,7 +58,7 @@ def test_dense_nobias(): | |||
| dense = nn.Dense(3, 2, weight, has_bias=False) | |||
| input_data = Tensor(np.random.randint(0, 255, [2, 3]).astype(np.float32)) | |||
| output = dense.construct(input_data) | |||
| output = dense(input_data) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0], (np.float32, np.float64)) | |||
| @@ -73,7 +73,7 @@ def test_dense_str_activation(): | |||
| assert isinstance(dense.activation, nn.ReLU) | |||
| input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32)) | |||
| output = dense.construct(input_data) | |||
| output = dense(input_data) | |||
| output_np = output.asnumpy() | |||
| assert isinstance(output_np[0][0], np.float32) | |||
| @@ -264,6 +264,7 @@ def test_grad_inline_bprop_multi_input(): | |||
| net = InlineMutilTwoInputParameterCell() | |||
| input1 = Tensor(np.ones([2, 2]).astype(np.float32)) | |||
| input2 = Tensor(np.ones([2, 2]).astype(np.float32)) | |||
| net.init_parameters_data() | |||
| grads = C.grad_all(net)(input1, input2) | |||
| assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all() | |||
| assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all() | |||
| @@ -133,6 +133,6 @@ def test_lenet_grad(): | |||
| print("fail to run optimizer") | |||
| # verification | |||
| if i == verification_step: | |||
| fw_output = net.construct(input_data) | |||
| loss_output = loss.construct(fw_output, label) | |||
| fw_output = net(input_data) | |||
| loss_output = loss(fw_output, label) | |||
| print("The loss of %s-th iteration is %s" % (i, loss_output.asnumpy())) | |||
| @@ -151,7 +151,7 @@ def test_softmaxloss_grad(): | |||
| predict = Tensor(np.ones([1, 64])) | |||
| label = Tensor(np.zeros([1, 10]).astype(np.float32)) | |||
| print("pynative run") | |||
| out = net.construct(predict, label) | |||
| out = net(predict, label) | |||
| print("out:", out) | |||
| def test_stop_gradient_1(): | |||
| @@ -22,6 +22,10 @@ from scipy import stats | |||
| import mindspore as ms | |||
| import mindspore.common.initializer as init | |||
| import mindspore.nn as nn | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.ops import operations as P | |||
| from mindspore import context | |||
| from mindspore.nn import Conv2d | |||
| from ..ut_filter import non_graph_engine | |||
| @@ -55,8 +59,8 @@ def _check_uniform(tensor, boundary_a, boundary_b): | |||
| def test_init_Initializer(): | |||
| tensor = init.initializer(InitTwo(), [2, 2], ms.int32) | |||
| assert tensor.shape() == (2, 2) | |||
| _check_value(tensor, 2, 2) | |||
| assert tensor.shape == (2, 2) | |||
| _check_value(tensor.to_tensor(), 2, 2) | |||
| def test_init_tensor(): | |||
| @@ -67,71 +71,71 @@ def test_init_tensor(): | |||
| def test_init_zero_default_dtype(): | |||
| tensor = init.initializer(init.Zero(), [2, 2]) | |||
| assert tensor.dtype() == ms.float32 | |||
| _check_value(tensor, 0, 0) | |||
| assert tensor.dtype == ms.float32 | |||
| _check_value(tensor.to_tensor(), 0, 0) | |||
| def test_init_zero(): | |||
| tensor = init.initializer(init.Zero(), [2, 2], ms.float32) | |||
| _check_value(tensor, 0, 0) | |||
| _check_value(tensor.to_tensor(), 0, 0) | |||
| def test_init_zero_alias_default_dtype(): | |||
| tensor = init.initializer('zeros', [1, 2]) | |||
| assert tensor.dtype() == ms.float32 | |||
| _check_value(tensor, 0, 0) | |||
| assert tensor.dtype == ms.float32 | |||
| _check_value(tensor.to_tensor(), 0, 0) | |||
| def test_init_zero_alias(): | |||
| tensor = init.initializer('zeros', [1, 2], ms.float32) | |||
| _check_value(tensor, 0, 0) | |||
| _check_value(tensor.to_tensor(), 0, 0) | |||
| def test_init_one(): | |||
| tensor = init.initializer(init.One(), [2, 2], ms.float32) | |||
| _check_value(tensor, 1, 1) | |||
| _check_value(tensor.to_tensor(), 1, 1) | |||
| def test_init_one_alias(): | |||
| tensor = init.initializer('ones', [1, 2], ms.float32) | |||
| _check_value(tensor, 1, 1) | |||
| _check_value(tensor.to_tensor(), 1, 1) | |||
| def test_init_constant(): | |||
| tensor = init.initializer(init.Constant(1), [2, 2], ms.float32) | |||
| _check_value(tensor, 1, 1) | |||
| _check_value(tensor.to_tensor(), 1, 1) | |||
| def test_init_uniform(): | |||
| scale = 10 | |||
| tensor = init.initializer(init.Uniform(scale=scale), [5, 4], ms.float32) | |||
| _check_value(tensor, -scale, scale) | |||
| _check_value(tensor.to_tensor(), -scale, scale) | |||
| def test_init_uniform_alias(): | |||
| scale = 100 | |||
| tensor = init.initializer('uniform', [5, 4], ms.float32) | |||
| _check_value(tensor, -scale, scale) | |||
| _check_value(tensor.to_tensor(), -scale, scale) | |||
| def test_init_normal(): | |||
| tensor = init.initializer(init.Normal(), [5, 4], ms.float32) | |||
| assert isinstance(tensor, ms.Tensor), 'tensor init failed!' | |||
| assert isinstance(tensor, init.Normal), 'Normal init failed!' | |||
| def test_init_truncated_normal(): | |||
| tensor = init.initializer(init.TruncatedNormal(), [5, 4], ms.float32) | |||
| assert isinstance(tensor, ms.Tensor), 'tensor init failed!' | |||
| assert isinstance(tensor, init.TruncatedNormal), 'TruncatedNormal init failed!' | |||
| def test_init_normal_alias(): | |||
| tensor = init.initializer('normal', [5, 4], ms.float32) | |||
| assert isinstance(tensor, ms.Tensor), 'tensor init failed!' | |||
| assert isinstance(tensor, init.Normal), 'Normal init failed!' | |||
| def test_init_truncatednormal_alias(): | |||
| tensor = init.initializer('truncatednormal', [5, 4], ms.float32) | |||
| assert isinstance(tensor, ms.Tensor), 'tensor init failed!' | |||
| assert isinstance(tensor, init.TruncatedNormal), 'TruncatedNormal init failed!' | |||
| def test_init_abnormal(): | |||
| @@ -142,12 +146,12 @@ def test_init_abnormal(): | |||
| def test_init_xavier_uniform(): | |||
| """ test_init_xavier_uniform """ | |||
| gain = 1.2 | |||
| tensor1 = init.initializer(init.XavierUniform(gain=gain), [20, 22], ms.float32) | |||
| tensor2 = init.initializer(init.XavierUniform(), [20, 22], ms.float32) | |||
| tensor3 = init.initializer(init.XavierUniform(gain=gain), [20, 22, 5, 5], ms.float32) | |||
| tensor4 = init.initializer(init.XavierUniform(), [20, 22, 5, 5], ms.float32) | |||
| tensor5 = init.initializer('xavier_uniform', [20, 22, 5, 5], ms.float32) | |||
| tensor6 = init.initializer('xavier_uniform', [20, 22], ms.float32) | |||
| tensor1 = init.initializer(init.XavierUniform(gain=gain), [20, 22], ms.float32).to_tensor() | |||
| tensor2 = init.initializer(init.XavierUniform(), [20, 22], ms.float32).to_tensor() | |||
| tensor3 = init.initializer(init.XavierUniform(gain=gain), [20, 22, 5, 5], ms.float32).to_tensor() | |||
| tensor4 = init.initializer(init.XavierUniform(), [20, 22, 5, 5], ms.float32).to_tensor() | |||
| tensor5 = init.initializer('xavier_uniform', [20, 22, 5, 5], ms.float32).to_tensor() | |||
| tensor6 = init.initializer('xavier_uniform', [20, 22], ms.float32).to_tensor() | |||
| tensor_dict = {tensor1: gain, tensor2: None, tensor3: gain, tensor4: None, tensor5: None, tensor6: None} | |||
| for tensor, gain_value in tensor_dict.items(): | |||
| @@ -167,7 +171,7 @@ def test_init_xavier_uniform(): | |||
| def test_init_xavier_uniform_error(): | |||
| with py.raises(ValueError): | |||
| init.initializer(init.XavierUniform(), [6], ms.float32) | |||
| init.initializer(init.XavierUniform(), [6], ms.float32).to_tensor() | |||
| def test_init_he_uniform(): | |||
| @@ -176,7 +180,7 @@ def test_init_he_uniform(): | |||
| tensor2 = init.initializer(init.HeUniform(), [20, 22, 5, 5], ms.float32) | |||
| tensor3 = init.initializer('he_uniform', [20, 22, 5, 5], ms.float32) | |||
| tensor4 = init.initializer('he_uniform', [20, 22], ms.float32) | |||
| tensors = [tensor1, tensor2, tensor3, tensor4] | |||
| tensors = [tensor1.to_tensor(), tensor2.to_tensor(), tensor3.to_tensor(), tensor4.to_tensor()] | |||
| for tensor in tensors: | |||
| shape = tensor.asnumpy().shape | |||
| @@ -192,7 +196,7 @@ def test_init_he_uniform(): | |||
| def test_init_he_uniform_error(): | |||
| with py.raises(ValueError): | |||
| init.initializer(init.HeUniform(), [6], ms.float32) | |||
| init.initializer(init.HeUniform(), [6], ms.float32).to_tensor() | |||
| def test_conv2d_abnormal_kernel_negative(): | |||
| @@ -216,9 +220,30 @@ def test_conv2d_abnormal_kernel_normal(): | |||
| @non_graph_engine | |||
| def test_conv2d_abnormal_kernel_truncated_normal(): | |||
| input_data = init.initializer(init.TruncatedNormal(), [64, 3, 7, 7], ms.float32) | |||
| input_data = init.initializer(init.TruncatedNormal(), [64, 3, 7, 7], ms.float32).to_tensor() | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| model = ms.Model( | |||
| Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=3, | |||
| padding=0, weight_init="truncatednormal")) | |||
| model.predict(input_data) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.add = P.TensorAdd() | |||
| self.t1 = Parameter(init.initializer('uniform', [5, 4], ms.float32), name="w1") | |||
| self.t2 = Parameter(init.initializer(init.TruncatedNormal(), [5, 4], ms.float32), name="w2") | |||
| def construct(self, x): | |||
| z = self.add(x, self.t1) | |||
| z = self.add(z, self.t2) | |||
| return z | |||
| def test_weight_shape(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| a = np.arange(20).reshape(5, 4) | |||
| t = Tensor(a, dtype=ms.float32) | |||
| net = Net() | |||
| out = net(t) | |||
| print(out) | |||
| @@ -198,6 +198,7 @@ def test_load_param_into_net_error_dict(): | |||
| def test_load_param_into_net_erro_dict_param(): | |||
| net = Net(10) | |||
| net.init_parameters_data() | |||
| assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0 | |||
| parameter_dict = {} | |||
| @@ -210,6 +211,7 @@ def test_load_param_into_net_erro_dict_param(): | |||
| def test_load_param_into_net_has_more_param(): | |||
| """ test_load_param_into_net_has_more_param """ | |||
| net = Net(10) | |||
| net.init_parameters_data() | |||
| assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0 | |||
| parameter_dict = {} | |||
| @@ -225,6 +227,7 @@ def test_load_param_into_net_has_more_param(): | |||
| def test_load_param_into_net_param_type_and_shape_error(): | |||
| net = Net(10) | |||
| net.init_parameters_data() | |||
| assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0 | |||
| parameter_dict = {} | |||
| @@ -236,6 +239,7 @@ def test_load_param_into_net_param_type_and_shape_error(): | |||
| def test_load_param_into_net_param_type_error(): | |||
| net = Net(10) | |||
| net.init_parameters_data() | |||
| assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0 | |||
| parameter_dict = {} | |||
| @@ -248,6 +252,7 @@ def test_load_param_into_net_param_type_error(): | |||
| def test_load_param_into_net_param_shape_error(): | |||
| net = Net(10) | |||
| net.init_parameters_data() | |||
| assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0 | |||
| parameter_dict = {} | |||
| @@ -260,6 +265,7 @@ def test_load_param_into_net_param_shape_error(): | |||
| def test_load_param_into_net(): | |||
| net = Net(10) | |||
| net.init_parameters_data() | |||
| assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0 | |||
| parameter_dict = {} | |||