|
|
|
@@ -602,9 +602,10 @@ class LogUniformCandidateSampler(PrimitiveWithInfer): |
|
|
|
Args: |
|
|
|
num_true (int): The number of target classes per training example. Default: 1. |
|
|
|
num_sampled (int): The number of classes to randomly sample. Default: 5. |
|
|
|
unique (bool): Determines whether sample with rejection. If unique is True, |
|
|
|
all sampled classes in a batch are unique. Default: True. |
|
|
|
range_max (int): The number of possible classes. Default: 5. |
|
|
|
unique (bool): Determines whether sample with rejection. If `unique` is True, |
|
|
|
all sampled classes in a batch are unique. Default: True. |
|
|
|
range_max (int): The number of possible classes. When `unique` is True, |
|
|
|
`range_max` must be greater than or equal to `num_sampled`. Default: 5. |
|
|
|
seed (int): Random seed, must be non-negative. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
@@ -644,6 +645,7 @@ class LogUniformCandidateSampler(PrimitiveWithInfer): |
|
|
|
Validator.check_value_type("seed", seed, [int], self.name) |
|
|
|
self.num_true = Validator.check_number("num_true", num_true, 1, Rel.GE, self.name) |
|
|
|
self.num_sampled = Validator.check_number("num_sampled", num_sampled, 1, Rel.GE, self.name) |
|
|
|
Validator.check_number("range_max", range_max, 1, Rel.GE, self.name) |
|
|
|
if unique: |
|
|
|
Validator.check("range_max", range_max, "num_sampled", num_sampled, Rel.GE, self.name) |
|
|
|
self.range_max = range_max |
|
|
|
|