Browse Source

fix doc and add check for UniformCandidateSampler

tags/v1.1.0
TFbunny 5 years ago
parent
commit
941645c363
1 changed files with 8 additions and 6 deletions
  1. +8
    -6
      mindspore/ops/operations/nn_ops.py

+ 8
- 6
mindspore/ops/operations/nn_ops.py View File

@@ -5956,8 +5956,8 @@ class UniformCandidateSampler(PrimitiveWithInfer):

Args:
num_true (int): The number of target classes in each training example.
num_sampled (int): The number of classes to randomly sample. The **sampled_candidates** will have a shape
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
unique (bool): Whether all sampled classes in a batch are unique.
range_max (int): The number of possible classes.
seed (int): Random seed, must be non-negative. Default: 0.
@@ -5968,16 +5968,16 @@ class UniformCandidateSampler(PrimitiveWithInfer):

Outputs:
A tuple of 3 tensors.

sampled_candidates: (int): The sampled_candidates is independent of the true classes. Shape: (num_sampled, ).
true_expected_count: (float): The expected counts under the sampling distribution of each of true_classes.
Shape: (batch_size, num_true).
Shape: (batch_size, num_true).
sampled_expected_count: (float): The expected counts under the sampling distribution of each of
sampled_candidates. Shape: (num_sampled, ).
sampled_candidates. Shape: (num_sampled, ).

Examples:
>>> sampler = P.UniformCandidateSampler(1, 3, False, 4)
>>> SampledCandidates, TrueExpectedCount, SampledExpectedCount = sampler(Tensor(np.array([[1],[3],[4],[6],
[3]], dtype=np.int32)))
>>> output1, output2, output3 = sampler(Tensor(np.array([[1],[3],[4],[6],[3]], dtype=np.int32)))
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
"""

@@ -5991,6 +5991,7 @@ class UniformCandidateSampler(PrimitiveWithInfer):
validator.check_value_type("seed", seed, [int], self.name)
validator.check_value_type("remove_accidental_hits", remove_accidental_hits, [bool], self.name)
validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name)
self.num_true = num_true
if unique:
validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name)
validator.check("value of seed", seed, '', 0, Rel.GE, self.name)
@@ -6000,4 +6001,5 @@ class UniformCandidateSampler(PrimitiveWithInfer):
return (true_classes_type, mstype.float32, mstype.float32)

def infer_shape(self, true_classes_shape):
validator.check("true_class[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
return ([self.num_sampled], true_classes_shape, [self.num_sampled])

Loading…
Cancel
Save