Browse Source

add LogUniformCandidateSampler and ComputeAccidentalHits op for aicpu

tags/v1.1.0
yanzhenxiang2020 5 years ago
parent
commit
3a1a9ff6dc
9 changed files with 351 additions and 2 deletions
  1. +3
    -0
      mindspore/ops/_op_impl/aicpu/__init__.py
  2. +44
    -0
      mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py
  3. +36
    -0
      mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py
  4. +36
    -0
      mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py
  5. +5
    -2
      mindspore/ops/operations/__init__.py
  6. +65
    -0
      mindspore/ops/operations/nn_ops.py
  7. +70
    -0
      mindspore/ops/operations/random_ops.py
  8. +47
    -0
      tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py
  9. +45
    -0
      tests/st/ops/ascend/test_aicpu_ops/test_log_uniform_candidate_sampler.py

+ 3
- 0
mindspore/ops/_op_impl/aicpu/__init__.py View File

@@ -35,6 +35,9 @@ from .expand_dims import _expand_dims_aicpu
from .randperm import _randperm_aicpu
from .random_choice_with_mask import _random_choice_with_mask_aicpu
from .pack import _pack_aicpu
from .uniform_candidate_sampler import _uniform_candidate_sampler_aicpu
from .log_uniform_candidate_sampler import _log_uniform_candidate_sampler_aicpu
from .compute_accidental_hits import _compute_accidental_hits_aicpu
from .ctcloss import _ctcloss_aicpu
from .reverse_sequence import _reverse_sequence_aicpu
from .crop_and_resize import _crop_and_resize_aicpu


+ 44
- 0
mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py View File

@@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""ComputeAccidentalHits op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
compute_accidental_hits_op_info = AiCPURegOp("ComputeAccidentalHits") \
.fusion_type("OPAQUE") \
.input(0, "true_classes", "required") \
.input(1, "sampled_candidates", "required") \
.output(0, "indices", "required") \
.output(1, "ids", "required") \
.output(2, "weights", "required") \
.attr("num_true", "int") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.F64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.F16_Default) \
.get_op_info()


@op_info_register(compute_accidental_hits_op_info)
def _compute_accidental_hits_aicpu():
"""ComputeAccidentalHits AiCPU register"""
return

+ 36
- 0
mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py View File

@@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""LogUniformCandidateSampler op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
log_uniform_candidate_sampler_op_info = AiCPURegOp("LogUniformCandidateSampler") \
.fusion_type("OPAQUE") \
.input(0, "true_classes", "required") \
.output(0, "sampled_candidates", "required") \
.output(1, "true_expected_count", "required") \
.output(2, "true_expected_count", "required") \
.attr("num_true", "int") \
.attr("num_sampled", "int") \
.attr("unique", "bool") \
.attr("range_max", "int") \
.attr("seed", "int") \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(log_uniform_candidate_sampler_op_info)
def _log_uniform_candidate_sampler_aicpu():
"""LogUniformCandidateSampler AiCPU register"""
return

+ 36
- 0
mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py View File

@@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""UniformCandidateSampler op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
uniform_candidate_sampler_op_info = AiCPURegOp("UniformCandidateSampler") \
.fusion_type("OPAQUE") \
.input(0, "true_classes", "required") \
.output(0, "sampled_candidates", "required") \
.output(1, "true_expected_count", "required") \
.output(2, "true_expected_count", "required") \
.attr("num_true", "int") \
.attr("num_sampled", "int") \
.attr("unique", "bool") \
.attr("range_max", "int") \
.attr("seed", "int") \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(uniform_candidate_sampler_op_info)
def _uniform_candidate_sampler_aicpu():
"""UniformCandidateSampler AiCPU register"""
return

+ 5
- 2
mindspore/ops/operations/__init__.py View File

@@ -57,7 +57,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)

