diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index d86dc6906f..16690d5478 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -514,7 +514,8 @@ class UniformCandidateSampler(PrimitiveWithInfer): of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. unique (bool): Whether all sampled classes in a batch are unique. range_max (int): The number of possible classes, must be non-negative. - seed (int): Random seed, must be non-negative. Default: 0. + seed (int): Used for random number generation, must be non-negative. If seed has a value of 0, + seed will be replaced with a randomly generated value. Default: 0. remove_accidental_hits (bool): Whether accidental hit is removed. Default: False. Inputs: @@ -553,6 +554,7 @@ class UniformCandidateSampler(PrimitiveWithInfer): self.num_sampled = num_sampled def infer_dtype(self, true_classes_type): + Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, (mstype.int32), self.name) return (true_classes_type, mstype.float32, mstype.float32) def infer_shape(self, true_classes_shape):