| @@ -79,12 +79,12 @@ if __name__ == '__main__': | |||||
| for _, cell in net.cells_and_names(): | for _, cell in net.cells_and_names(): | ||||
| if isinstance(cell, nn.Conv2d): | if isinstance(cell, nn.Conv2d): | ||||
| cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | ||||
| cell.weight.default_input.shape(), | |||||
| cell.weight.default_input.dtype()).to_tensor() | |||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| if isinstance(cell, nn.Dense): | if isinstance(cell, nn.Dense): | ||||
| cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | ||||
| cell.weight.default_input.shape(), | |||||
| cell.weight.default_input.dtype()).to_tensor() | |||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| if not config.use_label_smooth: | if not config.use_label_smooth: | ||||
| config.label_smooth_factor = 0.0 | config.label_smooth_factor = 0.0 | ||||
| @@ -338,15 +338,15 @@ class Dense_Thor(Cell): | |||||
| self.has_bias = check_bool(has_bias) | self.has_bias = check_bool(has_bias) | ||||
| self.thor = True | self.thor = True | ||||
| if isinstance(weight_init, Tensor): | if isinstance(weight_init, Tensor): | ||||
| if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ | |||||
| weight_init.shape()[1] != in_channels: | |||||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||||
| weight_init.shape[1] != in_channels: | |||||
| raise ValueError("weight_init shape error") | raise ValueError("weight_init shape error") | ||||
| self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") | self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") | ||||
| if self.has_bias: | if self.has_bias: | ||||
| if isinstance(bias_init, Tensor): | if isinstance(bias_init, Tensor): | ||||
| if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: | |||||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||||
| raise ValueError("bias_init shape error") | raise ValueError("bias_init shape error") | ||||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | ||||
| @@ -56,7 +56,7 @@ def init_net_param(network, init_value='ones'): | |||||
| params = network.trainable_params() | params = network.trainable_params() | ||||
| for p in params: | for p in params: | ||||
| if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | ||||
| p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype())) | |||||
| p.set_parameter_data(initializer(init_value, p.data.shape, p.data.dtype)) | |||||
| def main(): | def main(): | ||||
| @@ -384,6 +384,28 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||||
| .def(py::init<py::tuple, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) | .def(py::init<py::tuple, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) | ||||
| .def(py::init<Tensor, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) | .def(py::init<Tensor, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) | ||||
| .def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_) | .def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_) | ||||
| .def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter( | |||||
| Get the tensor's data type. | |||||
| Returns: | |||||
| type, the data type of tensor. | |||||
| Examples: | |||||
| >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) | |||||
| >>> data.dtype | |||||
| Int32 | |||||
| )mydelimiter") | |||||
| .def_property_readonly("shape", &Tensor::GetPyTupleShape, R"mydelimiter( | |||||
| Get the tensor's shape. | |||||
| Returns: | |||||
| tuple[int], the shape of tensor. | |||||
| Examples: | |||||
| >>> data = mindspore.Tensor(np.ones((3, 3))) | |||||
| >>> data.shape() | |||||
| (3, 3) | |||||
| )mydelimiter") | |||||
| .def("asnumpy", &Tensor::data_sync, R"mydelimiter( | .def("asnumpy", &Tensor::data_sync, R"mydelimiter( | ||||
| Convert tensor to numpy.ndarray. | Convert tensor to numpy.ndarray. | ||||
| @@ -437,17 +459,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||||
| >>> data.dim() | >>> data.dim() | ||||
| 2 | 2 | ||||
| )mydelimiter") | )mydelimiter") | ||||
| .def("dtype", &Tensor::Dtype, R"mydelimiter( | |||||
| Get the tensor's data type. | |||||
| Returns: | |||||
| type, the data type of tensor. | |||||
| Examples: | |||||
| >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) | |||||
| >>> data.dtype() | |||||
| Int32 | |||||
| )mydelimiter") | |||||
| .def("set_dtype", &Tensor::SetDtype, R"mydelimiter( | .def("set_dtype", &Tensor::SetDtype, R"mydelimiter( | ||||
| Set the tensor's data type. | Set the tensor's data type. | ||||
| @@ -459,17 +470,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||||
| >>> data.set_dtype(mindspore.int32) | >>> data.set_dtype(mindspore.int32) | ||||
| mindspore.int32 | mindspore.int32 | ||||
| )mydelimiter") | )mydelimiter") | ||||
| .def("shape", &Tensor::GetPyTupleShape, R"mydelimiter( | |||||
| Get the tensor's shape. | |||||
| Returns: | |||||
| tuple[int], the shape of tensor. | |||||
| Examples: | |||||
| >>> data = mindspore.Tensor(np.ones((3, 3))) | |||||
| >>> data.shape() | |||||
| (3, 3) | |||||
| )mydelimiter") | |||||
| .def("__str__", &Tensor::ToString) | .def("__str__", &Tensor::ToString) | ||||
| .def("__repr__", &Tensor::ToStringRepr) | .def("__repr__", &Tensor::ToStringRepr) | ||||
| .def(py::pickle( | .def(py::pickle( | ||||
| @@ -488,8 +488,8 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||||
| (void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor") | (void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor") | ||||
| .def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape")) | .def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape")) | ||||
| .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) | .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) | ||||
| .def("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | |||||
| .def("shape", &MetaTensor::shape, "Get the MetaTensor's shape."); | |||||
| .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | |||||
| .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape."); | |||||
| })); | })); | ||||
| } // namespace tensor | } // namespace tensor | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -170,8 +170,8 @@ def get_py_obj_dtype(obj): | |||||
| Type of MindSpore type. | Type of MindSpore type. | ||||
| """ | """ | ||||
| # Tensor | # Tensor | ||||
| if hasattr(obj, 'dtype') and callable(obj.dtype) and isinstance(obj.dtype(), typing.Type): | |||||
| return tensor_type(obj.dtype()) | |||||
| if hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): | |||||
| return tensor_type(obj.dtype) | |||||
| if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): | if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): | ||||
| return function | return function | ||||
| if isinstance(obj, (typing.Type, type)): | if isinstance(obj, (typing.Type, type)): | ||||
| @@ -331,11 +331,11 @@ def initializer(init, shape=None, dtype=mstype.float32): | |||||
| raise TypeError("Unsupported init type '{}'.".format(type(init))) | raise TypeError("Unsupported init type '{}'.".format(type(init))) | ||||
| if isinstance(init, Tensor): | if isinstance(init, Tensor): | ||||
| init_shape = init.shape() | |||||
| init_shape = init.shape | |||||
| shape = shape if isinstance(shape, (tuple, list)) else [shape] | shape = shape if isinstance(shape, (tuple, list)) else [shape] | ||||
| if shape is not None and init_shape != tuple(shape): | if shape is not None and init_shape != tuple(shape): | ||||
| raise ValueError("The shape of init should be same as variable shape, but got the shape of init {} and " | raise ValueError("The shape of init should be same as variable shape, but got the shape of init {} and " | ||||
| "the variable shape {}.".format(list(init.shape()), shape)) | |||||
| "the variable shape {}.".format(list(init.shape), shape)) | |||||
| return init | return init | ||||
| if isinstance(shape, list): | if isinstance(shape, list): | ||||
| @@ -140,8 +140,8 @@ class Parameter: | |||||
| x.name = prefix + '.' + x.name | x.name = prefix + '.' + x.name | ||||
| x.is_init = False | x.is_init = False | ||||
| if init != 'same': | if init != 'same': | ||||
| shape = self.default_input.shape() | |||||
| dtype = self.default_input.dtype() | |||||
| shape = self.default_input.shape | |||||
| dtype = self.default_input.dtype | |||||
| if isinstance(init, (str, Initializer, numbers.Number)): | if isinstance(init, (str, Initializer, numbers.Number)): | ||||
| x.init_mode = initializer(init, shape=shape, dtype=dtype) | x.init_mode = initializer(init, shape=shape, dtype=dtype) | ||||
| x.default_input = MetaTensor(dtype, shape) | x.default_input = MetaTensor(dtype, shape) | ||||
| @@ -45,13 +45,13 @@ class Tensor(Tensor_): | |||||
| >>> # init a tensor with input data | >>> # init a tensor with input data | ||||
| >>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32) | >>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32) | ||||
| >>> assert isinstance(t1, Tensor) | >>> assert isinstance(t1, Tensor) | ||||
| >>> assert t1.shape() == (1, 2, 3) | |||||
| >>> assert t1.dtype() == mindspore.float32 | |||||
| >>> assert t1.shape == (1, 2, 3) | |||||
| >>> assert t1.dtype == mindspore.float32 | |||||
| >>> | >>> | ||||
| >>> # init a tensor with a float scalar | >>> # init a tensor with a float scalar | ||||
| >>> t2 = Tensor(0.1) | >>> t2 = Tensor(0.1) | ||||
| >>> assert isinstance(t2, Tensor) | >>> assert isinstance(t2, Tensor) | ||||
| >>> assert t2.dtype() == mindspore.float64 | |||||
| >>> assert t2.dtype == mindspore.float64 | |||||
| """ | """ | ||||
| def __init__(self, input_data, dtype=None): | def __init__(self, input_data, dtype=None): | ||||
| @@ -80,7 +80,7 @@ class Tensor(Tensor_): | |||||
| return False | return False | ||||
| # The GE backend don't support single `Equal` operator execution. | # The GE backend don't support single `Equal` operator execution. | ||||
| # bool type is not supported for `Equal` operator in backend. | # bool type is not supported for `Equal` operator in backend. | ||||
| if context.get_context("enable_ge") or self.dtype() == mstype.bool_ or other.dtype() == mstype.bool_: | |||||
| if context.get_context("enable_ge") or self.dtype == mstype.bool_ or other.dtype == mstype.bool_: | |||||
| return Tensor(np.array(self.asnumpy() == other.asnumpy())) | return Tensor(np.array(self.asnumpy() == other.asnumpy())) | ||||
| return tensor_operator_registry.get('__eq__')(self, other) | return tensor_operator_registry.get('__eq__')(self, other) | ||||
| @@ -166,7 +166,7 @@ class Tensor(Tensor_): | |||||
| return out[0] | return out[0] | ||||
| def __str__(self): | def __str__(self): | ||||
| if self.dtype() == mstype.type_none: | |||||
| if self.dtype == mstype.type_none: | |||||
| return "Unknown Tensor type!" | return "Unknown Tensor type!" | ||||
| return str(self.asnumpy()) | return str(self.asnumpy()) | ||||
| @@ -267,21 +267,21 @@ class MobileNetV2(nn.Cell): | |||||
| if isinstance(m, (nn.Conv2d, DepthwiseConv)): | if isinstance(m, (nn.Conv2d, DepthwiseConv)): | ||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | ||||
| m.weight.data.shape()).astype("float32"))) | |||||
| m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | if m.bias is not None: | ||||
| m.bias.set_parameter_data( | m.bias.set_parameter_data( | ||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.BatchNorm2d): | elif isinstance(m, nn.BatchNorm2d): | ||||
| m.gamma.set_parameter_data( | m.gamma.set_parameter_data( | ||||
| Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) | |||||
| Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) | |||||
| m.beta.set_parameter_data( | m.beta.set_parameter_data( | ||||
| Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.Dense): | elif isinstance(m, nn.Dense): | ||||
| m.weight.set_parameter_data(Tensor(np.random.normal( | m.weight.set_parameter_data(Tensor(np.random.normal( | ||||
| 0, 0.01, m.weight.data.shape()).astype("float32"))) | |||||
| 0, 0.01, m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | if m.bias is not None: | ||||
| m.bias.set_parameter_data( | m.bias.set_parameter_data( | ||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| def mobilenet_v2(**kwargs): | def mobilenet_v2(**kwargs): | ||||
| @@ -322,21 +322,21 @@ class MobileNetV3(nn.Cell): | |||||
| if isinstance(m, (nn.Conv2d)): | if isinstance(m, (nn.Conv2d)): | ||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | ||||
| m.weight.data.shape()).astype("float32"))) | |||||
| m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | if m.bias is not None: | ||||
| m.bias.set_parameter_data( | m.bias.set_parameter_data( | ||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.BatchNorm2d): | elif isinstance(m, nn.BatchNorm2d): | ||||
| m.gamma.set_parameter_data( | m.gamma.set_parameter_data( | ||||
| Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) | |||||
| Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) | |||||
| m.beta.set_parameter_data( | m.beta.set_parameter_data( | ||||
| Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.Dense): | elif isinstance(m, nn.Dense): | ||||
| m.weight.set_parameter_data(Tensor(np.random.normal( | m.weight.set_parameter_data(Tensor(np.random.normal( | ||||
| 0, 0.01, m.weight.data.shape()).astype("float32"))) | |||||
| 0, 0.01, m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | if m.bias is not None: | ||||
| m.bias.set_parameter_data( | m.bias.set_parameter_data( | ||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| def mobilenet_v3(model_name, **kwargs): | def mobilenet_v3(model_name, **kwargs): | ||||
| @@ -131,7 +131,7 @@ class Flatten(Cell): | |||||
| Examples: | Examples: | ||||
| >>> net = nn.Flatten() | >>> net = nn.Flatten() | ||||
| >>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32) | >>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32) | ||||
| >>> input.shape() | |||||
| >>> input.shape | |||||
| (2, 2, 2) | (2, 2, 2) | ||||
| >>> net(input) | >>> net(input) | ||||
| [[1.2 1.2 2.1 2.1] | [[1.2 1.2 2.1 2.1] | ||||
| @@ -198,15 +198,15 @@ class Dense(Cell): | |||||
| self.has_bias = check_bool(has_bias) | self.has_bias = check_bool(has_bias) | ||||
| if isinstance(weight_init, Tensor): | if isinstance(weight_init, Tensor): | ||||
| if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ | |||||
| weight_init.shape()[1] != in_channels: | |||||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||||
| weight_init.shape[1] != in_channels: | |||||
| raise ValueError("weight_init shape error") | raise ValueError("weight_init shape error") | ||||
| self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") | self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") | ||||
| if self.has_bias: | if self.has_bias: | ||||
| if isinstance(bias_init, Tensor): | if isinstance(bias_init, Tensor): | ||||
| if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: | |||||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||||
| raise ValueError("bias_init shape error") | raise ValueError("bias_init shape error") | ||||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | ||||
| @@ -69,7 +69,7 @@ class Conv2d(Cell): | |||||
| Examples: | Examples: | ||||
| >>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU') | >>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU') | ||||
| >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | ||||
| >>> net(input).shape() | |||||
| >>> net(input).shape | |||||
| (1, 240, 1024, 640) | (1, 240, 1024, 640) | ||||
| """ | """ | ||||
| @@ -168,7 +168,7 @@ class Conv2d(_Conv): | |||||
| Examples: | Examples: | ||||
| >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal') | >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal') | ||||
| >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | ||||
| >>> net(input).shape() | |||||
| >>> net(input).shape | |||||
| (1, 240, 1024, 640) | (1, 240, 1024, 640) | ||||
| """ | """ | ||||
| @cell_attr_register | @cell_attr_register | ||||
| @@ -56,7 +56,7 @@ class Embedding(Cell): | |||||
| >>> | >>> | ||||
| >>> # Maps the input word IDs to word embedding. | >>> # Maps the input word IDs to word embedding. | ||||
| >>> output = net(input_data) | >>> output = net(input_data) | ||||
| >>> output.shape() | |||||
| >>> output.shape | |||||
| (8, 128, 768) | (8, 128, 768) | ||||
| """ | """ | ||||
| def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): | def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): | ||||
| @@ -474,7 +474,7 @@ class LayerNorm(Cell): | |||||
| Examples: | Examples: | ||||
| >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) | >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) | ||||
| >>> shape1 = x.shape()[1:] | |||||
| >>> shape1 = x.shape[1:] | |||||
| >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) | >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) | ||||
| >>> m(x) | >>> m(x) | ||||
| """ | """ | ||||
| @@ -113,7 +113,7 @@ class MaxPool2d(_PoolNd): | |||||
| [0. 0. 4. 0.] | [0. 0. 4. 0.] | ||||
| [1. 8. 7. 0.]]]] | [1. 8. 7. 0.]]]] | ||||
| >>> output = pool(x) | >>> output = pool(x) | ||||
| >>> output.shape() | |||||
| >>> output.shape | |||||
| (1, 2, 2, 2) | (1, 2, 2, 2) | ||||
| >>> output | >>> output | ||||
| [[[[7. 8.] | [[[[7. 8.] | ||||
| @@ -195,7 +195,7 @@ class AvgPool2d(_PoolNd): | |||||
| [0. 8. 9. 7.] | [0. 8. 9. 7.] | ||||
| [2. 1. 4. 9.]]]] | [2. 1. 4. 9.]]]] | ||||
| >>> output = pool(x) | >>> output = pool(x) | ||||
| >>> output.shape() | |||||
| >>> output.shape | |||||
| (1, 2, 2, 2) | (1, 2, 2, 2) | ||||
| >>> output | >>> output | ||||
| [[[[4.888889 4.4444447] | [[[[4.888889 4.4444447] | ||||
| @@ -260,7 +260,7 @@ class AvgPool1d(_PoolNd): | |||||
| >>> pool = nn.AvgPool1d(kernel_size=6, strides=1) | >>> pool = nn.AvgPool1d(kernel_size=6, strides=1) | ||||
| >>> x = Tensor(np.random.randint(0, 10, [1, 3, 6]), mindspore.float32) | >>> x = Tensor(np.random.randint(0, 10, [1, 3, 6]), mindspore.float32) | ||||
| >>> output = pool(x) | >>> output = pool(x) | ||||
| >>> output.shape() | |||||
| >>> output.shape | |||||
| (1, 3, 1) | (1, 3, 1) | ||||
| """ | """ | ||||
| @@ -571,8 +571,8 @@ class DenseQuant(Cell): | |||||
| self.has_bias = check_bool(has_bias) | self.has_bias = check_bool(has_bias) | ||||
| if isinstance(weight_init, Tensor): | if isinstance(weight_init, Tensor): | ||||
| if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ | |||||
| weight_init.shape()[1] != in_channels: | |||||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||||
| weight_init.shape[1] != in_channels: | |||||
| raise ValueError("weight_init shape error") | raise ValueError("weight_init shape error") | ||||
| self.weight = Parameter(initializer( | self.weight = Parameter(initializer( | ||||
| @@ -580,7 +580,7 @@ class DenseQuant(Cell): | |||||
| if self.has_bias: | if self.has_bias: | ||||
| if isinstance(bias_init, Tensor): | if isinstance(bias_init, Tensor): | ||||
| if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: | |||||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||||
| raise ValueError("bias_init shape error") | raise ValueError("bias_init shape error") | ||||
| self.bias = Parameter(initializer( | self.bias = Parameter(initializer( | ||||
| @@ -23,7 +23,7 @@ greater = base.MultitypeFuncGraph("greater") | |||||
| @greater.register("Number", "Number") | @greater.register("Number", "Number") | ||||
| def _greater_scala(x, y): | |||||
| def _greater_scalar(x, y): | |||||
| """ | """ | ||||
| Determine whether two numbers are greater. | Determine whether two numbers are greater. | ||||
| @@ -145,10 +145,10 @@ class SameTypeShape(PrimitiveWithInfer): | |||||
| def __call__(self, x, y): | def __call__(self, x, y): | ||||
| """run in PyNative mode""" | """run in PyNative mode""" | ||||
| validator.check_value_type("x", x, Tensor, self.name) | |||||
| validator.check_value_type("y", y, Tensor, self.name) | |||||
| validator.check('x dtype', x.dtype(), 'y dtype', y.dtype(), Rel.EQ, self.name, TypeError) | |||||
| validator.check('x shape', x.shape(), 'y shape', y.shape(), Rel.EQ, self.name) | |||||
| validator.check_value_type('x', x, Tensor, self.name) | |||||
| validator.check_value_type('y', y, Tensor, self.name) | |||||
| validator.check('x dtype', x.dtype, 'y dtype', y.dtype, Rel.EQ, self.name, TypeError) | |||||
| validator.check('x shape', x.shape, 'y shape', y.shape, Rel.EQ, self.name) | |||||
| return x | return x | ||||
| def __infer__(self, x, y): | def __infer__(self, x, y): | ||||
| @@ -187,7 +187,7 @@ class Cast(PrimitiveWithInfer): | |||||
| def check_elim(self, x, dtype): | def check_elim(self, x, dtype): | ||||
| if isinstance(x, Tensor): | if isinstance(x, Tensor): | ||||
| if x.dtype() == dtype: | |||||
| if x.dtype == dtype: | |||||
| return (True, x) | return (True, x) | ||||
| return (False, None) | return (False, None) | ||||
| raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs)) | raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs)) | ||||
| @@ -498,7 +498,7 @@ class GatherV2(PrimitiveWithInfer): | |||||
| The original Tensor. | The original Tensor. | ||||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | ||||
| Specifies the indices of elements of the original Tensor. Must be in the range | Specifies the indices of elements of the original Tensor. Must be in the range | ||||
| `[0, input_param.shape()[axis])`. | |||||
| `[0, input_param.shape[axis])`. | |||||
| - **axis** (int) - Specifies the dimension index to gather indices. | - **axis** (int) - Specifies the dimension index to gather indices. | ||||
| Outputs: | Outputs: | ||||
| @@ -542,7 +542,7 @@ class SparseGatherV2(GatherV2): | |||||
| The original Tensor. | The original Tensor. | ||||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | ||||
| Specifies the indices of elements of the original Tensor. Must be in the range | Specifies the indices of elements of the original Tensor. Must be in the range | ||||
| `[0, input_param.shape()[axis])`. | |||||
| `[0, input_param.shape[axis])`. | |||||
| - **axis** (int) - Specifies the dimension index to gather indices. | - **axis** (int) - Specifies the dimension index to gather indices. | ||||
| Outputs: | Outputs: | ||||
| @@ -700,7 +700,7 @@ class Split(PrimitiveWithInfer): | |||||
| output_num (int): The number of output tensors. Default: 1. | output_num (int): The number of output tensors. Default: 1. | ||||
| Raises: | Raises: | ||||
| ValueError: If axis is out of the range [-len(input_x.shape()), len(input_x.shape())), | |||||
| ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)), | |||||
| or if the output_num is less than or equal to 0, or if the | or if the output_num is less than or equal to 0, or if the | ||||
| dimension which to split cannot be evenly divided by output_num. | dimension which to split cannot be evenly divided by output_num. | ||||
| @@ -1644,7 +1644,7 @@ class Unpack(PrimitiveWithInfer): | |||||
| A tuple of Tensors, the shape of each objects is same. | A tuple of Tensors, the shape of each objects is same. | ||||
| Raises: | Raises: | ||||
| ValueError: If axis is out of the range [-len(input_x.shape()), len(input_x.shape())). | |||||
| ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)). | |||||
| Examples: | Examples: | ||||
| >>> unpack = P.Unpack() | >>> unpack = P.Unpack() | ||||
| @@ -1850,7 +1850,7 @@ class StridedSlice(PrimitiveWithInfer): | |||||
| >>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32) | >>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32) | ||||
| >>> slice = P.StridedSlice() | >>> slice = P.StridedSlice() | ||||
| >>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1)) | >>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1)) | ||||
| >>> output.shape() | |||||
| >>> output.shape | |||||
| (1, 1, 3) | (1, 1, 3) | ||||
| >>> output | >>> output | ||||
| [[[3, 3, 3]]] | [[[3, 3, 3]]] | ||||
| @@ -1974,7 +1974,7 @@ class Diag(PrimitiveWithInfer): | |||||
| if x is None: | if x is None: | ||||
| return None | return None | ||||
| # do constant-folding only when x rank is 1 | # do constant-folding only when x rank is 1 | ||||
| if len(x.shape()) != 1: | |||||
| if len(x.shape) != 1: | |||||
| return None | return None | ||||
| ret = np.diag(x.asnumpy()) | ret = np.diag(x.asnumpy()) | ||||
| return Tensor(ret) | return Tensor(ret) | ||||
| @@ -2026,7 +2026,7 @@ class DiagPart(PrimitiveWithInfer): | |||||
| if x is None: | if x is None: | ||||
| return None | return None | ||||
| # do constant-folding only when x rank is 2 | # do constant-folding only when x rank is 2 | ||||
| if len(x.shape()) != 2: | |||||
| if len(x.shape) != 2: | |||||
| return None | return None | ||||
| ret = np.diag(x.asnumpy()) | ret = np.diag(x.asnumpy()) | ||||
| return Tensor(ret) | return Tensor(ret) | ||||
| @@ -2329,8 +2329,8 @@ class NMSWithMask(PrimitiveWithInfer): | |||||
| def infer_shape(self, bboxes_shape): | def infer_shape(self, bboxes_shape): | ||||
| cls_name = self.name | cls_name = self.name | ||||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | ||||
| validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) | |||||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | |||||
| validator.check_integer("bboxes.shape[0]", bboxes_shape[0], 0, Rel.GT, cls_name) | |||||
| validator.check_integer("bboxes.shape[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | |||||
| num = bboxes_shape[0] | num = bboxes_shape[0] | ||||
| return (bboxes_shape, (num,), (num,)) | return (bboxes_shape, (num,), (num,)) | ||||
| @@ -78,7 +78,7 @@ class Flatten(PrimitiveWithInfer): | |||||
| >>> input_tensor = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32) | >>> input_tensor = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32) | ||||
| >>> flatten = P.Flatten() | >>> flatten = P.Flatten() | ||||
| >>> output = flatten(input_tensor) | >>> output = flatten(input_tensor) | ||||
| >>> assert output.shape() == (1, 24) | |||||
| >>> assert output.shape == (1, 24) | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -840,7 +840,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): | |||||
| >>> weight = Tensor(np.ones([1, 32, 3, 3]), mindspore.float32) | >>> weight = Tensor(np.ones([1, 32, 3, 3]), mindspore.float32) | ||||
| >>> depthwise_conv2d = P.DepthwiseConv2dNative(channel_multiplier = 3, kernel_size = (3, 3)) | >>> depthwise_conv2d = P.DepthwiseConv2dNative(channel_multiplier = 3, kernel_size = (3, 3)) | ||||
| >>> output = depthwise_conv2d(input, weight) | >>> output = depthwise_conv2d(input, weight) | ||||
| >>> assert output.shape() == (10, 96, 30, 30) | |||||
| >>> assert output.shape == (10, 96, 30, 30) | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -2057,7 +2057,7 @@ class DropoutDoMask(PrimitiveWithInfer): | |||||
| >>> dropout_do_mask = P.DropoutDoMask() | >>> dropout_do_mask = P.DropoutDoMask() | ||||
| >>> mask = dropout_gen_mask(shape, keep_prob) | >>> mask = dropout_gen_mask(shape, keep_prob) | ||||
| >>> output = dropout_do_mask(x, mask, keep_prob) | >>> output = dropout_do_mask(x, mask, keep_prob) | ||||
| >>> assert output.shape() == (20, 16, 50) | |||||
| >>> assert output.shape == (20, 16, 50) | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -2114,7 +2114,7 @@ class ResizeBilinear(PrimitiveWithInfer): | |||||
| >>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.int32) | >>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.int32) | ||||
| >>> resize_bilinear = P.ResizeBilinear((5, 5)) | >>> resize_bilinear = P.ResizeBilinear((5, 5)) | ||||
| >>> result = resize_bilinear(tensor) | >>> result = resize_bilinear(tensor) | ||||
| >>> assert result.shape() == (5, 5) | |||||
| >>> assert result.shape == (5, 5) | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -157,8 +157,8 @@ def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None): | |||||
| data = Tensor(data) | data = Tensor(data) | ||||
| if not isinstance(data, Tensor): | if not isinstance(data, Tensor): | ||||
| raise ValueError("elements in tensors must be Tensor") | raise ValueError("elements in tensors must be Tensor") | ||||
| shape_ = data.shape() | |||||
| type_ = data.dtype() | |||||
| shape_ = data.shape | |||||
| type_ = data.dtype | |||||
| new_shape = () | new_shape = () | ||||
| batchsize_per_device = 1 | batchsize_per_device = 1 | ||||
| for i, item in enumerate(shape_): | for i, item in enumerate(shape_): | ||||
| @@ -42,17 +42,17 @@ def _special_process_par(par, new_par): | |||||
| Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor. | Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor. | ||||
| """ | """ | ||||
| par_shape_len = len(par.data.shape()) | |||||
| new_par_shape_len = len(new_par.data.shape()) | |||||
| par_shape_len = len(par.data.shape) | |||||
| new_par_shape_len = len(new_par.data.shape) | |||||
| delta_len = new_par_shape_len - par_shape_len | delta_len = new_par_shape_len - par_shape_len | ||||
| delta_i = 0 | delta_i = 0 | ||||
| for delta_i in range(delta_len): | for delta_i in range(delta_len): | ||||
| if new_par.data.shape()[par_shape_len + delta_i] != 1: | |||||
| if new_par.data.shape[par_shape_len + delta_i] != 1: | |||||
| break | break | ||||
| if delta_i == delta_len - 1: | if delta_i == delta_len - 1: | ||||
| new_val = new_par.data.asnumpy() | new_val = new_par.data.asnumpy() | ||||
| new_val = new_val.reshape(par.data.shape()) | |||||
| par.set_parameter_data(Tensor(new_val, par.data.dtype())) | |||||
| new_val = new_val.reshape(par.data.shape) | |||||
| par.set_parameter_data(Tensor(new_val, par.data.dtype)) | |||||
| return True | return True | ||||
| return False | return False | ||||
| @@ -61,17 +61,17 @@ def _update_param(param, new_param): | |||||
| """Updates param's data from new_param's data.""" | """Updates param's data from new_param's data.""" | ||||
| if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor): | if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor): | ||||
| if param.data.dtype() != new_param.data.dtype(): | |||||
| if param.data.dtype != new_param.data.dtype: | |||||
| logger.error("Failed to combine the net and the parameters for param %s.", param.name) | logger.error("Failed to combine the net and the parameters for param %s.", param.name) | ||||
| msg = ("Net parameters {} type({}) different from parameter_dict's({})" | msg = ("Net parameters {} type({}) different from parameter_dict's({})" | ||||
| .format(param.name, param.data.dtype(), new_param.data.dtype())) | |||||
| .format(param.name, param.data.dtype, new_param.data.dtype)) | |||||
| raise RuntimeError(msg) | raise RuntimeError(msg) | ||||
| if param.data.shape() != new_param.data.shape(): | |||||
| if param.data.shape != new_param.data.shape: | |||||
| if not _special_process_par(param, new_param): | if not _special_process_par(param, new_param): | ||||
| logger.error("Failed to combine the net and the parameters for param %s.", param.name) | logger.error("Failed to combine the net and the parameters for param %s.", param.name) | ||||
| msg = ("Net parameters {} shape({}) different from parameter_dict's({})" | msg = ("Net parameters {} shape({}) different from parameter_dict's({})" | ||||
| .format(param.name, param.data.shape(), new_param.data.shape())) | |||||
| .format(param.name, param.data.shape, new_param.data.shape)) | |||||
| raise RuntimeError(msg) | raise RuntimeError(msg) | ||||
| return | return | ||||
| @@ -79,12 +79,12 @@ def _update_param(param, new_param): | |||||
| return | return | ||||
| if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor): | if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor): | ||||
| if param.data.shape() != (1,) and param.data.shape() != (): | |||||
| if param.data.shape != (1,) and param.data.shape != (): | |||||
| logger.error("Failed to combine the net and the parameters for param %s.", param.name) | logger.error("Failed to combine the net and the parameters for param %s.", param.name) | ||||
| msg = ("Net parameters {} shape({}) is not (1,), inconsitent with parameter_dict's(scalar)." | msg = ("Net parameters {} shape({}) is not (1,), inconsitent with parameter_dict's(scalar)." | ||||
| .format(param.name, param.data.shape())) | |||||
| .format(param.name, param.data.shape)) | |||||
| raise RuntimeError(msg) | raise RuntimeError(msg) | ||||
| param.set_parameter_data(initializer(new_param.data, param.data.shape(), param.data.dtype())) | |||||
| param.set_parameter_data(initializer(new_param.data, param.data.shape, param.data.dtype)) | |||||
| elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor): | elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor): | ||||
| logger.error("Failed to combine the net and the parameters for param %s.", param.name) | logger.error("Failed to combine the net and the parameters for param %s.", param.name) | ||||
| @@ -120,12 +120,12 @@ def save_checkpoint(parameter_list, ckpoint_file_name): | |||||
| param["data"].init_data() | param["data"].init_data() | ||||
| param_data = param["data"].asnumpy().reshape(-1) | param_data = param["data"].asnumpy().reshape(-1) | ||||
| param_tensor.tensor_content = param_data.tostring() | param_tensor.tensor_content = param_data.tostring() | ||||
| param_tensor.tensor_type = str(param["data"].dtype()) | |||||
| param_tensor.tensor_type = str(param["data"].dtype) | |||||
| if param['data'].shape() == (): | |||||
| if param['data'].shape == (): | |||||
| param_tensor.dims.append(0) | param_tensor.dims.append(0) | ||||
| else: | else: | ||||
| for dim in param['data'].shape(): | |||||
| for dim in param['data'].shape: | |||||
| param_tensor.dims.append(dim) | param_tensor.dims.append(dim) | ||||
| with open(ckpoint_file_name, "wb") as f: | with open(ckpoint_file_name, "wb") as f: | ||||
| @@ -73,7 +73,7 @@ class FusedLayerNorm(Cell): | |||||
| Examples: | Examples: | ||||
| >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) | >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) | ||||
| >>> shape1 = x.shape()[1:] | |||||
| >>> shape1 = x.shape[1:] | |||||
| >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) | >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) | ||||
| >>> m(x) | >>> m(x) | ||||
| """ | """ | ||||
| @@ -267,21 +267,21 @@ class MobileNetV2(nn.Cell): | |||||
| if isinstance(m, (nn.Conv2d, DepthwiseConv)): | if isinstance(m, (nn.Conv2d, DepthwiseConv)): | ||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | ||||
| m.weight.data.shape()).astype("float32"))) | |||||
| m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | if m.bias is not None: | ||||
| m.bias.set_parameter_data( | m.bias.set_parameter_data( | ||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.BatchNorm2d): | elif isinstance(m, nn.BatchNorm2d): | ||||
| m.gamma.set_parameter_data( | m.gamma.set_parameter_data( | ||||
| Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) | |||||
| Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) | |||||
| m.beta.set_parameter_data( | m.beta.set_parameter_data( | ||||
| Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.Dense): | elif isinstance(m, nn.Dense): | ||||
| m.weight.set_parameter_data(Tensor(np.random.normal( | m.weight.set_parameter_data(Tensor(np.random.normal( | ||||
| 0, 0.01, m.weight.data.shape()).astype("float32"))) | |||||
| 0, 0.01, m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | if m.bias is not None: | ||||
| m.bias.set_parameter_data( | m.bias.set_parameter_data( | ||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| def mobilenet_v2(**kwargs): | def mobilenet_v2(**kwargs): | ||||
| @@ -322,21 +322,21 @@ class MobileNetV3(nn.Cell): | |||||
| if isinstance(m, (nn.Conv2d)): | if isinstance(m, (nn.Conv2d)): | ||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | ||||
| m.weight.data.shape()).astype("float32"))) | |||||
| m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | if m.bias is not None: | ||||
| m.bias.set_parameter_data( | m.bias.set_parameter_data( | ||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.BatchNorm2d): | elif isinstance(m, nn.BatchNorm2d): | ||||
| m.gamma.set_parameter_data( | m.gamma.set_parameter_data( | ||||
| Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) | |||||
| Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) | |||||
| m.beta.set_parameter_data( | m.beta.set_parameter_data( | ||||
| Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.Dense): | elif isinstance(m, nn.Dense): | ||||
| m.weight.set_parameter_data(Tensor(np.random.normal( | m.weight.set_parameter_data(Tensor(np.random.normal( | ||||
| 0, 0.01, m.weight.data.shape()).astype("float32"))) | |||||
| 0, 0.01, m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | if m.bias is not None: | ||||
| m.bias.set_parameter_data( | m.bias.set_parameter_data( | ||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| def mobilenet_v3(model_name, **kwargs): | def mobilenet_v3(model_name, **kwargs): | ||||
| @@ -66,12 +66,12 @@ if __name__ == '__main__': | |||||
| for _, cell in net.cells_and_names(): | for _, cell in net.cells_and_names(): | ||||
| if isinstance(cell, nn.Conv2d): | if isinstance(cell, nn.Conv2d): | ||||
| cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | ||||
| cell.weight.default_input.shape(), | |||||
| cell.weight.default_input.dtype()).to_tensor() | |||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| if isinstance(cell, nn.Dense): | if isinstance(cell, nn.Dense): | ||||
| cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | ||||
| cell.weight.default_input.shape(), | |||||
| cell.weight.default_input.dtype()).to_tensor() | |||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| if not config.label_smooth: | if not config.label_smooth: | ||||
| config.label_smooth_factor = 0.0 | config.label_smooth_factor = 0.0 | ||||
| loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | ||||
| @@ -23,9 +23,9 @@ def init_net_param(network, initialize_mode='TruncatedNormal'): | |||||
| for p in params: | for p in params: | ||||
| if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | ||||
| if initialize_mode == 'TruncatedNormal': | if initialize_mode == 'TruncatedNormal': | ||||
| p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape(), p.data.dtype())) | |||||
| p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape, p.data.dtype)) | |||||
| else: | else: | ||||
| p.set_parameter_data(initialize_mode, p.data.shape(), p.data.dtype()) | |||||
| p.set_parameter_data(initialize_mode, p.data.shape, p.data.dtype) | |||||
| def load_backbone_params(network, param_dict): | def load_backbone_params(network, param_dict): | ||||
| @@ -78,15 +78,15 @@ class GNNFeatureTransform(nn.Cell): | |||||
| self.has_bias = check_bool(has_bias) | self.has_bias = check_bool(has_bias) | ||||
| if isinstance(weight_init, Tensor): | if isinstance(weight_init, Tensor): | ||||
| if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ | |||||
| weight_init.shape()[1] != in_channels: | |||||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||||
| weight_init.shape[1] != in_channels: | |||||
| raise ValueError("weight_init shape error") | raise ValueError("weight_init shape error") | ||||
| self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") | self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") | ||||
| if self.has_bias: | if self.has_bias: | ||||
| if isinstance(bias_init, Tensor): | if isinstance(bias_init, Tensor): | ||||
| if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: | |||||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||||
| raise ValueError("bias_init shape error") | raise ValueError("bias_init shape error") | ||||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | ||||
| @@ -51,4 +51,4 @@ def test_AllGather(): | |||||
| diff = output.asnumpy() - expect | diff = output.asnumpy() - expect | ||||
| error = np.ones(shape=expect.shape) * 1.0e-5 | error = np.ones(shape=expect.shape) * 1.0e-5 | ||||
| assert np.all(diff < error) | assert np.all(diff < error) | ||||
| assert output.shape() == expect.shape | |||||
| assert output.shape == expect.shape | |||||
| @@ -62,19 +62,19 @@ def test_AllReduce(): | |||||
| diff0 = output[0].asnumpy() - expect0 | diff0 = output[0].asnumpy() - expect0 | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output[0].shape() == expect0.shape | |||||
| assert output[0].shape == expect0.shape | |||||
| expect1 = expect0 | expect1 = expect0 | ||||
| diff1 = output[1].asnumpy() - expect1 | diff1 = output[1].asnumpy() - expect1 | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output[1].shape() == expect1.shape | |||||
| assert output[1].shape == expect1.shape | |||||
| expect2 = expect1 | expect2 = expect1 | ||||
| diff2 = output[2].asnumpy() - expect2 | diff2 = output[2].asnumpy() - expect2 | ||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | error2 = np.ones(shape=expect2.shape) * 1.0e-5 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output[2].shape() == expect2.shape | |||||
| assert output[2].shape == expect2.shape | |||||
| class Net2(nn.Cell): | class Net2(nn.Cell): | ||||
| @@ -108,16 +108,16 @@ def test_AllReduce2(): | |||||
| diff0 = abs(output[0].asnumpy() - expect0) | diff0 = abs(output[0].asnumpy() - expect0) | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output[0].shape() == expect0.shape | |||||
| assert output[0].shape == expect0.shape | |||||
| expect1 = expect0 * size | expect1 = expect0 * size | ||||
| diff1 = abs(output[1].asnumpy() - expect1) | diff1 = abs(output[1].asnumpy() - expect1) | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output[1].shape() == expect1.shape | |||||
| assert output[1].shape == expect1.shape | |||||
| expect2 = expect1 * size | expect2 = expect1 * size | ||||
| diff2 = abs(output[2].asnumpy() - expect2) | diff2 = abs(output[2].asnumpy() - expect2) | ||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | error2 = np.ones(shape=expect2.shape) * 1.0e-5 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output[2].shape() == expect2.shape | |||||
| assert output[2].shape == expect2.shape | |||||
| @@ -61,16 +61,16 @@ def test_ReduceScatter(): | |||||
| diff0 = output[0].asnumpy() - expect0 | diff0 = output[0].asnumpy() - expect0 | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output[0].shape() == expect0.shape | |||||
| assert output[0].shape == expect0.shape | |||||
| expect1 = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * size | expect1 = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * size | ||||
| diff1 = output[1].asnumpy() - expect1 | diff1 = output[1].asnumpy() - expect1 | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output[1].shape() == expect1.shape | |||||
| assert output[1].shape == expect1.shape | |||||
| expect2 = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * 1 | expect2 = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * 1 | ||||
| diff2 = output[2].asnumpy() - expect2 | diff2 = output[2].asnumpy() - expect2 | ||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | error2 = np.ones(shape=expect2.shape) * 1.0e-5 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output[2].shape() == expect2.shape | |||||
| assert output[2].shape == expect2.shape | |||||
| @@ -73,7 +73,7 @@ class FusedLayerNorm(Cell): | |||||
| Examples: | Examples: | ||||
| >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) | >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) | ||||
| >>> shape1 = x.shape()[1:] | |||||
| >>> shape1 = x.shape[1:] | |||||
| >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) | >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) | ||||
| >>> m(x) | >>> m(x) | ||||
| """ | """ | ||||
| @@ -75,93 +75,93 @@ def test_tensor_auto_cast(): | |||||
| t_fp64 = Tensor(np.ones([2, 1, 2, 2]), mstype.float64) | t_fp64 = Tensor(np.ones([2, 1, 2, 2]), mstype.float64) | ||||
| net = TensorAutoCast() | net = TensorAutoCast() | ||||
| rs = net(t_uint8, t_int8) | rs = net(t_uint8, t_int8) | ||||
| assert rs.dtype() == mstype.int16 | |||||
| assert rs.dtype == mstype.int16 | |||||
| rs = net(t_uint8, t_int16) | rs = net(t_uint8, t_int16) | ||||
| assert rs.dtype() == mstype.int16 | |||||
| assert rs.dtype == mstype.int16 | |||||
| rs = net(t_uint8, t_int32) | rs = net(t_uint8, t_int32) | ||||
| assert rs.dtype() == mstype.int32 | |||||
| assert rs.dtype == mstype.int32 | |||||
| rs = net(t_uint8, t_int64) | rs = net(t_uint8, t_int64) | ||||
| assert rs.dtype() == mstype.int64 | |||||
| assert rs.dtype == mstype.int64 | |||||
| rs = net(t_int8, t_int16) | rs = net(t_int8, t_int16) | ||||
| assert rs.dtype() == mstype.int16 | |||||
| assert rs.dtype == mstype.int16 | |||||
| rs = net(t_int8, t_int32) | rs = net(t_int8, t_int32) | ||||
| assert rs.dtype() == mstype.int32 | |||||
| assert rs.dtype == mstype.int32 | |||||
| rs = net(t_int8, t_int64) | rs = net(t_int8, t_int64) | ||||
| assert rs.dtype() == mstype.int64 | |||||
| assert rs.dtype == mstype.int64 | |||||
| rs = net(t_int16, t_int32) | rs = net(t_int16, t_int32) | ||||
| assert rs.dtype() == mstype.int32 | |||||
| assert rs.dtype == mstype.int32 | |||||
| rs = net(t_int16, t_int64) | rs = net(t_int16, t_int64) | ||||
| assert rs.dtype() == mstype.int64 | |||||
| assert rs.dtype == mstype.int64 | |||||
| rs = net(t_int32, t_int64) | rs = net(t_int32, t_int64) | ||||
| assert rs.dtype() == mstype.int64 | |||||
| assert rs.dtype == mstype.int64 | |||||
| rs = net(t_fp16, t_fp32) | rs = net(t_fp16, t_fp32) | ||||
| assert rs.dtype() == mstype.float32 | |||||
| assert rs.dtype == mstype.float32 | |||||
| rs = net(t_fp16, t_fp64) | rs = net(t_fp16, t_fp64) | ||||
| assert rs.dtype() == mstype.float64 | |||||
| assert rs.dtype == mstype.float64 | |||||
| rs = net(t_fp32, t_fp64) | rs = net(t_fp32, t_fp64) | ||||
| assert rs.dtype() == mstype.float64 | |||||
| assert rs.dtype == mstype.float64 | |||||
| rs = net(t_uint8, t_fp16) | rs = net(t_uint8, t_fp16) | ||||
| assert rs.dtype() == mstype.float16 | |||||
| assert rs.dtype == mstype.float16 | |||||
| rs = net(t_uint8, t_fp32) | rs = net(t_uint8, t_fp32) | ||||
| assert rs.dtype() == mstype.float32 | |||||
| assert rs.dtype == mstype.float32 | |||||
| rs = net(t_uint8, t_fp64) | rs = net(t_uint8, t_fp64) | ||||
| assert rs.dtype() == mstype.float64 | |||||
| assert rs.dtype == mstype.float64 | |||||
| rs = net(t_int8, t_fp64) | rs = net(t_int8, t_fp64) | ||||
| assert rs.dtype() == mstype.float64 | |||||
| assert rs.dtype == mstype.float64 | |||||
| rs = net(t_int16, t_fp64) | rs = net(t_int16, t_fp64) | ||||
| assert rs.dtype() == mstype.float64 | |||||
| assert rs.dtype == mstype.float64 | |||||
| rs = net(t_int32, t_fp64) | rs = net(t_int32, t_fp64) | ||||
| assert rs.dtype() == mstype.float64 | |||||
| assert rs.dtype == mstype.float64 | |||||
| rs = net(t_int64, t_fp64) | rs = net(t_int64, t_fp64) | ||||
| assert rs.dtype() == mstype.float64 | |||||
| assert rs.dtype == mstype.float64 | |||||
| rs = net(t_fp16, t_int8) | rs = net(t_fp16, t_int8) | ||||
| assert rs.dtype() == mstype.float16 | |||||
| assert rs.dtype == mstype.float16 | |||||
| rs = net(t_fp16, t_uint8) | rs = net(t_fp16, t_uint8) | ||||
| assert rs.dtype() == mstype.float16 | |||||
| assert rs.dtype == mstype.float16 | |||||
| rs = net(t_fp16, t_int16) | rs = net(t_fp16, t_int16) | ||||
| assert rs.dtype() == mstype.float16 | |||||
| assert rs.dtype == mstype.float16 | |||||
| rs = net(t_fp16, t_int32) | rs = net(t_fp16, t_int32) | ||||
| assert rs.dtype() == mstype.float16 | |||||
| assert rs.dtype == mstype.float16 | |||||
| rs = net(t_fp16, t_int64) | rs = net(t_fp16, t_int64) | ||||
| assert rs.dtype() == mstype.float16 | |||||
| assert rs.dtype == mstype.float16 | |||||
| tint = TensorIntAutoCast() | tint = TensorIntAutoCast() | ||||
| rs = tint(t_uint8) | rs = tint(t_uint8) | ||||
| assert rs.dtype() == mstype.uint8 | |||||
| assert rs.dtype == mstype.uint8 | |||||
| rs = tint(t_int8) | rs = tint(t_int8) | ||||
| assert rs.dtype() == mstype.int8 | |||||
| assert rs.dtype == mstype.int8 | |||||
| rs = tint(t_int16) | rs = tint(t_int16) | ||||
| assert rs.dtype() == mstype.int16 | |||||
| assert rs.dtype == mstype.int16 | |||||
| rs = tint(t_int32) | rs = tint(t_int32) | ||||
| assert rs.dtype() == mstype.int32 | |||||
| assert rs.dtype == mstype.int32 | |||||
| rs = tint(t_int64) | rs = tint(t_int64) | ||||
| assert rs.dtype() == mstype.int64 | |||||
| assert rs.dtype == mstype.int64 | |||||
| rs = tint(t_fp16) | rs = tint(t_fp16) | ||||
| assert rs.dtype() == mstype.float16 | |||||
| assert rs.dtype == mstype.float16 | |||||
| rs = tint(t_fp32) | rs = tint(t_fp32) | ||||
| assert rs.dtype() == mstype.float32 | |||||
| assert rs.dtype == mstype.float32 | |||||
| rs = tint(t_fp64) | rs = tint(t_fp64) | ||||
| assert rs.dtype() == mstype.float64 | |||||
| assert rs.dtype == mstype.float64 | |||||
| tfp = TensorFPAutoCast() | tfp = TensorFPAutoCast() | ||||
| rs = tfp(t_uint8) | rs = tfp(t_uint8) | ||||
| assert rs.dtype() == mstype.float32 | |||||
| assert rs.dtype == mstype.float32 | |||||
| rs = tfp(t_int8) | rs = tfp(t_int8) | ||||
| assert rs.dtype() == mstype.float32 | |||||
| assert rs.dtype == mstype.float32 | |||||
| rs = tfp(t_int16) | rs = tfp(t_int16) | ||||
| assert rs.dtype() == mstype.float32 | |||||
| assert rs.dtype == mstype.float32 | |||||
| rs = tfp(t_int32) | rs = tfp(t_int32) | ||||
| assert rs.dtype() == mstype.float32 | |||||
| assert rs.dtype == mstype.float32 | |||||
| rs = tfp(t_int64) | rs = tfp(t_int64) | ||||
| assert rs.dtype() == mstype.float32 | |||||
| assert rs.dtype == mstype.float32 | |||||
| rs = tfp(t_fp16) | rs = tfp(t_fp16) | ||||
| assert rs.dtype() == mstype.float32 | |||||
| assert rs.dtype == mstype.float32 | |||||
| rs = tfp(t_fp32) | rs = tfp(t_fp32) | ||||
| assert rs.dtype() == mstype.float32 | |||||
| assert rs.dtype == mstype.float32 | |||||
| rs = tfp(t_fp64) | rs = tfp(t_fp64) | ||||
| assert rs.dtype() == mstype.float64 | |||||
| assert rs.dtype == mstype.float64 | |||||
| t_uint16 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint16) | t_uint16 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint16) | ||||
| t_uint32 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint32) | t_uint32 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint32) | ||||
| @@ -35,7 +35,7 @@ class Net(nn.Cell): | |||||
| self.biasAdd = P.BiasAdd() | self.biasAdd = P.BiasAdd() | ||||
| if isinstance(bias_init, Tensor): | if isinstance(bias_init, Tensor): | ||||
| if bias_init.dim() != 1 or bias_init.shape()[0] != output_channels: | |||||
| if bias_init.dim() != 1 or bias_init.shape[0] != output_channels: | |||||
| raise ValueError("bias_init shape error") | raise ValueError("bias_init shape error") | ||||
| self.bias = Parameter(initializer( | self.bias = Parameter(initializer( | ||||
| @@ -64,7 +64,7 @@ def convert_type(shapes, types): | |||||
| for np_shape, np_type in zip(shapes, types): | for np_shape, np_type in zip(shapes, types): | ||||
| input_np = np.zeros(np_shape, np_type) | input_np = np.zeros(np_shape, np_type) | ||||
| tensor = Tensor(input_np) | tensor = Tensor(input_np) | ||||
| ms_types.append(tensor.dtype()) | |||||
| ms_types.append(tensor.dtype) | |||||
| return ms_types | return ms_types | ||||
| @@ -34,7 +34,7 @@ class NetArgmax(nn.Cell): | |||||
| x = Tensor(np.array([[1., 20., 5.], | x = Tensor(np.array([[1., 20., 5.], | ||||
| [67., 8., 9.], | [67., 8., 9.], | ||||
| [130., 24., 15.]]).astype(np.float32)) | [130., 24., 15.]]).astype(np.float32)) | ||||
| self.x = Parameter(initializer(x, x.shape()), name='x') | |||||
| self.x = Parameter(initializer(x, x.shape), name='x') | |||||
| def construct(self): | def construct(self): | ||||
| return self.argmax(self.x) | return self.argmax(self.x) | ||||
| @@ -32,8 +32,8 @@ class NetEqualCount(nn.Cell): | |||||
| self.equalcount = P.EqualCount() | self.equalcount = P.EqualCount() | ||||
| x = Tensor(np.array([1, 20, 5]).astype(np.int32)) | x = Tensor(np.array([1, 20, 5]).astype(np.int32)) | ||||
| y = Tensor(np.array([2, 20, 5]).astype(np.int32)) | y = Tensor(np.array([2, 20, 5]).astype(np.int32)) | ||||
| self.x = Parameter(initializer(x, x.shape()), name='x') | |||||
| self.y = Parameter(initializer(y, y.shape()), name='y') | |||||
| self.x = Parameter(initializer(x, x.shape), name='x') | |||||
| self.y = Parameter(initializer(y, y.shape), name='y') | |||||
| def construct(self): | def construct(self): | ||||
| return self.equalcount(self.x, self.y) | return self.equalcount(self.x, self.y) | ||||
| @@ -33,7 +33,7 @@ class NetSoftmax(nn.Cell): | |||||
| x = Tensor(np.array([[0.1, 0.3, 0.6], | x = Tensor(np.array([[0.1, 0.3, 0.6], | ||||
| [0.2, -0.6, 0.8], | [0.2, -0.6, 0.8], | ||||
| [0.6, 1, 0.4]]).astype(np.float32)) | [0.6, 1, 0.4]]).astype(np.float32)) | ||||
| self.x = Parameter(initializer(x, x.shape()), name='x') | |||||
| self.x = Parameter(initializer(x, x.shape), name='x') | |||||
| def construct(self): | def construct(self): | ||||
| return self.softmax(self.x) | return self.softmax(self.x) | ||||
| @@ -32,9 +32,9 @@ class NetSoftmaxWithCrossEntropy(nn.Cell): | |||||
| logits = Tensor(np.array([[1, 1, 10], | logits = Tensor(np.array([[1, 1, 10], | ||||
| [1, 10, 1], | [1, 10, 1], | ||||
| [10, 1, 1]]).astype(np.float32)) | [10, 1, 1]]).astype(np.float32)) | ||||
| self.logits = Parameter(initializer(logits, logits.shape()), name='logits') | |||||
| self.logits = Parameter(initializer(logits, logits.shape), name='logits') | |||||
| labels = Tensor(np.array([2, 1, 0]).astype(np.int32)) | labels = Tensor(np.array([2, 1, 0]).astype(np.int32)) | ||||
| self.labels = Parameter(initializer(labels, labels.shape()), name='labels') | |||||
| self.labels = Parameter(initializer(labels, labels.shape), name='labels') | |||||
| self.SoftmaxWithCrossEntropy = P.SparseSoftmaxCrossEntropyWithLogits(True) | self.SoftmaxWithCrossEntropy = P.SparseSoftmaxCrossEntropyWithLogits(True) | ||||
| def construct(self): | def construct(self): | ||||
| @@ -50,4 +50,4 @@ def test_correction_mul(): | |||||
| diff = output.asnumpy() - expect | diff = output.asnumpy() - expect | ||||
| assert np.all(diff < error) | assert np.all(diff < error) | ||||
| assert np.all(diff > error * -1) | assert np.all(diff > error * -1) | ||||
| assert output.shape() == expect.shape | |||||
| assert output.shape == expect.shape | |||||
| @@ -65,19 +65,19 @@ def test_equal(): | |||||
| equal = NetEqual() | equal = NetEqual() | ||||
| output0 = equal(x0, y0) | output0 = equal(x0, y0) | ||||
| assert np.all(output0.asnumpy() == expect0) | assert np.all(output0.asnumpy() == expect0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = equal(x1, y1) | output1 = equal(x1, y1) | ||||
| assert np.all(output1.asnumpy() == expect1) | assert np.all(output1.asnumpy() == expect1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| equal = NetEqual() | equal = NetEqual() | ||||
| output0 = equal(x0, y0) | output0 = equal(x0, y0) | ||||
| assert np.all(output0.asnumpy() == expect0) | assert np.all(output0.asnumpy() == expect0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = equal(x1, y1) | output1 = equal(x1, y1) | ||||
| assert np.all(output1.asnumpy() == expect1) | assert np.all(output1.asnumpy() == expect1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @@ -92,13 +92,13 @@ def test_notequal(): | |||||
| notequal = NetNotEqual() | notequal = NetNotEqual() | ||||
| output0 = notequal(x0, y0) | output0 = notequal(x0, y0) | ||||
| assert np.all(output0.asnumpy() == expect0) | assert np.all(output0.asnumpy() == expect0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| notequal = NetNotEqual() | notequal = NetNotEqual() | ||||
| output0 = notequal(x0, y0) | output0 = notequal(x0, y0) | ||||
| assert np.all(output0.asnumpy() == expect0) | assert np.all(output0.asnumpy() == expect0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @@ -113,10 +113,10 @@ def test_greaterqual(): | |||||
| gequal = NetGreaterEqual() | gequal = NetGreaterEqual() | ||||
| output0 = gequal(x0, y0) | output0 = gequal(x0, y0) | ||||
| assert np.all(output0.asnumpy() == expect0) | assert np.all(output0.asnumpy() == expect0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| gequal = NetGreaterEqual() | gequal = NetGreaterEqual() | ||||
| output0 = gequal(x0, y0) | output0 = gequal(x0, y0) | ||||
| assert np.all(output0.asnumpy() == expect0) | assert np.all(output0.asnumpy() == expect0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| @@ -49,19 +49,19 @@ def test_exp(): | |||||
| output0 = exp(x0) | output0 = exp(x0) | ||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = exp(x1) | output1 = exp(x1) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | ||||
| exp = NetExp() | exp = NetExp() | ||||
| output0 = exp(x0) | output0 = exp(x0) | ||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = exp(x1) | output1 = exp(x1) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| @@ -50,10 +50,10 @@ def test_log(): | |||||
| output1 = log(x1) | output1 = log(x1) | ||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| log = NetLog() | log = NetLog() | ||||
| @@ -61,7 +61,7 @@ def test_log(): | |||||
| output1 = log(x1) | output1 = log(x1) | ||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| @@ -64,35 +64,35 @@ def test_mul(): | |||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = mul(x1, y1) | output1 = mul(x1, y1) | ||||
| expect1 = np.multiply(x1_np, y1_np) | expect1 = np.multiply(x1_np, y1_np) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| output2 = mul(x2, y2) | output2 = mul(x2, y2) | ||||
| expect2 = np.multiply(x2_np, y2_np) | expect2 = np.multiply(x2_np, y2_np) | ||||
| diff2 = output2.asnumpy() - expect2 | diff2 = output2.asnumpy() - expect2 | ||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | error2 = np.ones(shape=expect2.shape) * 1.0e-5 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output2.shape() == expect2.shape | |||||
| assert output2.shape == expect2.shape | |||||
| output3 = mul(x3, y3) | output3 = mul(x3, y3) | ||||
| expect3 = np.multiply(x3_np, y3_np) | expect3 = np.multiply(x3_np, y3_np) | ||||
| diff3 = output3.asnumpy() - expect3 | diff3 = output3.asnumpy() - expect3 | ||||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | error3 = np.ones(shape=expect3.shape) * 1.0e-5 | ||||
| assert np.all(diff3 < error3) | assert np.all(diff3 < error3) | ||||
| assert output3.shape() == expect3.shape | |||||
| assert output3.shape == expect3.shape | |||||
| output4 = mul(x4, y4) | output4 = mul(x4, y4) | ||||
| expect4 = np.multiply(x4_np, y4_np) | expect4 = np.multiply(x4_np, y4_np) | ||||
| diff4 = output4.asnumpy() - expect4 | diff4 = output4.asnumpy() - expect4 | ||||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | error4 = np.ones(shape=expect4.shape) * 1.0e-5 | ||||
| assert np.all(diff4 < error4) | assert np.all(diff4 < error4) | ||||
| assert output4.shape() == expect4.shape | |||||
| assert output4.shape == expect4.shape | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| mul = NetMul() | mul = NetMul() | ||||
| @@ -101,32 +101,32 @@ def test_mul(): | |||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = mul(x1, y1) | output1 = mul(x1, y1) | ||||
| expect1 = np.multiply(x1_np, y1_np) | expect1 = np.multiply(x1_np, y1_np) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| output2 = mul(x2, y2) | output2 = mul(x2, y2) | ||||
| expect2 = np.multiply(x2_np, y2_np) | expect2 = np.multiply(x2_np, y2_np) | ||||
| diff2 = output2.asnumpy() - expect2 | diff2 = output2.asnumpy() - expect2 | ||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | error2 = np.ones(shape=expect2.shape) * 1.0e-5 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output2.shape() == expect2.shape | |||||
| assert output2.shape == expect2.shape | |||||
| output3 = mul(x3, y3) | output3 = mul(x3, y3) | ||||
| expect3 = np.multiply(x3_np, y3_np) | expect3 = np.multiply(x3_np, y3_np) | ||||
| diff3 = output3.asnumpy() - expect3 | diff3 = output3.asnumpy() - expect3 | ||||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | error3 = np.ones(shape=expect3.shape) * 1.0e-5 | ||||
| assert np.all(diff3 < error3) | assert np.all(diff3 < error3) | ||||
| assert output3.shape() == expect3.shape | |||||
| assert output3.shape == expect3.shape | |||||
| output4 = mul(x4, y4) | output4 = mul(x4, y4) | ||||
| expect4 = np.multiply(x4_np, y4_np) | expect4 = np.multiply(x4_np, y4_np) | ||||
| diff4 = output4.asnumpy() - expect4 | diff4 = output4.asnumpy() - expect4 | ||||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | error4 = np.ones(shape=expect4.shape) * 1.0e-5 | ||||
| assert np.all(diff4 < error4) | assert np.all(diff4 < error4) | ||||
| assert output4.shape() == expect4.shape | |||||
| assert output4.shape == expect4.shape | |||||
| @@ -49,19 +49,19 @@ def test_neg(): | |||||
| output0 = neg(x0) | output0 = neg(x0) | ||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = neg(x1) | output1 = neg(x1) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| neg = NetNeg() | neg = NetNeg() | ||||
| output0 = neg(x0) | output0 = neg(x0) | ||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = neg(x1) | output1 = neg(x1) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| @@ -64,35 +64,35 @@ def test_real_div(): | |||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = real_div(x1, y1) | output1 = real_div(x1, y1) | ||||
| expect1 = np.divide(x1_np, y1_np) | expect1 = np.divide(x1_np, y1_np) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| output2 = real_div(x2, y2) | output2 = real_div(x2, y2) | ||||
| expect2 = np.divide(x2_np, y2_np) | expect2 = np.divide(x2_np, y2_np) | ||||
| diff2 = output2.asnumpy() - expect2 | diff2 = output2.asnumpy() - expect2 | ||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | error2 = np.ones(shape=expect2.shape) * 1.0e-5 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output2.shape() == expect2.shape | |||||
| assert output2.shape == expect2.shape | |||||
| output3 = real_div(x3, y3) | output3 = real_div(x3, y3) | ||||
| expect3 = np.divide(x3_np, y3_np) | expect3 = np.divide(x3_np, y3_np) | ||||
| diff3 = output3.asnumpy() - expect3 | diff3 = output3.asnumpy() - expect3 | ||||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | error3 = np.ones(shape=expect3.shape) * 1.0e-5 | ||||
| assert np.all(diff3 < error3) | assert np.all(diff3 < error3) | ||||
| assert output3.shape() == expect3.shape | |||||
| assert output3.shape == expect3.shape | |||||
| output4 = real_div(x4, y4) | output4 = real_div(x4, y4) | ||||
| expect4 = np.divide(x4_np, y4_np) | expect4 = np.divide(x4_np, y4_np) | ||||
| diff4 = output4.asnumpy() - expect4 | diff4 = output4.asnumpy() - expect4 | ||||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | error4 = np.ones(shape=expect4.shape) * 1.0e-5 | ||||
| assert np.all(diff4 < error4) | assert np.all(diff4 < error4) | ||||
| assert output4.shape() == expect4.shape | |||||
| assert output4.shape == expect4.shape | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | ||||
| real_div = NetRealDiv() | real_div = NetRealDiv() | ||||
| @@ -101,32 +101,32 @@ def test_real_div(): | |||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = real_div(x1, y1) | output1 = real_div(x1, y1) | ||||
| expect1 = np.divide(x1_np, y1_np) | expect1 = np.divide(x1_np, y1_np) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| output2 = real_div(x2, y2) | output2 = real_div(x2, y2) | ||||
| expect2 = np.divide(x2_np, y2_np) | expect2 = np.divide(x2_np, y2_np) | ||||
| diff2 = output2.asnumpy() - expect2 | diff2 = output2.asnumpy() - expect2 | ||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | error2 = np.ones(shape=expect2.shape) * 1.0e-5 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output2.shape() == expect2.shape | |||||
| assert output2.shape == expect2.shape | |||||
| output3 = real_div(x3, y3) | output3 = real_div(x3, y3) | ||||
| expect3 = np.divide(x3_np, y3_np) | expect3 = np.divide(x3_np, y3_np) | ||||
| diff3 = output3.asnumpy() - expect3 | diff3 = output3.asnumpy() - expect3 | ||||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | error3 = np.ones(shape=expect3.shape) * 1.0e-5 | ||||
| assert np.all(diff3 < error3) | assert np.all(diff3 < error3) | ||||
| assert output3.shape() == expect3.shape | |||||
| assert output3.shape == expect3.shape | |||||
| output4 = real_div(x4, y4) | output4 = real_div(x4, y4) | ||||
| expect4 = np.divide(x4_np, y4_np) | expect4 = np.divide(x4_np, y4_np) | ||||
| diff4 = output4.asnumpy() - expect4 | diff4 = output4.asnumpy() - expect4 | ||||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | error4 = np.ones(shape=expect4.shape) * 1.0e-5 | ||||
| assert np.all(diff4 < error4) | assert np.all(diff4 < error4) | ||||
| assert output4.shape() == expect4.shape | |||||
| assert output4.shape == expect4.shape | |||||
| @@ -49,19 +49,19 @@ def test_Reciprocal(): | |||||
| output0 = reciprocal(x0) | output0 = reciprocal(x0) | ||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = reciprocal(x1) | output1 = reciprocal(x1) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| reciprocal = NetReciprocal() | reciprocal = NetReciprocal() | ||||
| output0 = reciprocal(x0) | output0 = reciprocal(x0) | ||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = reciprocal(x1) | output1 = reciprocal(x1) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| @@ -128,43 +128,43 @@ def test_ReduceMax(): | |||||
| diff0 = abs(output[0].asnumpy() - expect0) | diff0 = abs(output[0].asnumpy() - expect0) | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output[0].shape() == expect0.shape | |||||
| assert output[0].shape == expect0.shape | |||||
| expect1 = np.max(x1, axis=axis1, keepdims=keep_dims1) | expect1 = np.max(x1, axis=axis1, keepdims=keep_dims1) | ||||
| diff1 = abs(output[1].asnumpy() - expect1) | diff1 = abs(output[1].asnumpy() - expect1) | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output[1].shape() == expect1.shape | |||||
| assert output[1].shape == expect1.shape | |||||
| expect2 = np.max(x2, axis=axis2, keepdims=keep_dims2) | expect2 = np.max(x2, axis=axis2, keepdims=keep_dims2) | ||||
| diff2 = abs(output[2].asnumpy() - expect2) | diff2 = abs(output[2].asnumpy() - expect2) | ||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | error2 = np.ones(shape=expect2.shape) * 1.0e-5 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output[2].shape() == expect2.shape | |||||
| assert output[2].shape == expect2.shape | |||||
| expect3 = np.max(x3, axis=axis3, keepdims=keep_dims3) | expect3 = np.max(x3, axis=axis3, keepdims=keep_dims3) | ||||
| diff3 = abs(output[3].asnumpy() - expect3) | diff3 = abs(output[3].asnumpy() - expect3) | ||||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | error3 = np.ones(shape=expect3.shape) * 1.0e-5 | ||||
| assert np.all(diff3 < error3) | assert np.all(diff3 < error3) | ||||
| assert output[3].shape() == expect3.shape | |||||
| assert output[3].shape == expect3.shape | |||||
| expect4 = np.max(x4, axis=np_axis4, keepdims=keep_dims4) | expect4 = np.max(x4, axis=np_axis4, keepdims=keep_dims4) | ||||
| diff4 = abs(output[4].asnumpy() - expect4) | diff4 = abs(output[4].asnumpy() - expect4) | ||||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | error4 = np.ones(shape=expect4.shape) * 1.0e-5 | ||||
| assert np.all(diff4 < error4) | assert np.all(diff4 < error4) | ||||
| assert output[4].shape() == expect4.shape | |||||
| assert output[4].shape == expect4.shape | |||||
| expect5 = np.max(x5, axis=np_axis5, keepdims=keep_dims5) | expect5 = np.max(x5, axis=np_axis5, keepdims=keep_dims5) | ||||
| diff5 = abs(output[5].asnumpy() - expect5) | diff5 = abs(output[5].asnumpy() - expect5) | ||||
| error5 = np.ones(shape=expect5.shape) * 1.0e-5 | error5 = np.ones(shape=expect5.shape) * 1.0e-5 | ||||
| assert np.all(diff5 < error5) | assert np.all(diff5 < error5) | ||||
| assert output[5].shape() == expect5.shape | |||||
| assert output[5].shape == expect5.shape | |||||
| expect6 = np.max(x6, axis=axis6, keepdims=keep_dims6) | expect6 = np.max(x6, axis=axis6, keepdims=keep_dims6) | ||||
| diff6 = abs(output[6].asnumpy() - expect6) | diff6 = abs(output[6].asnumpy() - expect6) | ||||
| error6 = np.ones(shape=expect6.shape) * 1.0e-5 | error6 = np.ones(shape=expect6.shape) * 1.0e-5 | ||||
| assert np.all(diff6 < error6) | assert np.all(diff6 < error6) | ||||
| assert output[6].shape() == expect6.shape | |||||
| assert output[6].shape == expect6.shape | |||||
| expect7 = np.max(x7, axis=axis7, keepdims=keep_dims7) | expect7 = np.max(x7, axis=axis7, keepdims=keep_dims7) | ||||
| diff7 = abs(output[7].asnumpy() - expect7) | diff7 = abs(output[7].asnumpy() - expect7) | ||||
| @@ -180,88 +180,88 @@ def test_ReduceMean(): | |||||
| diff0 = abs(output[0].asnumpy() - expect0) | diff0 = abs(output[0].asnumpy() - expect0) | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output[0].shape() == expect0.shape | |||||
| assert output[0].shape == expect0.shape | |||||
| expect1 = np.mean(x1, axis=axis1, keepdims=keep_dims1) | expect1 = np.mean(x1, axis=axis1, keepdims=keep_dims1) | ||||
| diff1 = abs(output[1].asnumpy() - expect1) | diff1 = abs(output[1].asnumpy() - expect1) | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output[1].shape() == expect1.shape | |||||
| assert output[1].shape == expect1.shape | |||||
| expect2 = np.mean(x2, axis=axis2, keepdims=keep_dims2) | expect2 = np.mean(x2, axis=axis2, keepdims=keep_dims2) | ||||
| diff2 = abs(output[2].asnumpy() - expect2) | diff2 = abs(output[2].asnumpy() - expect2) | ||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | error2 = np.ones(shape=expect2.shape) * 1.0e-5 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output[2].shape() == expect2.shape | |||||
| assert output[2].shape == expect2.shape | |||||
| expect3 = np.mean(x3, axis=axis3, keepdims=keep_dims3) | expect3 = np.mean(x3, axis=axis3, keepdims=keep_dims3) | ||||
| diff3 = abs(output[3].asnumpy() - expect3) | diff3 = abs(output[3].asnumpy() - expect3) | ||||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | error3 = np.ones(shape=expect3.shape) * 1.0e-5 | ||||
| assert np.all(diff3 < error3) | assert np.all(diff3 < error3) | ||||
| assert output[3].shape() == expect3.shape | |||||
| assert output[3].shape == expect3.shape | |||||
| expect4 = np.mean(x4, axis=axis4, keepdims=keep_dims4) | expect4 = np.mean(x4, axis=axis4, keepdims=keep_dims4) | ||||
| diff4 = abs(output[4].asnumpy() - expect4) | diff4 = abs(output[4].asnumpy() - expect4) | ||||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | error4 = np.ones(shape=expect4.shape) * 1.0e-5 | ||||
| assert np.all(diff4 < error4) | assert np.all(diff4 < error4) | ||||
| assert output[4].shape() == expect4.shape | |||||
| assert output[4].shape == expect4.shape | |||||
| expect5 = np.mean(x5, axis=axis5, keepdims=keep_dims5) | expect5 = np.mean(x5, axis=axis5, keepdims=keep_dims5) | ||||
| diff5 = abs(output[5].asnumpy() - expect5) | diff5 = abs(output[5].asnumpy() - expect5) | ||||
| error5 = np.ones(shape=expect5.shape) * 1.0e-5 | error5 = np.ones(shape=expect5.shape) * 1.0e-5 | ||||
| assert np.all(diff5 < error5) | assert np.all(diff5 < error5) | ||||
| assert output[5].shape() == expect5.shape | |||||
| assert output[5].shape == expect5.shape | |||||
| expect6 = np.mean(x6, axis=axis6, keepdims=keep_dims6) | expect6 = np.mean(x6, axis=axis6, keepdims=keep_dims6) | ||||
| diff6 = abs(output[6].asnumpy() - expect6) | diff6 = abs(output[6].asnumpy() - expect6) | ||||
| error6 = np.ones(shape=expect6.shape) * 1.0e-5 | error6 = np.ones(shape=expect6.shape) * 1.0e-5 | ||||
| assert np.all(diff6 < error6) | assert np.all(diff6 < error6) | ||||
| assert output[6].shape() == expect6.shape | |||||
| assert output[6].shape == expect6.shape | |||||
| expect7 = np.mean(x7, axis=axis7, keepdims=keep_dims7) | expect7 = np.mean(x7, axis=axis7, keepdims=keep_dims7) | ||||
| diff7 = abs(output[7].asnumpy() - expect7) | diff7 = abs(output[7].asnumpy() - expect7) | ||||
| error7 = np.ones(shape=expect7.shape) * 1.0e-5 | error7 = np.ones(shape=expect7.shape) * 1.0e-5 | ||||
| assert np.all(diff7 < error7) | assert np.all(diff7 < error7) | ||||
| assert output[7].shape() == expect7.shape | |||||
| assert output[7].shape == expect7.shape | |||||
| expect8 = np.mean(x8, axis=axis8, keepdims=keep_dims8) | expect8 = np.mean(x8, axis=axis8, keepdims=keep_dims8) | ||||
| diff8 = abs(output[8].asnumpy() - expect8) | diff8 = abs(output[8].asnumpy() - expect8) | ||||
| error8 = np.ones(shape=expect8.shape) * 1.0e-5 | error8 = np.ones(shape=expect8.shape) * 1.0e-5 | ||||
| assert np.all(diff8 < error8) | assert np.all(diff8 < error8) | ||||
| assert output[8].shape() == expect8.shape | |||||
| assert output[8].shape == expect8.shape | |||||
| expect9 = np.mean(x9, axis=axis9, keepdims=keep_dims9) | expect9 = np.mean(x9, axis=axis9, keepdims=keep_dims9) | ||||
| diff9 = abs(output[9].asnumpy() - expect9) | diff9 = abs(output[9].asnumpy() - expect9) | ||||
| error9 = np.ones(shape=expect9.shape) * 1.0e-5 | error9 = np.ones(shape=expect9.shape) * 1.0e-5 | ||||
| assert np.all(diff9 < error9) | assert np.all(diff9 < error9) | ||||
| assert output[9].shape() == expect9.shape | |||||
| assert output[9].shape == expect9.shape | |||||
| expect10 = np.mean(x10, axis=axis10, keepdims=keep_dims10) | expect10 = np.mean(x10, axis=axis10, keepdims=keep_dims10) | ||||
| diff10 = abs(output[10].asnumpy() - expect10) | diff10 = abs(output[10].asnumpy() - expect10) | ||||
| error10 = np.ones(shape=expect10.shape) * 1.0e-5 | error10 = np.ones(shape=expect10.shape) * 1.0e-5 | ||||
| assert np.all(diff10 < error10) | assert np.all(diff10 < error10) | ||||
| assert output[10].shape() == expect10.shape | |||||
| assert output[10].shape == expect10.shape | |||||
| expect11 = np.mean(x11, axis=axis11, keepdims=keep_dims11) | expect11 = np.mean(x11, axis=axis11, keepdims=keep_dims11) | ||||
| diff11 = abs(output[11].asnumpy() - expect11) | diff11 = abs(output[11].asnumpy() - expect11) | ||||
| error11 = np.ones(shape=expect11.shape) * 1.0e-5 | error11 = np.ones(shape=expect11.shape) * 1.0e-5 | ||||
| assert np.all(diff11 < error11) | assert np.all(diff11 < error11) | ||||
| assert output[11].shape() == expect11.shape | |||||
| assert output[11].shape == expect11.shape | |||||
| expect12 = np.mean(x12, axis=axis12, keepdims=keep_dims12) | expect12 = np.mean(x12, axis=axis12, keepdims=keep_dims12) | ||||
| diff12 = abs(output[12].asnumpy() - expect12) | diff12 = abs(output[12].asnumpy() - expect12) | ||||
| error12 = np.ones(shape=expect12.shape) * 1.0e-5 | error12 = np.ones(shape=expect12.shape) * 1.0e-5 | ||||
| assert np.all(diff12 < error12) | assert np.all(diff12 < error12) | ||||
| assert output[12].shape() == expect12.shape | |||||
| assert output[12].shape == expect12.shape | |||||
| expect13 = np.mean(x13, axis=axis13, keepdims=keep_dims13) | expect13 = np.mean(x13, axis=axis13, keepdims=keep_dims13) | ||||
| diff13 = abs(output[13].asnumpy() - expect13) | diff13 = abs(output[13].asnumpy() - expect13) | ||||
| error13 = np.ones(shape=expect13.shape) * 1.0e-5 | error13 = np.ones(shape=expect13.shape) * 1.0e-5 | ||||
| assert np.all(diff13 < error13) | assert np.all(diff13 < error13) | ||||
| assert output[13].shape() == expect13.shape | |||||
| assert output[13].shape == expect13.shape | |||||
| expect14 = np.mean(x14, axis=np_axis14, keepdims=keep_dims14) | expect14 = np.mean(x14, axis=np_axis14, keepdims=keep_dims14) | ||||
| diff14 = abs(output[14].asnumpy() - expect14) | diff14 = abs(output[14].asnumpy() - expect14) | ||||
| error14 = np.ones(shape=expect14.shape) * 1.0e-5 | error14 = np.ones(shape=expect14.shape) * 1.0e-5 | ||||
| assert np.all(diff14 < error14) | assert np.all(diff14 < error14) | ||||
| assert output[14].shape() == expect14.shape | |||||
| assert output[14].shape == expect14.shape | |||||
| @@ -182,88 +182,88 @@ def test_ReduceSum(): | |||||
| diff0 = abs(output[0].asnumpy() - expect0) | diff0 = abs(output[0].asnumpy() - expect0) | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output[0].shape() == expect0.shape | |||||
| assert output[0].shape == expect0.shape | |||||
| expect1 = np.sum(x1, axis=axis1, keepdims=keep_dims1) | expect1 = np.sum(x1, axis=axis1, keepdims=keep_dims1) | ||||
| diff1 = abs(output[1].asnumpy() - expect1) | diff1 = abs(output[1].asnumpy() - expect1) | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output[1].shape() == expect1.shape | |||||
| assert output[1].shape == expect1.shape | |||||
| expect2 = np.sum(x2, axis=axis2, keepdims=keep_dims2) | expect2 = np.sum(x2, axis=axis2, keepdims=keep_dims2) | ||||
| diff2 = abs(output[2].asnumpy() - expect2) | diff2 = abs(output[2].asnumpy() - expect2) | ||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | error2 = np.ones(shape=expect2.shape) * 1.0e-5 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output[2].shape() == expect2.shape | |||||
| assert output[2].shape == expect2.shape | |||||
| expect3 = np.sum(x3, axis=axis3, keepdims=keep_dims3) | expect3 = np.sum(x3, axis=axis3, keepdims=keep_dims3) | ||||
| diff3 = abs(output[3].asnumpy() - expect3) | diff3 = abs(output[3].asnumpy() - expect3) | ||||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | error3 = np.ones(shape=expect3.shape) * 1.0e-5 | ||||
| assert np.all(diff3 < error3) | assert np.all(diff3 < error3) | ||||
| assert output[3].shape() == expect3.shape | |||||
| assert output[3].shape == expect3.shape | |||||
| expect4 = np.sum(x4, axis=np_axis4, keepdims=keep_dims4) | expect4 = np.sum(x4, axis=np_axis4, keepdims=keep_dims4) | ||||
| diff4 = abs(output[4].asnumpy() - expect4) | diff4 = abs(output[4].asnumpy() - expect4) | ||||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | error4 = np.ones(shape=expect4.shape) * 1.0e-5 | ||||
| assert np.all(diff4 < error4) | assert np.all(diff4 < error4) | ||||
| assert output[4].shape() == expect4.shape | |||||
| assert output[4].shape == expect4.shape | |||||
| expect5 = np.sum(x5, axis=np_axis5, keepdims=keep_dims5) | expect5 = np.sum(x5, axis=np_axis5, keepdims=keep_dims5) | ||||
| diff5 = abs(output[5].asnumpy() - expect5) | diff5 = abs(output[5].asnumpy() - expect5) | ||||
| error5 = np.ones(shape=expect5.shape) * 1.0e-5 | error5 = np.ones(shape=expect5.shape) * 1.0e-5 | ||||
| assert np.all(diff5 < error5) | assert np.all(diff5 < error5) | ||||
| assert output[5].shape() == expect5.shape | |||||
| assert output[5].shape == expect5.shape | |||||
| expect6 = np.sum(x6, axis=axis6, keepdims=keep_dims6) | expect6 = np.sum(x6, axis=axis6, keepdims=keep_dims6) | ||||
| diff6 = abs(output[6].asnumpy() - expect6) | diff6 = abs(output[6].asnumpy() - expect6) | ||||
| error6 = np.ones(shape=expect6.shape) * 1.0e-5 | error6 = np.ones(shape=expect6.shape) * 1.0e-5 | ||||
| assert np.all(diff6 < error6) | assert np.all(diff6 < error6) | ||||
| assert output[6].shape() == expect6.shape | |||||
| assert output[6].shape == expect6.shape | |||||
| expect7 = np.sum(x7, axis=axis7, keepdims=keep_dims7) | expect7 = np.sum(x7, axis=axis7, keepdims=keep_dims7) | ||||
| diff7 = abs(output[7].asnumpy() - expect7) | diff7 = abs(output[7].asnumpy() - expect7) | ||||
| error7 = np.ones(shape=expect7.shape) * 1.0e-5 | error7 = np.ones(shape=expect7.shape) * 1.0e-5 | ||||
| assert np.all(diff7 < error7) | assert np.all(diff7 < error7) | ||||
| assert output[7].shape() == expect7.shape | |||||
| assert output[7].shape == expect7.shape | |||||
| expect8 = np.sum(x8, axis=axis8, keepdims=keep_dims8) | expect8 = np.sum(x8, axis=axis8, keepdims=keep_dims8) | ||||
| diff8 = abs(output[8].asnumpy() - expect8) | diff8 = abs(output[8].asnumpy() - expect8) | ||||
| error8 = np.ones(shape=expect8.shape) * 1.0e-5 | error8 = np.ones(shape=expect8.shape) * 1.0e-5 | ||||
| assert np.all(diff8 < error8) | assert np.all(diff8 < error8) | ||||
| assert output[8].shape() == expect8.shape | |||||
| assert output[8].shape == expect8.shape | |||||
| expect9 = np.sum(x9, axis=axis9, keepdims=keep_dims9) | expect9 = np.sum(x9, axis=axis9, keepdims=keep_dims9) | ||||
| diff9 = abs(output[9].asnumpy() - expect9) | diff9 = abs(output[9].asnumpy() - expect9) | ||||
| error9 = np.ones(shape=expect9.shape) * 1.0e-5 | error9 = np.ones(shape=expect9.shape) * 1.0e-5 | ||||
| assert np.all(diff9 < error9) | assert np.all(diff9 < error9) | ||||
| assert output[9].shape() == expect9.shape | |||||
| assert output[9].shape == expect9.shape | |||||
| expect10 = np.sum(x10, axis=axis10, keepdims=keep_dims10) | expect10 = np.sum(x10, axis=axis10, keepdims=keep_dims10) | ||||
| diff10 = abs(output[10].asnumpy() - expect10) | diff10 = abs(output[10].asnumpy() - expect10) | ||||
| error10 = np.ones(shape=expect10.shape) * 1.0e-5 | error10 = np.ones(shape=expect10.shape) * 1.0e-5 | ||||
| assert np.all(diff10 < error10) | assert np.all(diff10 < error10) | ||||
| assert output[10].shape() == expect10.shape | |||||
| assert output[10].shape == expect10.shape | |||||
| expect11 = np.sum(x11, axis=axis11, keepdims=keep_dims11) | expect11 = np.sum(x11, axis=axis11, keepdims=keep_dims11) | ||||
| diff11 = abs(output[11].asnumpy() - expect11) | diff11 = abs(output[11].asnumpy() - expect11) | ||||
| error11 = np.ones(shape=expect11.shape) * 1.0e-5 | error11 = np.ones(shape=expect11.shape) * 1.0e-5 | ||||
| assert np.all(diff11 < error11) | assert np.all(diff11 < error11) | ||||
| assert output[11].shape() == expect11.shape | |||||
| assert output[11].shape == expect11.shape | |||||
| expect12 = np.sum(x12, axis=axis12, keepdims=keep_dims12) | expect12 = np.sum(x12, axis=axis12, keepdims=keep_dims12) | ||||
| diff12 = abs(output[12].asnumpy() - expect12) | diff12 = abs(output[12].asnumpy() - expect12) | ||||
| error12 = np.ones(shape=expect12.shape) * 1.0e-5 | error12 = np.ones(shape=expect12.shape) * 1.0e-5 | ||||
| assert np.all(diff12 < error12) | assert np.all(diff12 < error12) | ||||
| assert output[12].shape() == expect12.shape | |||||
| assert output[12].shape == expect12.shape | |||||
| expect13 = np.sum(x13, axis=axis13, keepdims=keep_dims13) | expect13 = np.sum(x13, axis=axis13, keepdims=keep_dims13) | ||||
| diff13 = abs(output[13].asnumpy() - expect13) | diff13 = abs(output[13].asnumpy() - expect13) | ||||
| error13 = np.ones(shape=expect13.shape) * 1.0e-5 | error13 = np.ones(shape=expect13.shape) * 1.0e-5 | ||||
| assert np.all(diff13 < error13) | assert np.all(diff13 < error13) | ||||
| assert output[13].shape() == expect13.shape | |||||
| assert output[13].shape == expect13.shape | |||||
| expect14 = np.sum(x14, axis=np_axis14, keepdims=keep_dims14) | expect14 = np.sum(x14, axis=np_axis14, keepdims=keep_dims14) | ||||
| diff14 = abs(output[14].asnumpy() - expect14) | diff14 = abs(output[14].asnumpy() - expect14) | ||||
| error14 = np.ones(shape=expect14.shape) * 1.0e-5 | error14 = np.ones(shape=expect14.shape) * 1.0e-5 | ||||
| assert np.all(diff14 < error14) | assert np.all(diff14 < error14) | ||||
| assert output[14].shape() == expect14.shape | |||||
| assert output[14].shape == expect14.shape | |||||
| @@ -76,19 +76,19 @@ def test_Sub(): | |||||
| output4 = sub(x4, y4) | output4 = sub(x4, y4) | ||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| diff2 = output2.asnumpy() - expect2 | diff2 = output2.asnumpy() - expect2 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output2.shape() == expect2.shape | |||||
| assert output2.shape == expect2.shape | |||||
| diff3 = output3.asnumpy() - expect3 | diff3 = output3.asnumpy() - expect3 | ||||
| assert np.all(diff3 < error3) | assert np.all(diff3 < error3) | ||||
| assert output3.shape() == expect3.shape | |||||
| assert output3.shape == expect3.shape | |||||
| diff4 = output4.asnumpy() - expect4 | diff4 = output4.asnumpy() - expect4 | ||||
| assert np.all(diff4 < error4) | assert np.all(diff4 < error4) | ||||
| assert output4.shape() == expect4.shape | |||||
| assert output4.shape == expect4.shape | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| sub = Net() | sub = Net() | ||||
| @@ -99,16 +99,16 @@ def test_Sub(): | |||||
| output4 = sub(x4, y4) | output4 = sub(x4, y4) | ||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| diff2 = output2.asnumpy() - expect2 | diff2 = output2.asnumpy() - expect2 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output2.shape() == expect2.shape | |||||
| assert output2.shape == expect2.shape | |||||
| diff3 = output3.asnumpy() - expect3 | diff3 = output3.asnumpy() - expect3 | ||||
| assert np.all(diff3 < error3) | assert np.all(diff3 < error3) | ||||
| assert output3.shape() == expect3.shape | |||||
| assert output3.shape == expect3.shape | |||||
| diff4 = output4.asnumpy() - expect4 | diff4 = output4.asnumpy() - expect4 | ||||
| assert np.all(diff4 < error4) | assert np.all(diff4 < error4) | ||||
| assert output4.shape() == expect4.shape | |||||
| assert output4.shape == expect4.shape | |||||
| @@ -65,16 +65,16 @@ def test_tile(): | |||||
| diff0 = output[0].asnumpy() - expect0 | diff0 = output[0].asnumpy() - expect0 | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output[0].shape() == expect0.shape | |||||
| assert output[0].shape == expect0.shape | |||||
| expect1 = np.tile(input_x1, mul1) | expect1 = np.tile(input_x1, mul1) | ||||
| diff1 = output[1].asnumpy() - expect1 | diff1 = output[1].asnumpy() - expect1 | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output[1].shape() == expect1.shape | |||||
| assert output[1].shape == expect1.shape | |||||
| expect2 = np.tile(input_x2, mul2) | expect2 = np.tile(input_x2, mul2) | ||||
| diff2 = output[2].asnumpy() - expect2 | diff2 = output[2].asnumpy() - expect2 | ||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | error2 = np.ones(shape=expect2.shape) * 1.0e-5 | ||||
| assert np.all(diff2 < error2) | assert np.all(diff2 < error2) | ||||
| assert output[2].shape() == expect2.shape | |||||
| assert output[2].shape == expect2.shape | |||||
| @@ -50,14 +50,14 @@ def test_ZerosLike(): | |||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = zeros_like(x1) | output1 = zeros_like(x1) | ||||
| expect1 = np.zeros_like(x1_np) | expect1 = np.zeros_like(x1_np) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| zeros_like = NetZerosLike() | zeros_like = NetZerosLike() | ||||
| @@ -66,11 +66,11 @@ def test_ZerosLike(): | |||||
| diff0 = output0.asnumpy() - expect0 | diff0 = output0.asnumpy() - expect0 | ||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | error0 = np.ones(shape=expect0.shape) * 1.0e-5 | ||||
| assert np.all(diff0 < error0) | assert np.all(diff0 < error0) | ||||
| assert output0.shape() == expect0.shape | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = zeros_like(x1) | output1 = zeros_like(x1) | ||||
| expect1 = np.zeros_like(x1_np) | expect1 = np.zeros_like(x1_np) | ||||
| diff1 = output1.asnumpy() - expect1 | diff1 = output1.asnumpy() - expect1 | ||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | error1 = np.ones(shape=expect1.shape) * 1.0e-5 | ||||
| assert np.all(diff1 < error1) | assert np.all(diff1 < error1) | ||||
| assert output1.shape() == expect1.shape | |||||
| assert output1.shape == expect1.shape | |||||
| @@ -20,6 +20,7 @@ import mindspore.nn as nn | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common import dtype as mstype | |||||
| from tests.ut.python.ut_filter import non_graph_engine | from tests.ut.python.ut_filter import non_graph_engine | ||||
| from tests.mindspore_test_framework.mindspore_test import mindspore_test | from tests.mindspore_test_framework.mindspore_test import mindspore_test | ||||
| from tests.mindspore_test_framework.pipeline.forward.compile_forward \ | from tests.mindspore_test_framework.pipeline.forward.compile_forward \ | ||||
| @@ -44,7 +45,12 @@ def test_list_equal(): | |||||
| y = Tensor(np.zeros([3, 4, 5], np.int32)) | y = Tensor(np.zeros([3, 4, 5], np.int32)) | ||||
| z = [1, 2, 3] | z = [1, 2, 3] | ||||
| net = Net(z) | net = Net(z) | ||||
| assert net(x, y) == x | |||||
| ret = net(x, y) | |||||
| print(ret.asnumpy()) | |||||
| assert ret == x | |||||
| assert ret.dtype == mstype.int32 | |||||
| assert ret.shape == (6, 8, 10) | |||||
| def test_list_not_equal(): | def test_list_not_equal(): | ||||
| @@ -33,7 +33,7 @@ class Net(nn.Cell): | |||||
| self.biasAdd = P.BiasAdd() | self.biasAdd = P.BiasAdd() | ||||
| if isinstance(bias_init, Tensor): | if isinstance(bias_init, Tensor): | ||||
| if bias_init.dim() != 1 or bias_init.shape()[0] != output_channels: | |||||
| if bias_init.dim() != 1 or bias_init.shape[0] != output_channels: | |||||
| raise ValueError("bias_init shape error") | raise ValueError("bias_init shape error") | ||||
| self.bias = Parameter(initializer( | self.bias = Parameter(initializer( | ||||
| @@ -65,7 +65,7 @@ def test_bias_add(test_with_simu): | |||||
| self.biasAdd = P.BiasAdd() | self.biasAdd = P.BiasAdd() | ||||
| if isinstance(bias_init, Tensor): | if isinstance(bias_init, Tensor): | ||||
| if bias_init.dim() != 1 or bias_init.shape()[0] != output_channels: | |||||
| if bias_init.dim() != 1 or bias_init.shape[0] != output_channels: | |||||
| raise ValueError("bias_init shape error") | raise ValueError("bias_init shape error") | ||||
| self.bias = Parameter(initializer( | self.bias = Parameter(initializer( | ||||
| @@ -50,148 +50,148 @@ def test_tensor(): | |||||
| """test_tensor""" | """test_tensor""" | ||||
| t1 = ms.Tensor(ndarr) | t1 = ms.Tensor(ndarr) | ||||
| assert isinstance(t1, ms.Tensor) | assert isinstance(t1, ms.Tensor) | ||||
| assert t1.dtype() == ms.float64 | |||||
| assert t1.dtype == ms.float64 | |||||
| t2 = ms.Tensor(np.zeros([1, 2, 3]), ms.float32) | t2 = ms.Tensor(np.zeros([1, 2, 3]), ms.float32) | ||||
| assert isinstance(t2, ms.Tensor) | assert isinstance(t2, ms.Tensor) | ||||
| assert t2.shape() == (1, 2, 3) | |||||
| assert t2.dtype() == ms.float32 | |||||
| assert t2.shape == (1, 2, 3) | |||||
| assert t2.dtype == ms.float32 | |||||
| t3 = ms.Tensor(0.1) | t3 = ms.Tensor(0.1) | ||||
| assert isinstance(t3, ms.Tensor) | assert isinstance(t3, ms.Tensor) | ||||
| assert t3.dtype() == ms.float64 | |||||
| assert t3.dtype == ms.float64 | |||||
| t4 = ms.Tensor(1) | t4 = ms.Tensor(1) | ||||
| assert isinstance(t4, ms.Tensor) | assert isinstance(t4, ms.Tensor) | ||||
| assert t4.dtype() == ms.int64 | |||||
| assert t4.dtype == ms.int64 | |||||
| def test_tensor_type_float16(): | def test_tensor_type_float16(): | ||||
| t_float16 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float16)) | t_float16 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float16)) | ||||
| assert isinstance(t_float16, ms.Tensor) | assert isinstance(t_float16, ms.Tensor) | ||||
| assert t_float16.shape() == (2, 3) | |||||
| assert t_float16.dtype() == ms.float16 | |||||
| assert t_float16.shape == (2, 3) | |||||
| assert t_float16.dtype == ms.float16 | |||||
| def test_tensor_type_float32(): | def test_tensor_type_float32(): | ||||
| t_float32 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)) | t_float32 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)) | ||||
| assert isinstance(t_float32, ms.Tensor) | assert isinstance(t_float32, ms.Tensor) | ||||
| assert t_float32.shape() == (2, 3) | |||||
| assert t_float32.dtype() == ms.float32 | |||||
| assert t_float32.shape == (2, 3) | |||||
| assert t_float32.dtype == ms.float32 | |||||
| def test_tensor_type_float32_user_define(): | def test_tensor_type_float32_user_define(): | ||||
| t = ms.Tensor(np.zeros([1, 2, 3]), ms.float32) | t = ms.Tensor(np.zeros([1, 2, 3]), ms.float32) | ||||
| assert isinstance(t, ms.Tensor) | assert isinstance(t, ms.Tensor) | ||||
| assert t.shape() == (1, 2, 3) | |||||
| assert t.dtype() == ms.float32 | |||||
| assert t.shape == (1, 2, 3) | |||||
| assert t.dtype == ms.float32 | |||||
| def test_tensor_type_float64(): | def test_tensor_type_float64(): | ||||
| t = ms.Tensor([[1.0, 2, 3], [4, 5, 6]]) | t = ms.Tensor([[1.0, 2, 3], [4, 5, 6]]) | ||||
| assert isinstance(t, ms.Tensor) | assert isinstance(t, ms.Tensor) | ||||
| assert t.shape() == (2, 3) | |||||
| assert t.dtype() == ms.float64 | |||||
| assert t.shape == (2, 3) | |||||
| assert t.dtype == ms.float64 | |||||
| t_zero = ms.Tensor(np.zeros([1, 2, 3])) | t_zero = ms.Tensor(np.zeros([1, 2, 3])) | ||||
| assert isinstance(t_zero, ms.Tensor) | assert isinstance(t_zero, ms.Tensor) | ||||
| assert t_zero.shape() == (1, 2, 3) | |||||
| assert t_zero.dtype() == ms.float64 | |||||
| assert t_zero.shape == (1, 2, 3) | |||||
| assert t_zero.dtype == ms.float64 | |||||
| def test_tensor_type_float64_user_define(): | def test_tensor_type_float64_user_define(): | ||||
| t = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=float)) | t = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=float)) | ||||
| assert isinstance(t, ms.Tensor) | assert isinstance(t, ms.Tensor) | ||||
| assert t.shape() == (2, 3) | |||||
| assert t.dtype() == ms.float64 | |||||
| assert t.shape == (2, 3) | |||||
| assert t.dtype == ms.float64 | |||||
| t_float64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]]), ms.float64) | t_float64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]]), ms.float64) | ||||
| assert isinstance(t_float64, ms.Tensor) | assert isinstance(t_float64, ms.Tensor) | ||||
| assert t_float64.shape() == (2, 3) | |||||
| assert t_float64.dtype() == ms.float64 | |||||
| assert t_float64.shape == (2, 3) | |||||
| assert t_float64.dtype == ms.float64 | |||||
| def test_tensor_type_bool(): | def test_tensor_type_bool(): | ||||
| # init a tensor with bool type | # init a tensor with bool type | ||||
| ts_bool_array = ms.Tensor(np.zeros([2, 3], np.bool), ms.bool_) | ts_bool_array = ms.Tensor(np.zeros([2, 3], np.bool), ms.bool_) | ||||
| assert isinstance(ts_bool_array, ms.Tensor) | assert isinstance(ts_bool_array, ms.Tensor) | ||||
| assert ts_bool_array.dtype() == ms.bool_ | |||||
| assert ts_bool_array.dtype == ms.bool_ | |||||
| t_bool = ms.Tensor(True) | t_bool = ms.Tensor(True) | ||||
| assert isinstance(t_bool, ms.Tensor) | assert isinstance(t_bool, ms.Tensor) | ||||
| assert t_bool.dtype() == ms.bool_ | |||||
| assert t_bool.dtype == ms.bool_ | |||||
| t_bool_array = ms.Tensor(np.array([[True, False, True], [False, False, False]])) | t_bool_array = ms.Tensor(np.array([[True, False, True], [False, False, False]])) | ||||
| assert isinstance(t_bool_array, ms.Tensor) | assert isinstance(t_bool_array, ms.Tensor) | ||||
| assert t_bool_array.shape() == (2, 3) | |||||
| assert t_bool_array.dtype() == ms.bool_ | |||||
| assert t_bool_array.shape == (2, 3) | |||||
| assert t_bool_array.dtype == ms.bool_ | |||||
| def test_tensor_type_int8(): | def test_tensor_type_int8(): | ||||
| t_int8_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int8)) | t_int8_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int8)) | ||||
| assert isinstance(t_int8_array, ms.Tensor) | assert isinstance(t_int8_array, ms.Tensor) | ||||
| assert t_int8_array.shape() == (2, 3) | |||||
| assert t_int8_array.dtype() == ms.int8 | |||||
| assert t_int8_array.shape == (2, 3) | |||||
| assert t_int8_array.dtype == ms.int8 | |||||
| def test_tensor_type_int16(): | def test_tensor_type_int16(): | ||||
| t_int16_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)) | t_int16_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)) | ||||
| assert isinstance(t_int16_array, ms.Tensor) | assert isinstance(t_int16_array, ms.Tensor) | ||||
| assert t_int16_array.shape() == (2, 3) | |||||
| assert t_int16_array.dtype() == ms.int16 | |||||
| assert t_int16_array.shape == (2, 3) | |||||
| assert t_int16_array.dtype == ms.int16 | |||||
| def test_tensor_type_int32(): | def test_tensor_type_int32(): | ||||
| t_int = ms.Tensor([[1, 2, 3], [4, 5, 6]]) | t_int = ms.Tensor([[1, 2, 3], [4, 5, 6]]) | ||||
| assert isinstance(t_int, ms.Tensor) | assert isinstance(t_int, ms.Tensor) | ||||
| assert t_int.shape() == (2, 3) | |||||
| assert t_int.dtype() == ms.int64 | |||||
| assert t_int.shape == (2, 3) | |||||
| assert t_int.dtype == ms.int64 | |||||
| t_int_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) | t_int_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) | ||||
| assert isinstance(t_int_array, ms.Tensor) | assert isinstance(t_int_array, ms.Tensor) | ||||
| assert t_int_array.shape() == (2, 3) | |||||
| assert t_int_array.dtype() == ms.int32 | |||||
| assert t_int_array.shape == (2, 3) | |||||
| assert t_int_array.dtype == ms.int32 | |||||
| def test_tensor_type_int64(): | def test_tensor_type_int64(): | ||||
| t_int64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)) | t_int64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)) | ||||
| assert isinstance(t_int64, ms.Tensor) | assert isinstance(t_int64, ms.Tensor) | ||||
| assert t_int64.shape() == (2, 3) | |||||
| assert t_int64.dtype() == ms.int64 | |||||
| assert t_int64.shape == (2, 3) | |||||
| assert t_int64.dtype == ms.int64 | |||||
| def test_tensor_type_uint8(): | def test_tensor_type_uint8(): | ||||
| t_uint8_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8)) | t_uint8_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8)) | ||||
| assert isinstance(t_uint8_array, ms.Tensor) | assert isinstance(t_uint8_array, ms.Tensor) | ||||
| assert t_uint8_array.shape() == (2, 3) | |||||
| assert t_uint8_array.dtype() == ms.uint8 | |||||
| assert t_uint8_array.shape == (2, 3) | |||||
| assert t_uint8_array.dtype == ms.uint8 | |||||
| def test_tensor_type_uint16(): | def test_tensor_type_uint16(): | ||||
| t_uint16_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)) | t_uint16_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)) | ||||
| assert isinstance(t_uint16_array, ms.Tensor) | assert isinstance(t_uint16_array, ms.Tensor) | ||||
| assert t_uint16_array.shape() == (2, 3) | |||||
| assert t_uint16_array.dtype() == ms.uint16 | |||||
| assert t_uint16_array.shape == (2, 3) | |||||
| assert t_uint16_array.dtype == ms.uint16 | |||||
| def test_tensor_type_uint32(): | def test_tensor_type_uint32(): | ||||
| t_uint32_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)) | t_uint32_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)) | ||||
| assert isinstance(t_uint32_array, ms.Tensor) | assert isinstance(t_uint32_array, ms.Tensor) | ||||
| assert t_uint32_array.shape() == (2, 3) | |||||
| assert t_uint32_array.dtype() == ms.uint32 | |||||
| assert t_uint32_array.shape == (2, 3) | |||||
| assert t_uint32_array.dtype == ms.uint32 | |||||
| def test_tensor_type_uint64(): | def test_tensor_type_uint64(): | ||||
| t_uint64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)) | t_uint64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)) | ||||
| assert isinstance(t_uint64, ms.Tensor) | assert isinstance(t_uint64, ms.Tensor) | ||||
| assert t_uint64.shape() == (2, 3) | |||||
| assert t_uint64.dtype() == ms.uint64 | |||||
| assert t_uint64.shape == (2, 3) | |||||
| assert t_uint64.dtype == ms.uint64 | |||||
| def test_set_type(): | def test_set_type(): | ||||
| t = ms.Tensor(ndarr) | t = ms.Tensor(ndarr) | ||||
| t.set_dtype(ms.float32) | t.set_dtype(ms.float32) | ||||
| assert t.dtype() == ms.float32 | |||||
| assert t.dtype == ms.float32 | |||||
| @non_graph_engine | @non_graph_engine | ||||
| @@ -250,11 +250,11 @@ def test_return_tensor(): | |||||
| tensor_ = exe(net, input_data) | tensor_ = exe(net, input_data) | ||||
| # get shape | # get shape | ||||
| shape_ = tensor_.shape() | |||||
| shape_ = tensor_.shape | |||||
| print("shape = ", shape_) | print("shape = ", shape_) | ||||
| # get type | # get type | ||||
| type_ = tensor_.dtype() | |||||
| type_ = tensor_.dtype | |||||
| print("type = ", type_) | print("type = ", type_) | ||||
| # get value | # get value | ||||
| @@ -71,7 +71,7 @@ def test_tensor_size(): | |||||
| def test_dtype(): | def test_dtype(): | ||||
| a = ms.Tensor(np.ones((2, 3), dtype=np.int32)) | a = ms.Tensor(np.ones((2, 3), dtype=np.int32)) | ||||
| assert a.dtype() == ms.int32 | |||||
| assert a.dtype == ms.int32 | |||||
| def test_asnumpy(): | def test_asnumpy(): | ||||
| @@ -89,7 +89,7 @@ def test_print(): | |||||
| def test_float(): | def test_float(): | ||||
| a = ms.Tensor(np.ones((2, 3)), ms.float16) | a = ms.Tensor(np.ones((2, 3)), ms.float16) | ||||
| assert a.dtype() == ms.float16 | |||||
| assert a.dtype == ms.float16 | |||||
| def test_tensor_method_sub(): | def test_tensor_method_sub(): | ||||
| @@ -71,7 +71,7 @@ def test(name, file_path, batch_size): | |||||
| data_list.append(data.asnumpy()) | data_list.append(data.asnumpy()) | ||||
| batch_data = np.concatenate(data_list, axis=0).transpose((0, 3, 1, 2)) | batch_data = np.concatenate(data_list, axis=0).transpose((0, 3, 1, 2)) | ||||
| input_tensor = Tensor(batch_data) | input_tensor = Tensor(batch_data) | ||||
| print(input_tensor.shape()) | |||||
| print(input_tensor.shape) | |||||
| network(input_tensor) | network(input_tensor) | ||||
| @@ -23,7 +23,7 @@ from mindspore import dtype as mstype | |||||
| def test_check_layer_norm_1(): | def test_check_layer_norm_1(): | ||||
| x = Tensor(np.ones([20, 5, 10, 10]), mstype.float32) | x = Tensor(np.ones([20, 5, 10, 10]), mstype.float32) | ||||
| shape1 = x.shape()[1:] | |||||
| shape1 = x.shape[1:] | |||||
| m = nn.LayerNorm(shape1, -1, 1) | m = nn.LayerNorm(shape1, -1, 1) | ||||
| with pytest.raises(NotImplementedError): | with pytest.raises(NotImplementedError): | ||||
| m(x) | m(x) | ||||
| @@ -31,7 +31,7 @@ def test_check_layer_norm_1(): | |||||
| def test_check_layer_norm_2(): | def test_check_layer_norm_2(): | ||||
| x = Tensor(np.ones([20, 5, 10, 10]), mstype.float32) | x = Tensor(np.ones([20, 5, 10, 10]), mstype.float32) | ||||
| shape1 = x.shape()[1:] | |||||
| shape1 = x.shape[1:] | |||||
| m = nn.LayerNorm(shape1, begin_params_axis=1) | m = nn.LayerNorm(shape1, begin_params_axis=1) | ||||
| with pytest.raises(NotImplementedError): | with pytest.raises(NotImplementedError): | ||||
| m(x) | m(x) | ||||
| @@ -65,7 +65,7 @@ def test_init_Initializer(): | |||||
| def test_init_tensor(): | def test_init_tensor(): | ||||
| tensor = ms.Tensor(np.zeros([1, 2, 3])) | tensor = ms.Tensor(np.zeros([1, 2, 3])) | ||||
| tensor = init.initializer(tensor, [1, 2, 3], ms.float32) | tensor = init.initializer(tensor, [1, 2, 3], ms.float32) | ||||
| assert tensor.shape() == (1, 2, 3) | |||||
| assert tensor.shape == (1, 2, 3) | |||||
| def test_init_zero_default_dtype(): | def test_init_zero_default_dtype(): | ||||
| @@ -126,8 +126,8 @@ def test_load_checkpoint(): | |||||
| assert len(par_dict) == 3 | assert len(par_dict) == 3 | ||||
| assert par_dict['param_test'].name == 'param_test' | assert par_dict['param_test'].name == 'param_test' | ||||
| assert par_dict['param_test'].data.dtype() == mstype.float32 | |||||
| assert par_dict['param_test'].data.shape() == (1, 3, 224, 224) | |||||
| assert par_dict['param_test'].data.dtype == mstype.float32 | |||||
| assert par_dict['param_test'].data.shape == (1, 3, 224, 224) | |||||
| assert isinstance(par_dict, dict) | assert isinstance(par_dict, dict) | ||||
| @@ -46,7 +46,7 @@ def vm_impl_dType(self): | |||||
| def vm_impl(x): | def vm_impl(x): | ||||
| # update the src type | # update the src type | ||||
| return x.dtype() | |||||
| return x.dtype | |||||
| return vm_impl | return vm_impl | ||||