from .random_ops import (Randperm, RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler)
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
LogUniformCandidateSampler)
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
DepthwiseConv2dNative,
@@ -67,7 +68,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
LogSoftmax,
MaxPool, DataFormatDimMap,
AvgPool, Conv2DBackpropInput,
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid,
SigmoidCrossEntropyWithLogits,
@@ -342,6 +343,7 @@ __all__ = [
"SpaceToDepth",
"DepthToSpace",
"Conv2DBackpropInput",
"ComputeAccidentalHits",
"Sign",
"LARSUpdate",
"Round",
@@ -384,6 +386,7 @@ __all__ = [
"InplaceUpdate",
"InTopK",
"UniformCandidateSampler",
"LogUniformCandidateSampler",
"LRN",
"Mod",
"ConfusionMatrix",


+ 65
- 0
mindspore/ops/operations/nn_ops.py View File

@@ -3424,6 +3424,71 @@ class MirrorPad(PrimitiveWithInfer):
'value': None}


class ComputeAccidentalHits(PrimitiveWithInfer):
"""
Compute accidental hits of sampled classes which happen to match target classes.

When a target class matches the sample class, we call it "accidental hit".
The result of calculating accidental hit contains three parts (index, id, weight),
where index represents the row number in true_classes, and id represents the position in sampled_candidates,
the weight is -FLOAT_MAX.

Args:
num_true (int): The number of target classes per training example.

Inputs:
- **true_classes** (Tensor) - The target classes. With data type of int32 or int64
and shape [batch_size, num_true].
- **sampled_candidates** (Tensor) - The sampled_candidates output of CandidateSampler,
with shape [num_sampled] and the same type as true_classes.

Outputs:
Tuple of 3 Tensors.

- **indices** (Tensor) - A Tensor with shape (num_accidental_hits,), with the same type as `true_classes`.
- **ids** (Tensor) - A Tensor with shape (num_accidental_hits,), with the same type as `true_classes`.
- **weights** (Tensor) - A Tensor with shape (num_accidental_hits,), with the type float32.

Supported Platforms:
``Ascend``

Examples:
>>> x = np.array([[1, 2], [0, 4], [3, 3]])
>>> y = np.array([0, 1, 2, 3, 4])
>>> sampler = ops.ComputeAccidentalHits(2)
>>> output1, output2, output3 = sampler(Tensor(x), Tensor(y))
>>> print(output1, output2, output3)
[0, 0, 1, 1, 2, 2], [1, 2, 0, 4, 3, 3],
[-3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38]

"""

@prim_attr_register
def __init__(self, num_true=1):
"""Initialize ComputeAccidentalHits"""
self.init_prim_io_names(inputs=['true_classes', 'sampled_candidates'],
outputs=['indices', 'ids', 'weights'])
validator.check_value_type("num_true", num_true, [int], self.name)
self.num_true = num_true

def infer_shape(self, true_classes_shape, sampled_candidates_shape):
validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name)
validator.check("sampled_candidates shape rank", len(sampled_candidates_shape), "expect", 1, Rel.EQ, self.name)
validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name)

indices_len = -1
return (indices_len,), (indices_len,), (indices_len,)

def infer_dtype(self, true_classes_type, sampled_candidates_type):
validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name)
validator.check_subclass("sampled_candidates_type", sampled_candidates_type, mstype.tensor, self.name)
valid_types = (mstype.int32, mstype.int64)
validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, valid_types, self.name)
validator.check_tensor_dtype_valid("sampled_candidates_type", sampled_candidates_type, valid_types, self.name)
weights_type = mstype.float32
return true_classes_type, true_classes_type, weights_type


class ROIAlign(PrimitiveWithInfer):
"""
Computes the Region of Interest (RoI) Align operator.


+ 70
- 0
mindspore/ops/operations/random_ops.py View File

@@ -634,3 +634,73 @@ class UniformCandidateSampler(PrimitiveWithInfer):
def infer_shape(self, true_classes_shape):
Validator.check("true_class.shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
return ([self.num_sampled], true_classes_shape, [self.num_sampled])


class LogUniformCandidateSampler(PrimitiveWithInfer):
"""
Generates random labels with a log-uniform distribution for sampled_candidates.

Random sampling a tensor of sampled classes from the range of integers [0, range_max).

Args:
num_true (int): The number of target classes per training example. Default: 1.
num_sampled (int): The number of classes to randomly sample. Default: 5.
unique (bool): Determines whether sample with rejection. If unique is True,
all sampled classes in a batch are unique. Default: True.
range_max (int): The number of possible classes. Default: 5.
seed (int): Random seed, must be non-negative.

Inputs:
- **true_classes** (Tensor) - The target classes. With data type of int64 and shape [batch_size, num_true].

