| @@ -19,12 +19,13 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.ops import functional as F | |||
| from ..cell import Cell | |||
| from ...common import dtype as mstype | |||
| from ..._checkparam import Validator as validator | |||
| __all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul'] | |||
| __all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul', 'Moments'] | |||
| class ReduceLogSumExp(Cell): | |||
| @@ -451,3 +452,65 @@ class MatMul(Cell): | |||
| matmul_broadcast = self.squeeze_right_op(matmul_broadcast) | |||
| return matmul_broadcast | |||
| @constexpr | |||
| def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name): | |||
| validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) | |||
| class Moments(Cell): | |||
| """ | |||
| Calculate the mean and variance of `x`. | |||
| Args: | |||
| axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: (). | |||
| keep_dims (bool): If true, The dimension of mean and variance are identical with input's. | |||
| If false, don't keep these dimensions. Default: False. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The tensor to be calculated. Only float16 and float32 are supported. | |||
| Outputs: | |||
| - **mean** (Tensor) - The mean of input x, with the same date type as input x. | |||
| - **variance** (Tensor) - The variance of input x, with the same date type as input x. | |||
| Examples: | |||
| >>> net = nn.Moments(axis=3, keep_dims=True) | |||
| >>> input_x = Tensor(np.array([[[[1, 2, 3, 4], [3, 4, 5, 6]]]]), mindspore.float32) | |||
| >>> mean, var = net(input_x) | |||
| mean: [[[[2.5], [4.5]]]] | |||
| var: [[[[1.25], [1.25]]]] | |||
| """ | |||
| def __init__(self, axis=None, keep_dims=None): | |||
| super(Moments, self).__init__() | |||
| if axis is None: | |||
| axis = () | |||
| if isinstance(axis, tuple): | |||
| for idx, item in enumerate(axis): | |||
| validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name) | |||
| self.axis = validator.check_value_type('axis', axis, [int, tuple], self.cls_name) | |||
| if keep_dims is None: | |||
| keep_dims = False | |||
| self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name) | |||
| self.cast = P.Cast() | |||
| self.reduce_mean = P.ReduceMean(keep_dims=True) | |||
| self.square_diff = P.SquaredDifference() | |||
| self.squeeze = P.Squeeze(self.axis) | |||
| def construct(self, x): | |||
| tensor_dtype = x.dtype | |||
| _check_input_dtype("input x", tensor_dtype, [mstype.float16, mstype.float32], self.cls_name) | |||
| if tensor_dtype == mstype.float16: | |||
| x = self.cast(x, mstype.float32) | |||
| mean = self.reduce_mean(x, self.axis) | |||
| variance = self.reduce_mean(self.square_diff(x, F.stop_gradient(mean)), self.axis) | |||
| if not self.keep_dims: | |||
| mean = self.squeeze(mean) | |||
| variance = self.squeeze(variance) | |||
| if tensor_dtype == mstype.float16: | |||
| mean = self.cast(mean, mstype.float16) | |||
| variance = self.cast(variance, mstype.float16) | |||
| return mean, variance | |||
| return mean, variance | |||
| @@ -27,6 +27,7 @@ from .multitype_ops.add_impl import hyper_add | |||
| from .multitype_ops.ones_like_impl import ones_like | |||
| from .multitype_ops.zeros_like_impl import zeros_like | |||
| from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial | |||
| from .math_ops import count_nonzero | |||
| __all__ = [ | |||
| @@ -47,4 +48,5 @@ __all__ = [ | |||
| 'gamma', | |||
| 'poisson', | |||
| 'multinomial', | |||
| 'clip_by_value',] | |||
| 'clip_by_value', | |||
| 'count_nonzero'] | |||
| @@ -0,0 +1,75 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """math Operations.""" | |||
| from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.ops import functional as F | |||
| from .. import operations as P | |||
| @constexpr | |||
| def _check_validate_axis(axis, name): | |||
| if isinstance(axis, (tuple, list)): | |||
| for idx, item in enumerate(axis): | |||
| validator.check_value_type("axis[%d]" % idx, item, [int], name) | |||
| axis = validator.check_value_type('axis', axis, [int, tuple, list], name) | |||
| return axis | |||
| @constexpr | |||
| def _check_validate_keepdims(keep_dims, name): | |||
| keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name) | |||
| return keep_dims | |||
| def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32): | |||
| """ | |||
| Count number of nonzero elements across axis of input tensor | |||
| Args: | |||
| - **x** (Tensor[Number]) - Input data is used to count non-zero numbers. | |||
| - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Only constant value is allowed. | |||
| Default: (), reduce all dimensions. | |||
| - **keep_dims** (bool) - If true, keep these reduced dimensions and the length is 1. | |||
| If false, don't keep these dimensions. Default: False. | |||
| - **dtype** (Union[Number, mstype.bool_]) - The data type of the output tensor. Only constant value is allowed. | |||
| Default: mstype.int32 | |||
| Returns: | |||
| Tensor, number of nonzero element. The data type is dtype. | |||
| Examples: | |||
| >>> input_tensor = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32)) | |||
| >>> nonzero_num = count_nonzero(x=input_x, axis=[0, 1], keep_dims=True, dtype=mstype.int32) | |||
| nonzero_num: [[3]] | |||
| """ | |||
| const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x') | |||
| axis = _check_validate_axis(axis, "count_nonzero") | |||
| keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero") | |||
| const_utils.check_valid_type(dtype, mstype.number_type + (mstype.bool_,), 'dtype') | |||
| not_equal = P.NotEqual() | |||
| cast = P.Cast() | |||
| reduce_sum = P.ReduceSum(keep_dims) | |||
| nonzero_bool = not_equal(x, 0) | |||
| # ReduceSum only support float16 or float32 tensor. | |||
| nonzero_val = cast(nonzero_bool, mstype.float16) | |||
| nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype) | |||
| return nonzero_num | |||
| @@ -241,9 +241,9 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer): | |||
| - **max** (Tensor) - Value of the max range of the input data x. | |||
| Outputs: | |||
| - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x. | |||
| - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min. | |||
| - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max. | |||
| - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x. | |||
| - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min. | |||
| - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max. | |||
| Examples: | |||
| >>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | |||
| @@ -356,9 +356,9 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer): | |||
| - **max** (Tensor) - Value of the max range of the input data x. | |||
| Outputs: | |||
| - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x. | |||
| - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min. | |||
| - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max. | |||
| - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x. | |||
| - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min. | |||
| - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max. | |||
| Examples: | |||
| >>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32) | |||
| @@ -489,6 +489,7 @@ class DynamicShape(Primitive): | |||
| self.add_prim_attr('is_dynamic_shape', True) | |||
| self.add_prim_attr("dynamic_shape_depends", [0]) | |||
| class Squeeze(PrimitiveWithInfer): | |||
| """ | |||
| Returns a tensor with the same type but dimensions of 1 are removed based on `axis`. | |||
| @@ -26,7 +26,7 @@ from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.ops.operations._quant_ops import FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsPerChannel | |||
| from mindspore.ops.operations import _quant_ops as Q | |||
| from ..ut_filter import non_graph_engine | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| @@ -216,6 +216,31 @@ class HistogramSummaryNet(nn.Cell): | |||
| return out | |||
| class Moments(nn.Cell): | |||
| """Moments net definition""" | |||
| def __init__(self, axis=None, keep_dims=None): | |||
| super(Moments, self).__init__() | |||
| self.moments = nn.Moments(axis=axis, keep_dims=keep_dims) | |||
| def construct(self, input_x): | |||
| mean, variance = self.moments(input_x) | |||
| return mean, variance | |||
| class CountNonZero(nn.Cell): | |||
| """CountNonZero net definition""" | |||
| def __init__(self, axis, keep_dims, dtype): | |||
| super(CountNonZero, self).__init__() | |||
| self.axis = axis | |||
| self.keep_dims = keep_dims | |||
| self.dtype = dtype | |||
| def construct(self, input_x): | |||
| nonzero_num = C.count_nonzero(input_x, self.axis, self.keep_dims, self.dtype) | |||
| return nonzero_num | |||
| class ScatterUpdate(nn.Cell): | |||
| """ScatterUpdate net definition""" | |||
| @@ -1057,14 +1082,22 @@ test_case_math_ops = [ | |||
| 'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])), | |||
| [2, 3], [2, 3]], | |||
| 'desc_bprop': [[2, 3]]}), | |||
| ('Moments', { | |||
| 'block': Moments(axis=(), keep_dims=False), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ('CountNonZero', { | |||
| 'block': CountNonZero(axis=(), keep_dims=False, dtype=mstype.int32), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ('FakeQuantWithMinMaxVars', { | |||
| 'block': FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False), | |||
| 'block': Q.FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32), | |||
| Tensor(np.array([-6]), mstype.float32), | |||
| Tensor(np.array([6]), mstype.float32)], | |||
| 'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)]}), | |||
| ('FakeQuantWithMinMaxVarsPerChannel', { | |||
| 'block': FakeQuantWithMinMaxVarsPerChannel(num_bits=8, narrow_range=False), | |||
| 'block': Q.FakeQuantWithMinMaxVarsPerChannel(num_bits=8, narrow_range=False), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32), | |||
| Tensor(np.array([-6, -1, -2, -3]), mstype.float32), | |||
| Tensor(np.array([6, 1, 2, 3]), mstype.float32)], | |||