From 3a1a9ff6dc6dd14c0f7e4e89c736d5c27edcedd9 Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Thu, 5 Nov 2020 11:47:40 +0800 Subject: [PATCH] add LogUniformCandidateSampler and ComputeAccidentalHits op for aicpu --- mindspore/ops/_op_impl/aicpu/__init__.py | 3 + .../_op_impl/aicpu/compute_accidental_hits.py | 44 ++++++++++++ .../aicpu/log_uniform_candidate_sampler.py | 36 ++++++++++ .../aicpu/uniform_candidate_sampler.py | 36 ++++++++++ mindspore/ops/operations/__init__.py | 7 +- mindspore/ops/operations/nn_ops.py | 65 +++++++++++++++++ mindspore/ops/operations/random_ops.py | 70 +++++++++++++++++++ .../test_compute_accidental_hits.py | 47 +++++++++++++ .../test_log_uniform_candidate_sampler.py | 45 ++++++++++++ 9 files changed, 351 insertions(+), 2 deletions(-) create mode 100644 mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py create mode 100644 mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py create mode 100644 mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py create mode 100644 tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py create mode 100644 tests/st/ops/ascend/test_aicpu_ops/test_log_uniform_candidate_sampler.py diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index d337ac7c56..f0383e7a5e 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -35,6 +35,9 @@ from .expand_dims import _expand_dims_aicpu from .randperm import _randperm_aicpu from .random_choice_with_mask import _random_choice_with_mask_aicpu from .pack import _pack_aicpu +from .uniform_candidate_sampler import _uniform_candidate_sampler_aicpu +from .log_uniform_candidate_sampler import _log_uniform_candidate_sampler_aicpu +from .compute_accidental_hits import _compute_accidental_hits_aicpu from .ctcloss import _ctcloss_aicpu from .reverse_sequence import _reverse_sequence_aicpu from .crop_and_resize import _crop_and_resize_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py b/mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py new file mode 100644 index 0000000000..01c76c8c11 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ComputeAccidentalHits op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +compute_accidental_hits_op_info = AiCPURegOp("ComputeAccidentalHits") \ + .fusion_type("OPAQUE") \ + .input(0, "true_classes", "required") \ + .input(1, "sampled_candidates", "required") \ + .output(0, "indices", "required") \ + .output(1, "ids", "required") \ + .output(2, "weights", "required") \ + .attr("num_true", "int") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.F64_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.F64_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.F16_Default) \ + .get_op_info() + + +@op_info_register(compute_accidental_hits_op_info) +def _compute_accidental_hits_aicpu(): + """ComputeAccidentalHits AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py b/mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py new file mode 100644 index 0000000000..c5fa9de297 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py @@ -0,0 +1,36 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LogUniformCandidateSampler op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +log_uniform_candidate_sampler_op_info = AiCPURegOp("LogUniformCandidateSampler") \ + .fusion_type("OPAQUE") \ + .input(0, "true_classes", "required") \ + .output(0, "sampled_candidates", "required") \ + .output(1, "true_expected_count", "required") \ + .output(2, "true_expected_count", "required") \ + .attr("num_true", "int") \ + .attr("num_sampled", "int") \ + .attr("unique", "bool") \ + .attr("range_max", "int") \ + .attr("seed", "int") \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(log_uniform_candidate_sampler_op_info) +def _log_uniform_candidate_sampler_aicpu(): + """LogUniformCandidateSampler AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py b/mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py new file mode 100644 index 0000000000..5e4c24f130 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py @@ -0,0 +1,36 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""UniformCandidateSampler op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +uniform_candidate_sampler_op_info = AiCPURegOp("UniformCandidateSampler") \ + .fusion_type("OPAQUE") \ + .input(0, "true_classes", "required") \ + .output(0, "sampled_candidates", "required") \ + .output(1, "true_expected_count", "required") \ + .output(2, "true_expected_count", "required") \ + .attr("num_true", "int") \ + .attr("num_sampled", "int") \ + .attr("unique", "bool") \ + .attr("range_max", "int") \ + .attr("seed", "int") \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(uniform_candidate_sampler_op_info) +def _uniform_candidate_sampler_aicpu(): + """UniformCandidateSampler AiCPU register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 5a45150634..f58cd5d0b5 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -57,7 +57,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) from .random_ops import (Randperm, RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, - RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler) + RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler, + LogUniformCandidateSampler) from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, @@ -67,7 +68,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, LogSoftmax, MaxPool, DataFormatDimMap, - AvgPool, Conv2DBackpropInput, + AvgPool, Conv2DBackpropInput, ComputeAccidentalHits, MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, ResizeBilinear, Sigmoid, SigmoidCrossEntropyWithLogits, @@ -342,6 +343,7 @@ __all__ = [ "SpaceToDepth", "DepthToSpace", "Conv2DBackpropInput", + "ComputeAccidentalHits", "Sign", "LARSUpdate", "Round", @@ -384,6 +386,7 @@ __all__ = [ "InplaceUpdate", "InTopK", "UniformCandidateSampler", + "LogUniformCandidateSampler", "LRN", "Mod", "ConfusionMatrix", diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 587499e9fb..a87244b708 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3424,6 +3424,71 @@ class MirrorPad(PrimitiveWithInfer): 'value': None} +class ComputeAccidentalHits(PrimitiveWithInfer): + """ + Compute accidental hits of sampled classes which happen to match target classes. + + When a target class matches the sample class, we call it "accidental hit". + The result of calculating accidental hit contains three parts (index, id, weight), + where index represents the row number in true_classes, and id represents the position in sampled_candidates, + the weight is -FLOAT_MAX. + + Args: + num_true (int): The number of target classes per training example. + + Inputs: + - **true_classes** (Tensor) - The target classes. With data type of int32 or int64 + and shape [batch_size, num_true]. + - **sampled_candidates** (Tensor) - The sampled_candidates output of CandidateSampler, + with shape [num_sampled] and the same type as true_classes. + + Outputs: + Tuple of 3 Tensors. + + - **indices** (Tensor) - A Tensor with shape (num_accidental_hits,), with the same type as `true_classes`. + - **ids** (Tensor) - A Tensor with shape (num_accidental_hits,), with the same type as `true_classes`. + - **weights** (Tensor) - A Tensor with shape (num_accidental_hits,), with the type float32. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> x = np.array([[1, 2], [0, 4], [3, 3]]) + >>> y = np.array([0, 1, 2, 3, 4]) + >>> sampler = ops.ComputeAccidentalHits(2) + >>> output1, output2, output3 = sampler(Tensor(x), Tensor(y)) + >>> print(output1, output2, output3) + [0, 0, 1, 1, 2, 2], [1, 2, 0, 4, 3, 3], + [-3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38] + + """ + + @prim_attr_register + def __init__(self, num_true=1): + """Initialize ComputeAccidentalHits""" + self.init_prim_io_names(inputs=['true_classes', 'sampled_candidates'], + outputs=['indices', 'ids', 'weights']) + validator.check_value_type("num_true", num_true, [int], self.name) + self.num_true = num_true + + def infer_shape(self, true_classes_shape, sampled_candidates_shape): + validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name) + validator.check("sampled_candidates shape rank", len(sampled_candidates_shape), "expect", 1, Rel.EQ, self.name) + validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name) + + indices_len = -1 + return (indices_len,), (indices_len,), (indices_len,) + + def infer_dtype(self, true_classes_type, sampled_candidates_type): + validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name) + validator.check_subclass("sampled_candidates_type", sampled_candidates_type, mstype.tensor, self.name) + valid_types = (mstype.int32, mstype.int64) + validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, valid_types, self.name) + validator.check_tensor_dtype_valid("sampled_candidates_type", sampled_candidates_type, valid_types, self.name) + weights_type = mstype.float32 + return true_classes_type, true_classes_type, weights_type + + class ROIAlign(PrimitiveWithInfer): """ Computes the Region of Interest (RoI) Align operator. diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 5c7070c082..da0e900b5c 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -634,3 +634,73 @@ class UniformCandidateSampler(PrimitiveWithInfer): def infer_shape(self, true_classes_shape): Validator.check("true_class.shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name) return ([self.num_sampled], true_classes_shape, [self.num_sampled]) + + +class LogUniformCandidateSampler(PrimitiveWithInfer): + """ + Generates random labels with a log-uniform distribution for sampled_candidates. + + Random sampling a tensor of sampled classes from the range of integers [0, range_max). + + Args: + num_true (int): The number of target classes per training example. Default: 1. + num_sampled (int): The number of classes to randomly sample. Default: 5. + unique (bool): Determines whether sample with rejection. If unique is True, + all sampled classes in a batch are unique. Default: True. + range_max (int): The number of possible classes. Default: 5. + seed (int): Random seed, must be non-negative. + + Inputs: + - **true_classes** (Tensor) - The target classes. With data type of int64 and shape [batch_size, num_true]. + + Outputs: + Tuple of 3 Tensors. + + - **sampled_candidates** (Tensor) - A Tensor with shape (num_sampled,) and the same type as `true_classes`. + - **true_expected_count** (Tensor) - A Tensor with the same shape as `true_classes and` type float32. + - **sampled_expected_count** (Tensor) - A Tensor with the same shape as `sampled_candidates` and type float32. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> sampler = ops.LogUniformCandidateSampler(1, 5, True, 5) + >>> output1, output2, output3 = sampler(Tensor(np.array([[1, 7], [0, 4], [3, 3]]))) + >>> print(output1, output2, output3) + [3, 2, 0, 4, 1], + [[9.23129916e-01, 4.93363708e-01], + [9.92489874e-01, 6.58063710e-01], + [7.35534430e-01, 7.35534430e-01]], + [7.35534430e-01, 8.26258004e-01, 9.92489874e-01, 6.58063710e-01, 9.23129916e-01] + + """ + + @prim_attr_register + def __init__(self, num_true=1, num_sampled=5, unique=True, range_max=5, seed=0): + """Initialize LogUniformCandidateSampler""" + self.init_prim_io_names(inputs=['true_classes'], + outputs=['sampled_candidates', 'true_expected_count', 'sampled_expected_count']) + Validator.check_value_type("num_true", num_true, [int], self.name) + Validator.check_value_type("num_sampled", num_sampled, [int], self.name) + Validator.check_value_type("unique", unique, [bool], self.name) + Validator.check_value_type("range_max", range_max, [int], self.name) + Validator.check_value_type("seed", seed, [int], self.name) + self.num_true = Validator.check_number("num_true", num_true, 1, Rel.GE, self.name) + self.num_sampled = Validator.check_number("num_sampled", num_sampled, 1, Rel.GE, self.name) + if unique: + Validator.check_number("range_max", range_max, num_sampled, Rel.GE, self.name) + self.range_max = range_max + self.unique = unique + self.seed = seed + + def infer_shape(self, true_classes_shape): + Validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name) + Validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name) + return (self.num_sampled,), true_classes_shape, (self.num_sampled,) + + def infer_dtype(self, true_classes_type): + Validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name) + valid_types = (mstype.int64,) + Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, valid_types, self.name) + expected_type = mstype.float32 + return true_classes_type, expected_type, expected_type diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py b/tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py new file mode 100644 index 0000000000..74fbcaf630 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py @@ -0,0 +1,47 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, num_true=1): + super(Net, self).__init__() + self.sampler = P.ComputeAccidentalHits(num_true) + + def construct(self, x, y): + return self.sampler(x, y) + + +def test_net(): + x = np.array([[1, 2], [0, 4], [3, 3]]) + y = np.array([0, 1, 2, 3, 4]) + net = Net(2) + output1, output2, output3 = net(Tensor(x), Tensor(y)) + print(output1, output2, output3) + + output1_expect = np.array([0, 0, 1, 1, 2, 2]) + output2_expect = np.array([1, 2, 0, 4, 3, 3]) + output3_expect = np.array([-3.4028235+38, -3.4028235+38, -3.4028235+38, + -3.4028235+38, -3.4028235+38, -3.4028235+38]).astype(np.float32) + assert np.array_equal(output1.asnumpy(), output1_expect) + assert np.array_equal(output2.asnumpy(), output2_expect) + assert np.array_equal(output3.asnumpy(), output3_expect) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_log_uniform_candidate_sampler.py b/tests/st/ops/ascend/test_aicpu_ops/test_log_uniform_candidate_sampler.py new file mode 100644 index 0000000000..913398fa16 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_log_uniform_candidate_sampler.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, num_true=1, num_sampled=5, unique=True, range_max=5, seed=0): + super(Net, self).__init__() + self.sampler = P.LogUniformCandidateSampler(num_true, num_sampled, unique, range_max, seed) + + def construct(self, x): + return self.sampler(x) + + +def test_net_true(): + x = np.array([[1, 7], [0, 4], [3, 3]]) + net = Net(2, 5, True, 5) + output = net(Tensor(x)) + print(output) + + +def test_net_false(): + x = np.array([[1, 7], [0, 4], [3, 3]]) + net = Net(2, 5, False, 10) + output = net(Tensor(x)) + print(output)