Outputs:
Tuple of 3 Tensors.

- **sampled_candidates** (Tensor) - A Tensor with shape (num_sampled,) and the same type as `true_classes`.
- **true_expected_count** (Tensor) - A Tensor with the same shape as `true_classes and` type float32.
- **sampled_expected_count** (Tensor) - A Tensor with the same shape as `sampled_candidates` and type float32.

Supported Platforms:
``Ascend``

Examples:
>>> sampler = ops.LogUniformCandidateSampler(1, 5, True, 5)
>>> output1, output2, output3 = sampler(Tensor(np.array([[1, 7], [0, 4], [3, 3]])))
>>> print(output1, output2, output3)
[3, 2, 0, 4, 1],
[[9.23129916e-01, 4.93363708e-01],
[9.92489874e-01, 6.58063710e-01],
[7.35534430e-01, 7.35534430e-01]],
[7.35534430e-01, 8.26258004e-01, 9.92489874e-01, 6.58063710e-01, 9.23129916e-01]

"""

@prim_attr_register
def __init__(self, num_true=1, num_sampled=5, unique=True, range_max=5, seed=0):
"""Initialize LogUniformCandidateSampler"""
self.init_prim_io_names(inputs=['true_classes'],
outputs=['sampled_candidates', 'true_expected_count', 'sampled_expected_count'])
Validator.check_value_type("num_true", num_true, [int], self.name)
Validator.check_value_type("num_sampled", num_sampled, [int], self.name)
Validator.check_value_type("unique", unique, [bool], self.name)
Validator.check_value_type("range_max", range_max, [int], self.name)
Validator.check_value_type("seed", seed, [int], self.name)
self.num_true = Validator.check_number("num_true", num_true, 1, Rel.GE, self.name)
self.num_sampled = Validator.check_number("num_sampled", num_sampled, 1, Rel.GE, self.name)
if unique:
Validator.check_number("range_max", range_max, num_sampled, Rel.GE, self.name)
self.range_max = range_max
self.unique = unique
self.seed = seed

def infer_shape(self, true_classes_shape):
Validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name)
Validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name)
return (self.num_sampled,), true_classes_shape, (self.num_sampled,)

def infer_dtype(self, true_classes_type):
Validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name)
valid_types = (mstype.int64,)
Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, valid_types, self.name)
expected_type = mstype.float32
return true_classes_type, expected_type, expected_type

+ 47
- 0
tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py View File

@@ -0,0 +1,47 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


class Net(nn.Cell):
def __init__(self, num_true=1):
super(Net, self).__init__()
self.sampler = P.ComputeAccidentalHits(num_true)

def construct(self, x, y):
return self.sampler(x, y)


def test_net():
x = np.array([[1, 2], [0, 4], [3, 3]])
y = np.array([0, 1, 2, 3, 4])
net = Net(2)
output1, output2, output3 = net(Tensor(x), Tensor(y))
print(output1, output2, output3)

output1_expect = np.array([0, 0, 1, 1, 2, 2])
output2_expect = np.array([1, 2, 0, 4, 3, 3])
output3_expect = np.array([-3.4028235+38, -3.4028235+38, -3.4028235+38,
-3.4028235+38, -3.4028235+38, -3.4028235+38]).astype(np.float32)
assert np.array_equal(output1.asnumpy(), output1_expect)
assert np.array_equal(output2.asnumpy(), output2_expect)
assert np.array_equal(output3.asnumpy(), output3_expect)

+ 45
- 0
tests/st/ops/ascend/test_aicpu_ops/test_log_uniform_candidate_sampler.py View File

@@ -0,0 +1,45 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


class Net(nn.Cell):
def __init__(self, num_true=1, num_sampled=5, unique=True, range_max=5, seed=0):
super(Net, self).__init__()
self.sampler = P.LogUniformCandidateSampler(num_true, num_sampled, unique, range_max, seed)

def construct(self, x):
return self.sampler(x)


def test_net_true():
x = np.array([[1, 7], [0, 4], [3, 3]])
net = Net(2, 5, True, 5)
output = net(Tensor(x))
print(output)


def test_net_false():
x = np.array([[1, 7], [0, 4], [3, 3]])
net = Net(2, 5, False, 10)
output = net(Tensor(x))
print(output)

Loading…
Cancel
Save