| @@ -62,6 +62,10 @@ class FakeQuantWithMinMax(Cell): | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| Examples: | |||
| >>> fake_quant = nn.FakeQuantWithMinMax() | |||
| >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) | |||
| >>> result = fake_quant(input_x) | |||
| """ | |||
| def __init__(self, | |||
| @@ -182,6 +186,12 @@ class Conv2dBatchNormQuant(Cell): | |||
| Outputs: | |||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Examples: | |||
| >>> batchnorm_quant = nn.Conv2dBatchNormQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid", | |||
| >>> dilation=(1, 1)) | |||
| >>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32) | |||
| >>> result = batchnorm_quant(input_x) | |||
| """ | |||
| def __init__(self, | |||
| @@ -339,6 +349,11 @@ class Conv2dQuant(_Conv): | |||
| Outputs: | |||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Examples: | |||
| >>> conv2d_quant = nn.Conv2dQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid", | |||
| >>> dilation=(1, 1)) | |||
| >>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32) | |||
| >>> result = conv2d_quant(input_x) | |||
| """ | |||
| def __init__(self, | |||
| @@ -412,6 +427,11 @@ class DenseQuant(Cell): | |||
| Outputs: | |||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Examples: | |||
| >>> dense_quant = nn.DenseQuant(3, 6) | |||
| >>> input_x = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32) | |||
| >>> result = dense_quant(input_x) | |||
| """ | |||
| def __init__( | |||
| @@ -503,6 +523,10 @@ class ReLUQuant(Cell): | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| Examples: | |||
| >>> relu_quant = nn.ReLUQuant() | |||
| >>> input_x = Tensor(np.array([[1, 2, 0], [-1, -2, 1]]), mindspore.float32) | |||
| >>> result = relu_quant(input_x) | |||
| """ | |||
| def __init__(self, | |||
| @@ -546,6 +570,10 @@ class ReLU6Quant(Cell): | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| Examples: | |||
| >>> relu6_quant = nn.ReLU6Quant(4, 1) | |||
| >>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32) | |||
| >>> result = relu6_quant(input_x) | |||
| """ | |||
| def __init__(self, num_bits=8, quant_delay=0, symmetric=False, | |||
| @@ -584,6 +612,10 @@ class HSwishQuant(Cell): | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| Examples: | |||
| >>> hswish_quant = nn.HSwishQuant(4, 1) | |||
| >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) | |||
| >>> result = hswish_quant(input_x) | |||
| """ | |||
| def __init__(self, | |||
| @@ -633,6 +665,10 @@ class HSigmoidQuant(Cell): | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| Examples: | |||
| >>> hsigmoid_quant = nn.HSigmoidQuant(4, 1) | |||
| >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) | |||
| >>> result = hsigmoid_quant(input_x) | |||
| """ | |||
| def __init__(self, | |||
| @@ -682,6 +718,11 @@ class TensorAddQuant(Cell): | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| Examples: | |||
| >>> add_quant = nn.TensorAddQuant() | |||
| >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) | |||
| >>> input_y = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32) | |||
| >>> result = add_quant(input_x, input_y) | |||
| """ | |||
| def __init__(self, | |||
| @@ -98,7 +98,17 @@ class FakeQuantWithMinMax(PrimitiveWithInfer): | |||
| class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): | |||
| """Performs grad of FakeQuantWithMinMax operation.""" | |||
| r""" | |||
| Performs grad of FakeQuantWithMinMax operation. | |||
| Examples: | |||
| >>> fake_min_max_grad = P.FakeQuantWithMinMaxGrad() | |||
| >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32) | |||
| >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32) | |||
| >>> _min = Tensor(np.array([-4]), mindspore.float32) | |||
| >>> _max = Tensor(np.array([2]), mindspore.float32) | |||
| >>> result = fake_min_max_grad(dout, input_x, _min, _max) | |||
| """ | |||
| support_quant_bit = [4, 8] | |||
| @prim_attr_register | |||
| @@ -149,10 +159,11 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): | |||
| - Tensor, has the same type as input. | |||
| Examples: | |||
| >>> input_tensor = Tensor(np.random.rand(3,4,5,5), mstype.float32) | |||
| >>> min_tensor = Tensor(np.array([-6.0, -6.5, -4.0, -5.0]), mstype.float32) | |||
| >>> max_tensor = Tensor(np.array([6.0, 6.5, 4.0, 5.0]), mstype.float32) | |||
| >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) | |||
| >>> fake_quant = P.FakeQuantWithMinMaxPerChannel() | |||
| >>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32) | |||
| >>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32) | |||
| >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32) | |||
| >>> result = fake_quant(input_x, _min, _max) | |||
| """ | |||
| support_quant_bit = [4, 8] | |||
| channel_idx = 0 | |||
| @@ -190,7 +201,17 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): | |||
| class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): | |||
| """Performs grad of FakeQuantWithMinMaxPerChannel operation.""" | |||
| r""" | |||
| Performs grad of FakeQuantWithMinMaxPerChannel operation. | |||
| Examples: | |||
| >>> fqmmpc_grad = P.FakeQuantWithMinMaxPerChannelGrad() | |||
| >>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32) | |||
| >>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32) | |||
| >>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32) | |||
| >>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32) | |||
| >>> result = fqmmpc_grad(dout, input_x, _min, _max) | |||
| """ | |||
| support_quant_bit = [4, 8] | |||
| @prim_attr_register | |||
| @@ -243,6 +264,13 @@ class BatchNormFold(PrimitiveWithInfer): | |||
| - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| Examples: | |||
| >>> batch_norm_fold = P.BatchNormFold() | |||
| >>> input_x = Tensor(np.array([1, 2, -1, -2, -2, 1]).reshape(2, 3), mindspore.float32) | |||
| >>> mean = Tensor(np.array([0.5, -1, 1,]), mindspore.float32) | |||
| >>> variance = Tensor(np.array([0.36, 0.4, 0.49]), mindspore.float32) | |||
| >>> global_step = Tensor(np.arange(6), mindspore.int32) | |||
| >>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step) | |||
| """ | |||
| channel = 1 | |||
| @@ -273,7 +301,19 @@ class BatchNormFold(PrimitiveWithInfer): | |||
| class BatchNormFoldGrad(PrimitiveWithInfer): | |||
| """Performs grad of BatchNormFold operation.""" | |||
| r""" | |||
| Performs grad of BatchNormFold operation. | |||
| Examples: | |||
| >>> batch_norm_fold_grad = P.BatchNormFoldGrad() | |||
| >>> d_batch_mean = Tensor(np.random.randint(-2., 2., (1, 2, 2, 3)), mindspore.float32) | |||
| >>> d_batch_std = Tensor(np.random.randn(1, 2, 2, 3), mindspore.float32) | |||
| >>> input_x = Tensor(np.random.randint(0, 256, (4, 1, 4, 6)), mindspore.float32) | |||
| >>> batch_mean = Tensor(np.random.randint(-8., 8., (1, 2, 2, 3)), mindspore.float32) | |||
| >>> batch_std = Tensor(np.random.randint(0, 12, (1, 2, 2, 3)), mindspore.float32) | |||
| >>> global_step = Tensor([2], mindspore.int32) | |||
| >>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step) | |||
| """ | |||
| channel = 1 | |||
| @prim_attr_register | |||
| @@ -321,6 +361,12 @@ class CorrectionMul(PrimitiveWithInfer): | |||
| Outputs: | |||
| - **out** (Tensor) - Tensor has the same shape as x. | |||
| Examples: | |||
| >>> correction_mul = P.CorrectionMul() | |||
| >>> input_x = Tensor(np.random.randint(-8, 12, (3, 4)), mindspore.float32) | |||
| >>> batch_std = Tensor(np.array([1.5, 3, 2]), mindspore.float32) | |||
| >>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32) | |||
| >>> out = correction_mul(input_x, batch_std, running_std) | |||
| """ | |||
| channel = 0 | |||
| @@ -343,7 +389,17 @@ class CorrectionMul(PrimitiveWithInfer): | |||
| class CorrectionMulGrad(PrimitiveWithInfer): | |||
| """Performs grad of CorrectionMul operation.""" | |||
| r""" | |||
| Performs grad of CorrectionMul operation. | |||
| Examples: | |||
| >>> correction_mul_grad = P.CorrectionMulGrad() | |||
| >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32) | |||
| >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32) | |||
| >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32) | |||
| >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32) | |||
| >>> result = correction_mul_grad(dout, input_x, gamma, running_std) | |||
| """ | |||
| channel = 0 | |||
| @prim_attr_register | |||
| @@ -385,6 +441,18 @@ class BatchNormFold2(PrimitiveWithInfer): | |||
| Outputs: | |||
| - **y** (Tensor) - Tensor has the same shape as x. | |||
| Examples: | |||
| >>> batch_norm_fold2 = P.BatchNormFold2() | |||
| >>> input_x = Tensor(np.random.randint(-6, 6, (4, 3)), mindspore.float32) | |||
| >>> beta = Tensor(np.array([0.2, -0.1, 0.25]), mindspore.float32) | |||
| >>> gamma = Tensor(np.array([-0.1, -0.25, 0.1]), mindspore.float32) | |||
| >>> batch_std = Tensor(np.array([0.1, 0.2, 0.1]), mindspore.float32) | |||
| >>> batch_mean = Tensor(np.array([0, 0.05, 0.2]), mindspore.float32) | |||
| >>> running_std = Tensor(np.array([0.1, 0.1, 0.3]), mindspore.float32) | |||
| >>> running_mean = Tensor(np.array([-0.1, 0, -0.1]), mindspore.float32) | |||
| >>> global_step = Tensor(np.random.randint(1, 8, (8, )), mindspore.int32) | |||
| >>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean, | |||
| >>> running_std, running_mean, global_step) | |||
| """ | |||
| channel = 1 | |||
| @@ -418,7 +486,21 @@ class BatchNormFold2(PrimitiveWithInfer): | |||
| class BatchNormFold2Grad(PrimitiveWithInfer): | |||
| """Performs grad of CorrectionAddGrad operation.""" | |||
| r""" | |||
| Performs grad of CorrectionAddGrad operation. | |||
| Examples: | |||
| >>> bnf2_grad = P.BatchNormFold2Grad() | |||
| >>> input_x = Tensor(np.arange(3*3*12*12).reshape(6, 3, 6, 12), mindspore.float32) | |||
| >>> dout = Tensor(np.random.randint(-32, 32, (6, 3, 6, 12)), mindspore.float32) | |||
| >>> gamma = Tensor(np.random.randint(-4, 4, (3, 1, 1, 2)), mindspore.float32) | |||
| >>> batch_std = Tensor(np.random.randint(0, 8, (3, 1, 1, 2)), mindspore.float32) | |||
| >>> batch_mean = Tensor(np.random.randint(-6, 6, (3, 1, 1, 2)), mindspore.float32) | |||
| >>> running_std = Tensor(np.linspace(0, 2, 6).reshape(3, 1, 1, 2), mindspore.float32) | |||
| >>> running_mean = Tensor(np.random.randint(-3, 3, (3, 1, 1, 2)), mindspore.float32) | |||
| >>> global_step = Tensor(np.array([-2]), mindspore.int32) | |||
| >>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step) | |||
| """ | |||
| channel = 1 | |||
| @prim_attr_register | |||
| @@ -1156,6 +1156,16 @@ class Tile(PrimitiveWithInfer): | |||
| Such as set the shape of `input_x` as :math:`(1, ..., x_1, x_2, ..., x_S)`, | |||
| then the shape of their corresponding positions can be multiplied, and | |||
| the shape of Outputs is :math:`(1*y_1, ..., x_S*y_R)`. | |||
| Examples: | |||
| >>> tile = P.Tile() | |||
| >>> input_x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32) | |||
| >>> multiples = (2, 3) | |||
| >>> result = tile(input_x, multiples) | |||
| [[1. 2. 1. 2. 1. 2.] | |||
| [3. 4. 3. 4. 3. 4.] | |||
| [1. 2. 1. 2. 1. 2.] | |||
| [3. 4. 3. 4. 3. 4.]] | |||
| """ | |||
| @prim_attr_register | |||
| @@ -144,6 +144,12 @@ class Merge(PrimitiveWithInfer): | |||
| Outputs: | |||
| tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element. | |||
| Examples: | |||
| >>> merge = P.Merge() | |||
| >>> input_x = Tensor(np.linspace(0, 8, 8).reshape(2, 4), mindspore.float32) | |||
| >>> input_y = Tensor(np.random.randint(-4, 4, (2, 4)), mindspore.float32) | |||
| >>> result = merge((input_x, input_y)) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -713,6 +713,12 @@ class Neg(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, has the same shape and dtype as input. | |||
| Examples: | |||
| >>> neg = P.Neg() | |||
| >>> input_x = Tensor(np.array([1, 2, -1, 2, 0, -3.5]), mindspore.float32) | |||
| >>> result = neg(input_x) | |||
| [-1. -2. 1. -2. 0. 3.5] | |||
| """ | |||
| @prim_attr_register | |||
| @@ -1623,6 +1629,7 @@ class LogicalOr(_LogicBinaryOp): | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name) | |||
| class IsNan(PrimitiveWithInfer): | |||
| """ | |||
| Judging which elements are nan for each position | |||
| @@ -1632,6 +1639,11 @@ class IsNan(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, has the same shape of input, and the dtype is bool. | |||
| Examples: | |||
| >>> is_nan = P.IsNan() | |||
| >>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32) | |||
| >>> result = is_nan(input_x) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -1645,6 +1657,7 @@ class IsNan(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype): | |||
| return mstype.bool_ | |||
| class IsInf(PrimitiveWithInfer): | |||
| """ | |||
| Judging which elements are inf or -inf for each position | |||
| @@ -1654,6 +1667,11 @@ class IsInf(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, has the same shape of input, and the dtype is bool. | |||
| Examples: | |||
| >>> is_inf = P.IsInf() | |||
| >>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32) | |||
| >>> result = is_inf(input_x) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -1667,6 +1685,7 @@ class IsInf(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype): | |||
| return mstype.bool_ | |||
| class IsFinite(PrimitiveWithInfer): | |||
| """ | |||
| Judging which elements are finite for each position | |||
| @@ -1676,6 +1695,12 @@ class IsFinite(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, has the same shape of input, and the dtype is bool. | |||
| Examples: | |||
| >>> is_finite = P.IsFinite() | |||
| >>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32) | |||
| >>> result = is_finite(input_x) | |||
| [False True False] | |||
| """ | |||
| @prim_attr_register | |||
| @@ -1691,6 +1716,7 @@ class IsFinite(PrimitiveWithInfer): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) | |||
| return mstype.bool_ | |||
| class FloatStatus(PrimitiveWithInfer): | |||
| """ | |||
| Determine if the elements contains nan, inf or -inf. `0` for normal, `1` for overflow. | |||
| @@ -1701,6 +1727,11 @@ class FloatStatus(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, has the shape of `(1,)`, and has the same dtype of input `mindspore.dtype.float32` or | |||
| `mindspore.dtype.float16`. | |||
| Examples: | |||
| >>> float_status = P.FloatStatus() | |||
| >>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32) | |||
| >>> result = float_status(input_x) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -1714,6 +1745,7 @@ class FloatStatus(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype): | |||
| return x_dtype | |||
| class NPUAllocFloatStatus(PrimitiveWithInfer): | |||
| """ | |||
| Allocates a flag to store the overflow status. | |||
| @@ -393,6 +393,10 @@ class HSwish(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `input_data`. | |||
| Examples: | |||
| >>> hswish = P.HSwish() | |||
| >>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16) | |||
| >>> result = hswish(input_x) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| @@ -406,7 +410,6 @@ class HSwish(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class Sigmoid(PrimitiveWithInfer): | |||
| r""" | |||
| Sigmoid activation function. | |||
| @@ -462,6 +465,10 @@ class HSigmoid(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `input_data`. | |||
| Examples: | |||
| >>> hsigmoid = P.HSigmoid() | |||
| >>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16) | |||
| >>> result = hsigmoid(input_x) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -883,7 +890,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): | |||
| h_out = math.ceil(x_shape[2] / stride_h) | |||
| w_out = math.ceil(x_shape[3] / stride_w) | |||
| pad_needed_h = max(0, (h_out - 1) * stride_h+ dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]) | |||
| pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]) | |||
| pad_top = math.floor(pad_needed_h / 2) | |||
| pad_bottom = pad_needed_h - pad_top | |||
| @@ -1138,6 +1145,33 @@ class AvgPool(_Pool): | |||
| Outputs: | |||
| Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.nn as nn | |||
| >>> import numpy as np | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops import operations as P | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.avgpool_op = P.AvgPool(padding="VALID", ksize=2, strides=1) | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> result = self.avgpool_op(x) | |||
| >>> return result | |||
| >>> | |||
| >>> input_x = Tensor(np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4), mindspore.float32) | |||
| >>> net = Net() | |||
| >>> result = net(input_x) | |||
| [[[[ 2.5 3.5 4.5] | |||
| [ 6.5 7.5 8.5]] | |||
| [[ 14.5 15.5 16.5] | |||
| [ 18.5 19.5 20.5]] | |||
| [[ 26.5 27.5 28.5] | |||
| [ 30.5 31.5 32.5]]]] | |||
| """ | |||
| @prim_attr_register | |||
| @@ -1590,6 +1624,16 @@ class SGD(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, parameters to be updated. | |||
| Examples: | |||
| >>> sgd = P.SGD() | |||
| >>> parameters = Tensor(np.array([2, -0.5, 1.7, 4]), mindspore.float32) | |||
| >>> gradient = Tensor(np.array([1, -1, 0.5, 2]), mindspore.float32) | |||
| >>> learning_rate = Tensor(0.01, mindspore.float32) | |||
| >>> accum = Tensor(np.array([0.1, 0.3, -0.2, -0.1]), mindspore.float32) | |||
| >>> momentum = Tensor(0.1, mindspore.float32) | |||
| >>> stat = Tensor(np.array([1.5, -0.3, 0.2, -0.7]), mindspore.float32) | |||
| >>> result = sgd(parameters, gradient, learning_rate, accum, momentum, stat) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -1620,6 +1664,7 @@ class SGD(PrimitiveWithInfer): | |||
| validator.check_tensor_type_same({"stat": stat_dtype}, valid_types, self.name) | |||
| return parameters_dtype | |||
| class ApplyRMSProp(PrimitiveWithInfer): | |||
| """ | |||
| Optimizer that implements the Root Mean Square prop(RMSProp) algorithm. | |||
| @@ -1659,6 +1704,18 @@ class ApplyRMSProp(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, parameters to be update. | |||
| Examples: | |||
| >>> apply_rms = P.ApplyRMSProp() | |||
| >>> input_x = Tensor(np.random.randint(0, 256, (3, 3)),mindspore.float32) | |||
| >>> mean_square = Tensor(np.random.randint(0, 256, (3, 3)), mindspore.float32) | |||
| >>> moment = Tensor(np.random.randn(3, 3), mindspore.float32) | |||
| >>> grad = Tensor(np.random.randint(-32, 16, (3, 3)), mindspore.float32 ) | |||
| >>> learning_rate = 0.9 | |||
| >>> decay = 0.0 | |||
| >>> momentum = 1e-10 | |||
| >>> epsilon = 0.001 | |||
| >>> result = apply_rms(input_x, mean_square, moment, grad, learning_rate, decay, momentum, epsilon) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -1729,6 +1786,20 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, parameters to be update. | |||
| Examples: | |||
| >>> centered_rms_prop = P.ApplyCenteredRMSProp() | |||
| >>> input_x = Tensor(np.random.randint(0, 256, (3, 3)),mindspore.float32) | |||
| >>> mean_grad = Tensor(np.random.randint(-8, 8, (3, 3)), mindspore.float32) | |||
| >>> mean_square = Tensor(np.random.randint(0, 256, (3, 3)), mindspore.float32) | |||
| >>> moment = Tensor(np.random.randn(3, 3), mindspore.float32) | |||
| >>> grad = Tensor(np.random.randint(-32, 16, (3, 3)), mindspore.float32 ) | |||
| >>> learning_rate = 0.9 | |||
| >>> decay = 0.0 | |||
| >>> momentum = 1e-10 | |||
| >>> epsilon = 0.001 | |||
| >>> result = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad, | |||
| >>> learning_rate, decay, momentum, epsilon) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -1827,6 +1898,18 @@ class L2Normalize(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, with the same type and shape as the input. | |||
| Examples: | |||
| >>> l2_normalize = P.L2Normalize() | |||
| >>> input_x = Tensor(np.random.randint(-256, 256, (2, 3, 4)), mindspore.float32) | |||
| >>> result = l2_normalize(input_x) | |||
| [[[-0.47247353 -0.30934513 -0.4991462 0.8185567 ] | |||
| [-0.08070751 -0.9961299 -0.5741758 0.09262337] | |||
| [-0.9916556 -0.3049123 0.5730487 -0.40579924] | |||
| [[-0.88134485 0.9509498 -0.86651784 0.57442576] | |||
| [ 0.99673784 0.08789381 -0.8187321 0.9957012 ] | |||
| [ 0.12891524 -0.9523804 -0.81952125 0.91396334]]] | |||
| """ | |||
| @prim_attr_register | |||
| @@ -2138,6 +2221,32 @@ class PReLU(PrimitiveWithInfer): | |||
| Tensor, with the same type as `input_x`. | |||
| Detailed information, please refer to `nn.PReLU`. | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.nn as nn | |||
| >>> import numpy as np | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops import operations as P | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.prelu = P.PReLU() | |||
| >>> def construct(self, input_x, weight): | |||
| >>> result = self.prelu(input_x, weight) | |||
| >>> return result | |||
| >>> | |||
| >>> input_x = Tensor(np.random.randint(-3, 3, (2, 3, 2)), mindspore.float32) | |||
| >>> weight = Tensor(np.array([0.1, 0.6, -0.3]), mindspore.float32) | |||
| >>> net = Net() | |||
| >>> result = net(input_x, weight) | |||
| [[[-0.1 1. ] | |||
| [ 0. 2. ] | |||
| [0. 0. ]] | |||
| [[-0.2 -0.1 ] | |||
| [2. -1.8000001] | |||
| [0.6 0.6 ]]] | |||
| """ | |||
| @prim_attr_register | |||
| @@ -2547,6 +2656,27 @@ class BinaryCrossEntropy(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor or Scalar, if `reduction` is 'none', then output is a tensor and same shape as `input_x`. | |||
| Otherwise it is a scalar. | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.nn as nn | |||
| >>> import numpy as np | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops import operations as P | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.binary_cross_entropy = P.BinaryCrossEntropy() | |||
| >>> def construct(self, x, y, weight): | |||
| >>> result = self.binary_cross_entropy(x, y, weight) | |||
| >>> return result | |||
| >>> | |||
| >>> net = Net() | |||
| >>> input_x = Tensor(np.array([0.2, 0.7, 0.1]), mindspore.float32) | |||
| >>> input_y = Tensor(np.array([0., 1., 0.]), mindspore.float32) | |||
| >>> weight = Tensor(np.array([1, 2, 2]), mindspore.float32) | |||
| >>> result = net(input_x, input_y, weight) | |||
| 0.38240486 | |||
| """ | |||
| @prim_attr_register | |||
| @@ -2726,6 +2856,37 @@ class ApplyFtrl(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, representing the updated var. | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.nn as nn | |||
| >>> import numpy as np | |||
| >>> from mindspore import Parameter | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops import operations as P | |||
| >>> class ApplyFtrlNet(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(ApplyFtrlNet, self).__init__() | |||
| >>> self.apply_ftrl = P.ApplyFtrl() | |||
| >>> self.lr = 0.001 | |||
| >>> self.l1 = 0.0 | |||
| >>> self.l2 = 0.0 | |||
| >>> self.lr_power = -0.5 | |||
| >>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") | |||
| >>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") | |||
| >>> self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear") | |||
| >>> | |||
| >>> def construct(self, grad): | |||
| >>> out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, | |||
| >>> self.lr_power) | |||
| >>> return out | |||
| >>> | |||
| >>> net = ApplyFtrlNet() | |||
| >>> input_x = Tensor(np.random.randint(-4, 4, (3, 3)), mindspore.float32) | |||
| >>> result = net(input_x) | |||
| [[0.67455846 0.14630564 0.160499 ] | |||
| [0.16329421 0.00415689 0.05202988] | |||
| [0.18672481 0.17418946 0.36420345]] | |||
| """ | |||
| @prim_attr_register | |||
| @@ -2780,6 +2941,18 @@ class ConfusionMulGrad(PrimitiveWithInfer): | |||
| the shape of output is :math:`(x_1,x_3,...,x_R)`. | |||
| - If axis is tuple(int), set as (2,3), and keep_dims is false, | |||
| the shape of output is :math:`(x_1,x_4,...x_R)`. | |||
| Examples: | |||
| >>> confusion_mul_grad = P.ConfusionMulGrad() | |||
| >>> input_0 = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32) | |||
| >>> input_1 = Tensor(np.random.randint(0, 4, (2, 3)), mindspore.float32) | |||
| >>> input_2 = Tensor(np.random.randint(-4, 0, (2, 3)), mindspore.float32) | |||
| >>> output_0, output_1 = confusion_mul_grad(input_0, input_1, input_2) | |||
| output_0: | |||
| [[ 3. 1. 0.] | |||
| [-6. 2. -2.]] | |||
| output_1: | |||
| -3.0 | |||
| """ | |||
| @prim_attr_register | |||
| @@ -168,6 +168,26 @@ class CheckValid(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, the valided tensor. | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.nn as nn | |||
| >>> import numpy as np | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops import operations as P | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.check_valid = P.CheckValid() | |||
| >>> def construct(self, x, y): | |||
| >>> valid_result = self.check_valid(x, y) | |||
| >>> return valid_result | |||
| >>> | |||
| >>> bboxes = Tensor(np.linspace(0, 6, 12).reshape(3, 4), mindspore.float32) | |||
| >>> img_metas = Tensor(np.array([2, 1, 3]), mindspore.float32) | |||
| >>> net = Net() | |||
| >>> result = net(bboxes, img_metas) | |||
| [True False False] | |||
| """ | |||
| @prim_attr_register | |||