| @@ -35,6 +35,9 @@ from .expand_dims import _expand_dims_aicpu | |||||
| from .randperm import _randperm_aicpu | from .randperm import _randperm_aicpu | ||||
| from .random_choice_with_mask import _random_choice_with_mask_aicpu | from .random_choice_with_mask import _random_choice_with_mask_aicpu | ||||
| from .pack import _pack_aicpu | from .pack import _pack_aicpu | ||||
| from .uniform_candidate_sampler import _uniform_candidate_sampler_aicpu | |||||
| from .log_uniform_candidate_sampler import _log_uniform_candidate_sampler_aicpu | |||||
| from .compute_accidental_hits import _compute_accidental_hits_aicpu | |||||
| from .ctcloss import _ctcloss_aicpu | from .ctcloss import _ctcloss_aicpu | ||||
| from .reverse_sequence import _reverse_sequence_aicpu | from .reverse_sequence import _reverse_sequence_aicpu | ||||
| from .crop_and_resize import _crop_and_resize_aicpu | from .crop_and_resize import _crop_and_resize_aicpu | ||||
| @@ -0,0 +1,44 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ComputeAccidentalHits op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| compute_accidental_hits_op_info = AiCPURegOp("ComputeAccidentalHits") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "true_classes", "required") \ | |||||
| .input(1, "sampled_candidates", "required") \ | |||||
| .output(0, "indices", "required") \ | |||||
| .output(1, "ids", "required") \ | |||||
| .output(2, "weights", "required") \ | |||||
| .attr("num_true", "int") \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, | |||||
| DataType.I32_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, | |||||
| DataType.I32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, | |||||
| DataType.I32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, | |||||
| DataType.I64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, | |||||
| DataType.I64_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, | |||||
| DataType.I64_Default, DataType.F16_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(compute_accidental_hits_op_info) | |||||
| def _compute_accidental_hits_aicpu(): | |||||
| """ComputeAccidentalHits AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """LogUniformCandidateSampler op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| log_uniform_candidate_sampler_op_info = AiCPURegOp("LogUniformCandidateSampler") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "true_classes", "required") \ | |||||
| .output(0, "sampled_candidates", "required") \ | |||||
| .output(1, "true_expected_count", "required") \ | |||||
| .output(2, "true_expected_count", "required") \ | |||||
| .attr("num_true", "int") \ | |||||
| .attr("num_sampled", "int") \ | |||||
| .attr("unique", "bool") \ | |||||
| .attr("range_max", "int") \ | |||||
| .attr("seed", "int") \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(log_uniform_candidate_sampler_op_info) | |||||
| def _log_uniform_candidate_sampler_aicpu(): | |||||
| """LogUniformCandidateSampler AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """UniformCandidateSampler op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| uniform_candidate_sampler_op_info = AiCPURegOp("UniformCandidateSampler") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "true_classes", "required") \ | |||||
| .output(0, "sampled_candidates", "required") \ | |||||
| .output(1, "true_expected_count", "required") \ | |||||
| .output(2, "true_expected_count", "required") \ | |||||
| .attr("num_true", "int") \ | |||||
| .attr("num_sampled", "int") \ | |||||
| .attr("unique", "bool") \ | |||||
| .attr("range_max", "int") \ | |||||
| .attr("seed", "int") \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(uniform_candidate_sampler_op_info) | |||||
| def _uniform_candidate_sampler_aicpu(): | |||||
| """UniformCandidateSampler AiCPU register""" | |||||
| return | |||||
| @@ -57,7 +57,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | ||||
| from .random_ops import (Randperm, RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | from .random_ops import (Randperm, RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | ||||
| RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler) | |||||
| RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler, | |||||
| LogUniformCandidateSampler) | |||||
| from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, | from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, | ||||
| BiasAdd, Conv2D, | BiasAdd, Conv2D, | ||||
| DepthwiseConv2dNative, | DepthwiseConv2dNative, | ||||
| @@ -67,7 +68,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam | |||||
| GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, | GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, | ||||
| LogSoftmax, | LogSoftmax, | ||||
| MaxPool, DataFormatDimMap, | MaxPool, DataFormatDimMap, | ||||
| AvgPool, Conv2DBackpropInput, | |||||
| AvgPool, Conv2DBackpropInput, ComputeAccidentalHits, | |||||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | ||||
| ResizeBilinear, Sigmoid, | ResizeBilinear, Sigmoid, | ||||
| SigmoidCrossEntropyWithLogits, | SigmoidCrossEntropyWithLogits, | ||||
| @@ -342,6 +343,7 @@ __all__ = [ | |||||
| "SpaceToDepth", | "SpaceToDepth", | ||||
| "DepthToSpace", | "DepthToSpace", | ||||
| "Conv2DBackpropInput", | "Conv2DBackpropInput", | ||||
| "ComputeAccidentalHits", | |||||
| "Sign", | "Sign", | ||||
| "LARSUpdate", | "LARSUpdate", | ||||
| "Round", | "Round", | ||||
| @@ -384,6 +386,7 @@ __all__ = [ | |||||
| "InplaceUpdate", | "InplaceUpdate", | ||||
| "InTopK", | "InTopK", | ||||
| "UniformCandidateSampler", | "UniformCandidateSampler", | ||||
| "LogUniformCandidateSampler", | |||||
| "LRN", | "LRN", | ||||
| "Mod", | "Mod", | ||||
| "ConfusionMatrix", | "ConfusionMatrix", | ||||
| @@ -3424,6 +3424,71 @@ class MirrorPad(PrimitiveWithInfer): | |||||
| 'value': None} | 'value': None} | ||||
| class ComputeAccidentalHits(PrimitiveWithInfer): | |||||
| """ | |||||
| Compute accidental hits of sampled classes which happen to match target classes. | |||||
| When a target class matches the sample class, we call it "accidental hit". | |||||
| The result of calculating accidental hit contains three parts (index, id, weight), | |||||
| where index represents the row number in true_classes, and id represents the position in sampled_candidates, | |||||
| the weight is -FLOAT_MAX. | |||||
| Args: | |||||
| num_true (int): The number of target classes per training example. | |||||
| Inputs: | |||||
| - **true_classes** (Tensor) - The target classes. With data type of int32 or int64 | |||||
| and shape [batch_size, num_true]. | |||||
| - **sampled_candidates** (Tensor) - The sampled_candidates output of CandidateSampler, | |||||
| with shape [num_sampled] and the same type as true_classes. | |||||
| Outputs: | |||||
| Tuple of 3 Tensors. | |||||
| - **indices** (Tensor) - A Tensor with shape (num_accidental_hits,), with the same type as `true_classes`. | |||||
| - **ids** (Tensor) - A Tensor with shape (num_accidental_hits,), with the same type as `true_classes`. | |||||
| - **weights** (Tensor) - A Tensor with shape (num_accidental_hits,), with the type float32. | |||||
| Supported Platforms: | |||||
| ``Ascend`` | |||||
| Examples: | |||||
| >>> x = np.array([[1, 2], [0, 4], [3, 3]]) | |||||
| >>> y = np.array([0, 1, 2, 3, 4]) | |||||
| >>> sampler = ops.ComputeAccidentalHits(2) | |||||
| >>> output1, output2, output3 = sampler(Tensor(x), Tensor(y)) | |||||
| >>> print(output1, output2, output3) | |||||
| [0, 0, 1, 1, 2, 2], [1, 2, 0, 4, 3, 3], | |||||
| [-3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, num_true=1): | |||||
| """Initialize ComputeAccidentalHits""" | |||||
| self.init_prim_io_names(inputs=['true_classes', 'sampled_candidates'], | |||||
| outputs=['indices', 'ids', 'weights']) | |||||
| validator.check_value_type("num_true", num_true, [int], self.name) | |||||
| self.num_true = num_true | |||||
| def infer_shape(self, true_classes_shape, sampled_candidates_shape): | |||||
| validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name) | |||||
| validator.check("sampled_candidates shape rank", len(sampled_candidates_shape), "expect", 1, Rel.EQ, self.name) | |||||
| validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name) | |||||
| indices_len = -1 | |||||
| return (indices_len,), (indices_len,), (indices_len,) | |||||
| def infer_dtype(self, true_classes_type, sampled_candidates_type): | |||||
| validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name) | |||||
| validator.check_subclass("sampled_candidates_type", sampled_candidates_type, mstype.tensor, self.name) | |||||
| valid_types = (mstype.int32, mstype.int64) | |||||
| validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, valid_types, self.name) | |||||
| validator.check_tensor_dtype_valid("sampled_candidates_type", sampled_candidates_type, valid_types, self.name) | |||||
| weights_type = mstype.float32 | |||||
| return true_classes_type, true_classes_type, weights_type | |||||
| class ROIAlign(PrimitiveWithInfer): | class ROIAlign(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Computes the Region of Interest (RoI) Align operator. | Computes the Region of Interest (RoI) Align operator. | ||||
| @@ -634,3 +634,73 @@ class UniformCandidateSampler(PrimitiveWithInfer): | |||||
| def infer_shape(self, true_classes_shape): | def infer_shape(self, true_classes_shape): | ||||
| Validator.check("true_class.shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name) | Validator.check("true_class.shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name) | ||||
| return ([self.num_sampled], true_classes_shape, [self.num_sampled]) | return ([self.num_sampled], true_classes_shape, [self.num_sampled]) | ||||
| class LogUniformCandidateSampler(PrimitiveWithInfer): | |||||
| """ | |||||
| Generates random labels with a log-uniform distribution for sampled_candidates. | |||||
| Random sampling a tensor of sampled classes from the range of integers [0, range_max). | |||||
| Args: | |||||
| num_true (int): The number of target classes per training example. Default: 1. | |||||
| num_sampled (int): The number of classes to randomly sample. Default: 5. | |||||
| unique (bool): Determines whether sample with rejection. If unique is True, | |||||
| all sampled classes in a batch are unique. Default: True. | |||||
| range_max (int): The number of possible classes. Default: 5. | |||||
| seed (int): Random seed, must be non-negative. | |||||
| Inputs: | |||||
| - **true_classes** (Tensor) - The target classes. With data type of int64 and shape [batch_size, num_true]. | |||||
| Outputs: | |||||
| Tuple of 3 Tensors. | |||||
| - **sampled_candidates** (Tensor) - A Tensor with shape (num_sampled,) and the same type as `true_classes`. | |||||
| - **true_expected_count** (Tensor) - A Tensor with the same shape as `true_classes and` type float32. | |||||
| - **sampled_expected_count** (Tensor) - A Tensor with the same shape as `sampled_candidates` and type float32. | |||||
| Supported Platforms: | |||||
| ``Ascend`` | |||||
| Examples: | |||||
| >>> sampler = ops.LogUniformCandidateSampler(1, 5, True, 5) | |||||
| >>> output1, output2, output3 = sampler(Tensor(np.array([[1, 7], [0, 4], [3, 3]]))) | |||||
| >>> print(output1, output2, output3) | |||||
| [3, 2, 0, 4, 1], | |||||
| [[9.23129916e-01, 4.93363708e-01], | |||||
| [9.92489874e-01, 6.58063710e-01], | |||||
| [7.35534430e-01, 7.35534430e-01]], | |||||
| [7.35534430e-01, 8.26258004e-01, 9.92489874e-01, 6.58063710e-01, 9.23129916e-01] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, num_true=1, num_sampled=5, unique=True, range_max=5, seed=0): | |||||
| """Initialize LogUniformCandidateSampler""" | |||||
| self.init_prim_io_names(inputs=['true_classes'], | |||||
| outputs=['sampled_candidates', 'true_expected_count', 'sampled_expected_count']) | |||||
| Validator.check_value_type("num_true", num_true, [int], self.name) | |||||
| Validator.check_value_type("num_sampled", num_sampled, [int], self.name) | |||||
| Validator.check_value_type("unique", unique, [bool], self.name) | |||||
| Validator.check_value_type("range_max", range_max, [int], self.name) | |||||
| Validator.check_value_type("seed", seed, [int], self.name) | |||||
| self.num_true = Validator.check_number("num_true", num_true, 1, Rel.GE, self.name) | |||||
| self.num_sampled = Validator.check_number("num_sampled", num_sampled, 1, Rel.GE, self.name) | |||||
| if unique: | |||||
| Validator.check_number("range_max", range_max, num_sampled, Rel.GE, self.name) | |||||
| self.range_max = range_max | |||||
| self.unique = unique | |||||
| self.seed = seed | |||||
| def infer_shape(self, true_classes_shape): | |||||
| Validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name) | |||||
| Validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name) | |||||
| return (self.num_sampled,), true_classes_shape, (self.num_sampled,) | |||||
| def infer_dtype(self, true_classes_type): | |||||
| Validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name) | |||||
| valid_types = (mstype.int64,) | |||||
| Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, valid_types, self.name) | |||||
| expected_type = mstype.float32 | |||||
| return true_classes_type, expected_type, expected_type | |||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, num_true=1): | |||||
| super(Net, self).__init__() | |||||
| self.sampler = P.ComputeAccidentalHits(num_true) | |||||
| def construct(self, x, y): | |||||
| return self.sampler(x, y) | |||||
| def test_net(): | |||||
| x = np.array([[1, 2], [0, 4], [3, 3]]) | |||||
| y = np.array([0, 1, 2, 3, 4]) | |||||
| net = Net(2) | |||||
| output1, output2, output3 = net(Tensor(x), Tensor(y)) | |||||
| print(output1, output2, output3) | |||||
| output1_expect = np.array([0, 0, 1, 1, 2, 2]) | |||||
| output2_expect = np.array([1, 2, 0, 4, 3, 3]) | |||||
| output3_expect = np.array([-3.4028235+38, -3.4028235+38, -3.4028235+38, | |||||
| -3.4028235+38, -3.4028235+38, -3.4028235+38]).astype(np.float32) | |||||
| assert np.array_equal(output1.asnumpy(), output1_expect) | |||||
| assert np.array_equal(output2.asnumpy(), output2_expect) | |||||
| assert np.array_equal(output3.asnumpy(), output3_expect) | |||||
| @@ -0,0 +1,45 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, num_true=1, num_sampled=5, unique=True, range_max=5, seed=0): | |||||
| super(Net, self).__init__() | |||||
| self.sampler = P.LogUniformCandidateSampler(num_true, num_sampled, unique, range_max, seed) | |||||
| def construct(self, x): | |||||
| return self.sampler(x) | |||||
| def test_net_true(): | |||||
| x = np.array([[1, 7], [0, 4], [3, 3]]) | |||||
| net = Net(2, 5, True, 5) | |||||
| output = net(Tensor(x)) | |||||
| print(output) | |||||
| def test_net_false(): | |||||
| x = np.array([[1, 7], [0, 4], [3, 3]]) | |||||
| net = Net(2, 5, False, 10) | |||||
| output = net(Tensor(x)) | |||||
| print(output) | |||||