|
|
|
@@ -514,7 +514,8 @@ class UniformCandidateSampler(PrimitiveWithInfer): |
|
|
|
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. |
|
|
|
unique (bool): Whether all sampled classes in a batch are unique. |
|
|
|
range_max (int): The number of possible classes, must be non-negative. |
|
|
|
seed (int): Random seed, must be non-negative. Default: 0. |
|
|
|
seed (int): Used for random number generation, must be non-negative. If seed has a value of 0, |
|
|
|
seed will be replaced with a randomly generated value. Default: 0. |
|
|
|
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
@@ -553,6 +554,7 @@ class UniformCandidateSampler(PrimitiveWithInfer): |
|
|
|
self.num_sampled = num_sampled |
|
|
|
|
|
|
|
def infer_dtype(self, true_classes_type): |
|
|
|
Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, (mstype.int32), self.name) |
|
|
|
return (true_classes_type, mstype.float32, mstype.float32) |
|
|
|
|
|
|
|
def infer_shape(self, true_classes_shape): |
|
|
|
|