| @@ -47,7 +47,7 @@ constexpr auto kEditDistance = "EditDistance"; | |||||
| constexpr auto kGatherD = "GatherD"; | constexpr auto kGatherD = "GatherD"; | ||||
| constexpr auto kIdentity = "Identity"; | constexpr auto kIdentity = "Identity"; | ||||
| constexpr auto kCustRunApi = "RunCpuKernel"; | constexpr auto kCustRunApi = "RunCpuKernel"; | ||||
| const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kGatherD, kIdentity, kMeshgrid}; | |||||
| const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity}; | |||||
| struct AicpuParamHead { | struct AicpuParamHead { | ||||
| uint32_t length; // Total length: include cunstom message | uint32_t length; // Total length: include cunstom message | ||||
| @@ -18,6 +18,8 @@ from .init_data_set_queue import _init_data_set_queue_aicpu | |||||
| from .embedding_lookup import _embedding_lookup_aicpu | from .embedding_lookup import _embedding_lookup_aicpu | ||||
| from .padding import _padding_aicpu | from .padding import _padding_aicpu | ||||
| from .gather import _gather_aicpu | from .gather import _gather_aicpu | ||||
| from .gather_grad import _gather_grad_aicpu | |||||
| from .scatter import _scatter_aicpu | |||||
| from .identity import _identity_aicpu | from .identity import _identity_aicpu | ||||
| from .edit_distance import _edit_distance_aicpu | from .edit_distance import _edit_distance_aicpu | ||||
| from .unique_with_pad import _unique_with_pad_aicpu | from .unique_with_pad import _unique_with_pad_aicpu | ||||
| @@ -30,6 +32,7 @@ from .reshape import _reshape_aicpu | |||||
| from .flatten import _flatten_aicpu | from .flatten import _flatten_aicpu | ||||
| from .squeeze import _squeeze_aicpu | from .squeeze import _squeeze_aicpu | ||||
| from .expand_dims import _expand_dims_aicpu | from .expand_dims import _expand_dims_aicpu | ||||
| from .randperm import _randperm_aicpu | |||||
| from .random_choice_with_mask import _random_choice_with_mask_aicpu | from .random_choice_with_mask import _random_choice_with_mask_aicpu | ||||
| from .pack import _pack_aicpu | from .pack import _pack_aicpu | ||||
| from .ctcloss import _ctcloss_aicpu | from .ctcloss import _ctcloss_aicpu | ||||
| @@ -0,0 +1,54 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """GatherDGrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| gather_grad_op_info = AiCPURegOp("GatherDGrad") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .attr("dim", "int") \ | |||||
| .input(0, "index", "required") \ | |||||
| .input(1, "src", "required") \ | |||||
| .output(0, "output", "required") \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(gather_grad_op_info) | |||||
| def _gather_grad_aicpu(): | |||||
| """GatherDGrad AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Randperm op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| randperm_op_info = AiCPURegOp("Randperm") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .output(0, "y", "required") \ | |||||
| .attr("n", "int") \ | |||||
| .dtype_format(DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(randperm_op_info) | |||||
| def _randperm_aicpu(): | |||||
| """Randperm AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,79 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Scatter op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| scatter_op_info = AiCPURegOp("Scatter") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "target", "required") \ | |||||
| .input(1, "dim", "required") \ | |||||
| .input(2, "index", "required") \ | |||||
| .input(3, "src", "required") \ | |||||
| .output(0, "output", "required") \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(scatter_op_info) | |||||
| def _scatter_aicpu(): | |||||
| """Scatter AiCPU register""" | |||||
| return | |||||
| @@ -56,7 +56,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | ||||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | ||||
| from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | |||||
| from .random_ops import (Randperm, RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | |||||
| RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler) | RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler) | ||||
| from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, | from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, | ||||
| BiasAdd, Conv2D, | BiasAdd, Conv2D, | ||||
| @@ -197,6 +197,7 @@ __all__ = [ | |||||
| 'HSwish', | 'HSwish', | ||||
| 'HSigmoid', | 'HSigmoid', | ||||
| 'Tanh', | 'Tanh', | ||||
| 'Randperm', | |||||
| 'RandomChoiceWithMask', | 'RandomChoiceWithMask', | ||||
| 'StandardNormal', | 'StandardNormal', | ||||
| 'Multinomial', | 'Multinomial', | ||||
| @@ -4631,12 +4631,8 @@ class GatherD(PrimitiveWithInfer): | |||||
| idx_rank = len(idx_shp) | idx_rank = len(idx_shp) | ||||
| validator.check("x_rank, idx_rank", x_rank, "expected", idx_rank, Rel.EQ, self.name) | validator.check("x_rank, idx_rank", x_rank, "expected", idx_rank, Rel.EQ, self.name) | ||||
| dim_v = dim['value'] | dim_v = dim['value'] | ||||
| validator.check("dim value", dim_v, "expected", 0, Rel.GE, self.name) | |||||
| validator.check("dim value", dim_v, "expected", -x_rank, Rel.GE, self.name) | |||||
| validator.check("dim value", dim_v, "expected", x_rank, Rel.LT, self.name) | validator.check("dim value", dim_v, "expected", x_rank, Rel.LT, self.name) | ||||
| for i in range(x_rank): | |||||
| if i == dim_v: | |||||
| continue | |||||
| validator.check("x_shp[{0}], idx_shp[{0}]".format(i), x_shp[i], "expected", idx_shp[i], Rel.EQ, self.name) | |||||
| out = {'shape': index['shape'], | out = {'shape': index['shape'], | ||||
| 'dtype': x['dtype'], | 'dtype': x['dtype'], | ||||
| @@ -346,6 +346,46 @@ class UniformReal(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class Randperm(PrimitiveWithInfer): | |||||
| """ | |||||
| Generates random samples from 0 to n-1. | |||||
| Args: | |||||
| n (int): Number of items expected to get and the number must be greater than 0. Default: 1. | |||||
| dtype (mindspore.dtype): The type of output. Default: mindspore.int32. | |||||
| Outputs: | |||||
| - **output** (Tensor) - The output Tensor with shape :math:`(n,)` and type: dtype. | |||||
| Supported Platforms: | |||||
| ``Ascend`` | |||||
| Examples: | |||||
| >>> randperm = ops.Randperm(20) | |||||
| >>> output = randperm() | |||||
| >>> print(output) | |||||
| [15 6 11 19 14 16 9 5 13 18 4 10 8 0 17 2 14 1 12 3 7] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, n=1, dtype=mstype.int32): | |||||
| """Initialize Randperm""" | |||||
| Validator.check_value_type("n", n, [int], self.name) | |||||
| self.dtype = dtype | |||||
| self.n = n | |||||
| self.init_prim_io_names(inputs=[], outputs=['output']) | |||||
| def infer_shape(self): | |||||
| Validator.check_int(self.n, 1, Rel.GE, "1", self.name) | |||||
| return [self.n] | |||||
| def infer_dtype(self): | |||||
| valid_values = (mstype.int8, mstype.int16, mstype.int32, mstype.int64, | |||||
| mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64) | |||||
| Validator.check_type_name("dtype", self.dtype, valid_values, self.name) | |||||
| return self.dtype | |||||
| class RandomChoiceWithMask(PrimitiveWithInfer): | class RandomChoiceWithMask(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Generates a random sample as index tensor with a mask tensor from a given tensor. | Generates a random sample as index tensor with a mask tensor from a given tensor. | ||||
| @@ -0,0 +1,99 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore | |||||
| import mindspore.nn as nn | |||||
| import mindspore.context as context | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, dim=0): | |||||
| super(Net, self).__init__() | |||||
| self.op = P.GatherD() | |||||
| self.dim = dim | |||||
| def construct(self, x, index): | |||||
| return self.op(x, self.dim, index) | |||||
| class NetGrad(nn.Cell): | |||||
| def __init__(self, dim=0, shape=None): | |||||
| super(NetGrad, self).__init__() | |||||
| self.op = G.GatherDGrad(dim, shape) | |||||
| def construct(self, index, x): | |||||
| return self.op(index, x) | |||||
| def test_net(): | |||||
| x = Tensor(np.array([[772, 231, 508, 545, 615, 249], | |||||
| [923, 210, 480, 696, 482, 761], | |||||
| [465, 904, 521, 824, 607, 669], | |||||
| [156, 539, 56, 159, 916, 566], | |||||
| [122, 676, 714, 261, 19, 936]]), mindspore.int32) | |||||
| index = Tensor(np.array([[0, 0, 0, 1, 1], | |||||
| [0, 0, 0, 1, 4], | |||||
| [0, 0, 0, 1, -1], | |||||
| [1, 1, 1, 0, 0]]), mindspore.int32) | |||||
| dim = 0 | |||||
| net = Net(dim) | |||||
| out = net(x, index) | |||||
| print(out.asnumpy()) | |||||
| expect_out = np.array([[772, 231, 508, 696, 482], | |||||
| [772, 231, 508, 696, 19], | |||||
| [772, 231, 508, 696, 19], | |||||
| [923, 210, 480, 545, 615]]) | |||||
| assert np.array_equal(out.asnumpy(), expect_out) | |||||
| def test_net_bool(): | |||||
| x = Tensor(np.array([[0, 1, 0, 0, 1, 0], | |||||
| [0, 1, 0, 0, 1, 0], | |||||
| [0, 0, 1, 1, 0, 1], | |||||
| [1, 0, 1, 1, 0, 0], | |||||
| [1, 1, 1, 1, 0, 0]]), mindspore.bool_) | |||||
| index = Tensor(np.array([[0, 0, 0, 1, 1], | |||||
| [0, 0, 0, 1, 4], | |||||
| [0, 0, 0, 1, -1], | |||||
| [1, 1, 1, 0, 0]]), mindspore.int32) | |||||
| dim = 0 | |||||
| net = Net(dim) | |||||
| out = net(x, index) | |||||
| print(out.asnumpy()) | |||||
| expect_out = np.array([[0, 1, 0, 0, 1], | |||||
| [0, 1, 0, 0, 0], | |||||
| [0, 1, 0, 0, 0], | |||||
| [0, 1, 0, 0, 1]]).astype(np.bool) | |||||
| assert np.array_equal(out.asnumpy(), expect_out) | |||||
| def test_net_grad(): | |||||
| index = Tensor(np.array([[0, 1, 2, 0, 0], | |||||
| [2, 0, 0, 1, -1]]), mindspore.int32) | |||||
| x = Tensor(np.array([[772, 231, 508, 615, 249], | |||||
| [122, 676, 714, 261, 936]]), mindspore.int32) | |||||
| net = NetGrad(dim=0, shape=(3, 5)) | |||||
| out = net(index, x) | |||||
| print(out.asnumpy()) | |||||
| expect_out = np.array([[772, 676, 714, 615, 249], | |||||
| [0, 231, 0, 261, 0], | |||||
| [122, 0, 508, 0, 936]]) | |||||
| assert np.array_equal(out.asnumpy(), expect_out) | |||||
| @@ -0,0 +1,56 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import mindspore | |||||
| import mindspore.nn as nn | |||||
| import mindspore.context as context | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, n=1, dtype=mindspore.int32): | |||||
| super(Net, self).__init__() | |||||
| self.randperm = P.Randperm(n, dtype) | |||||
| def construct(self): | |||||
| return self.randperm() | |||||
| def test_net(): | |||||
| net = Net() | |||||
| output = net() | |||||
| print(output) | |||||
| print(output.shape) | |||||
| print(output.dtype) | |||||
| assert output.shape == (1,) | |||||
| assert output.dtype == mindspore.int32 | |||||
| assert output.asnumpy()[0] == 0 | |||||
| def test_net_n20(): | |||||
| net = Net(20, mindspore.uint64) | |||||
| output = net() | |||||
| print(output) | |||||
| assert output.shape == (20,) | |||||
| assert output.dtype == mindspore.uint64 | |||||
| sample_set = set() | |||||
| for i in output.asnumpy(): | |||||
| assert i not in sample_set | |||||
| assert 0 <= i < 20 | |||||
| sample_set.add(i) | |||||