Browse Source

move P.UniformCandidateSampler to ops.random

tags/v1.1.0
TFbunny 5 years ago
parent
commit
1b22e6d8ae
5 changed files with 64 additions and 64 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.cc
  2. +3
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.h
  3. +2
    -2
      mindspore/ops/operations/__init__.py
  4. +0
    -58
      mindspore/ops/operations/nn_ops.py
  5. +58
    -0
      mindspore/ops/operations/random_ops.py

mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.cc → mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.cc View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.h"
#include "backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.h"

namespace mindspore {
namespace kernel {

mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.h → mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_

#include <cmath>
#include <set>
@@ -160,4 +160,4 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_

+ 2
- 2
mindspore/ops/operations/__init__.py View File

@@ -57,7 +57,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)

from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial)
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler)
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
DepthwiseConv2dNative,
@@ -79,7 +79,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
FusedSparseFtrl, FusedSparseProximalAdagrad,
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformCandidateSampler)
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
from . import _quant_ops
from ._quant_ops import *
from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode,


+ 0
- 58
mindspore/ops/operations/nn_ops.py View File

@@ -6201,61 +6201,3 @@ class LRN(PrimitiveWithInfer):
def infer_shape(self, x_shape):
validator.check_int(len(x_shape), 4, Rel.EQ, "x_shape", self.name)
return x_shape


class UniformCandidateSampler(PrimitiveWithInfer):
r"""
Uniform candidate sampler.

This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
If unique=True, candidates are drawn without replacement, else unique=False with replacement.

Args:
num_true (int): The number of target classes in each training example.
num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
unique (bool): Whether all sampled classes in a batch are unique.
range_max (int): The number of possible classes.
seed (int): Random seed, must be non-negative. Default: 0.
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.

Inputs:
true_classes (int): A tensor. The target classes with a tensor shape of (batch_size, num_true).

Outputs:
A tuple of 3 tensors.

sampled_candidates: (int): The sampled_candidates is independent of the true classes. Shape: (num_sampled, ).
true_expected_count: (float): The expected counts under the sampling distribution of each of true_classes.
Shape: (batch_size, num_true).
sampled_expected_count: (float): The expected counts under the sampling distribution of each of
sampled_candidates. Shape: (num_sampled, ).

Examples:
>>> sampler = P.UniformCandidateSampler(1, 3, False, 4)
>>> output1, output2, output3 = sampler(Tensor(np.array([[1],[3],[4],[6],[3]], dtype=np.int32)))
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
"""

@prim_attr_register
def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False):
"""Initialize UniformCandidateSampler"""
validator.check_value_type("num_true", num_true, [int], self.name)
validator.check_value_type("num_sampled", num_sampled, [int], self.name)
validator.check_value_type("unique", unique, [bool], self.name)
validator.check_value_type("range_max", range_max, [int], self.name)
validator.check_value_type("seed", seed, [int], self.name)
validator.check_value_type("remove_accidental_hits", remove_accidental_hits, [bool], self.name)
validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name)
self.num_true = num_true
if unique:
validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name)
validator.check("value of seed", seed, '', 0, Rel.GE, self.name)
self.num_sampled = num_sampled

def infer_dtype(self, true_classes_type):
return (true_classes_type, mstype.float32, mstype.float32)

def infer_shape(self, true_classes_shape):
validator.check("true_class[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
return ([self.num_sampled], true_classes_shape, [self.num_sampled])

+ 58
- 0
mindspore/ops/operations/random_ops.py View File

@@ -500,3 +500,61 @@ class Multinomial(PrimitiveWithInfer):
"dtype": mstype.int32,
"value": None}
return out

class UniformCandidateSampler(PrimitiveWithInfer):
r"""
Uniform candidate sampler.

This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
If unique=True, candidates are drawn without replacement, else unique=False with replacement.

Args:
num_true (int): The number of target classes in each training example.
num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
unique (bool): Whether all sampled classes in a batch are unique.
range_max (int): The number of possible classes, must be non-negative.
seed (int): Random seed, must be non-negative. Default: 0.
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.

Inputs:
- **true_classes** (Tensor) - A Tensor. The target classes with a Tensor shape of (batch_size, num_true).

Outputs:
- **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes.
Shape: (num_sampled, ).
- **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each
of true_classes. Shape: (batch_size, num_true).
- **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of
each of sampled_candidates. Shape: (num_sampled, ).

Examples:
>>> sampler = P.UniformCandidateSampler(1, 3, False, 4)
>>> output1, output2, output3 = sampler(Tensor(np.array([[1],[3],[4],[6],[3]], dtype=np.int32)))
>>> print(output1, output2, output3)
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
"""

@prim_attr_register
def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False):
"""Initialize UniformCandidateSampler"""
Validator.check_value_type("num_true", num_true, [int], self.name)
Validator.check_value_type("num_sampled", num_sampled, [int], self.name)
Validator.check_value_type("unique", unique, [bool], self.name)
Validator.check_value_type("range_max", range_max, [int], self.name)
Validator.check_value_type("seed", seed, [int], self.name)
Validator.check_value_type("remove_accidental_hits", remove_accidental_hits, [bool], self.name)
Validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name)
Validator.check("value of range_max", range_max, '', 0, Rel.GT, self.name)
self.num_true = num_true
if unique:
Validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name)
Validator.check("value of seed", seed, '', 0, Rel.GE, self.name)
self.num_sampled = num_sampled

def infer_dtype(self, true_classes_type):
return (true_classes_type, mstype.float32, mstype.float32)

def infer_shape(self, true_classes_shape):
Validator.check("true_class.shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
return ([self.num_sampled], true_classes_shape, [self.num_sampled])

Loading…
Cancel
Save