|
|
|
@@ -5956,8 +5956,8 @@ class UniformCandidateSampler(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Args: |
|
|
|
num_true (int): The number of target classes in each training example. |
|
|
|
num_sampled (int): The number of classes to randomly sample. The **sampled_candidates** will have a shape |
|
|
|
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. |
|
|
|
num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape |
|
|
|
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. |
|
|
|
unique (bool): Whether all sampled classes in a batch are unique. |
|
|
|
range_max (int): The number of possible classes. |
|
|
|
seed (int): Random seed, must be non-negative. Default: 0. |
|
|
|
@@ -5968,16 +5968,16 @@ class UniformCandidateSampler(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Outputs: |
|
|
|
A tuple of 3 tensors. |
|
|
|
|
|
|
|
sampled_candidates: (int): The sampled_candidates is independent of the true classes. Shape: (num_sampled, ). |
|
|
|
true_expected_count: (float): The expected counts under the sampling distribution of each of true_classes. |
|
|
|
Shape: (batch_size, num_true). |
|
|
|
Shape: (batch_size, num_true). |
|
|
|
sampled_expected_count: (float): The expected counts under the sampling distribution of each of |
|
|
|
sampled_candidates. Shape: (num_sampled, ). |
|
|
|
sampled_candidates. Shape: (num_sampled, ). |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> sampler = P.UniformCandidateSampler(1, 3, False, 4) |
|
|
|
>>> SampledCandidates, TrueExpectedCount, SampledExpectedCount = sampler(Tensor(np.array([[1],[3],[4],[6], |
|
|
|
[3]], dtype=np.int32))) |
|
|
|
>>> output1, output2, output3 = sampler(Tensor(np.array([[1],[3],[4],[6],[3]], dtype=np.int32))) |
|
|
|
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75] |
|
|
|
""" |
|
|
|
|
|
|
|
@@ -5991,6 +5991,7 @@ class UniformCandidateSampler(PrimitiveWithInfer): |
|
|
|
validator.check_value_type("seed", seed, [int], self.name) |
|
|
|
validator.check_value_type("remove_accidental_hits", remove_accidental_hits, [bool], self.name) |
|
|
|
validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name) |
|
|
|
self.num_true = num_true |
|
|
|
if unique: |
|
|
|
validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name) |
|
|
|
validator.check("value of seed", seed, '', 0, Rel.GE, self.name) |
|
|
|
@@ -6000,4 +6001,5 @@ class UniformCandidateSampler(PrimitiveWithInfer): |
|
|
|
return (true_classes_type, mstype.float32, mstype.float32) |
|
|
|
|
|
|
|
def infer_shape(self, true_classes_shape): |
|
|
|
validator.check("true_class[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name) |
|
|
|
return ([self.num_sampled], true_classes_shape, [self.num_sampled]) |