| @@ -514,7 +514,8 @@ class UniformCandidateSampler(PrimitiveWithInfer): | |||||
| of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. | of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. | ||||
| unique (bool): Whether all sampled classes in a batch are unique. | unique (bool): Whether all sampled classes in a batch are unique. | ||||
| range_max (int): The number of possible classes, must be non-negative. | range_max (int): The number of possible classes, must be non-negative. | ||||
| seed (int): Random seed, must be non-negative. Default: 0. | |||||
| seed (int): Used for random number generation, must be non-negative. If seed has a value of 0, | |||||
| seed will be replaced with a randomly generated value. Default: 0. | |||||
| remove_accidental_hits (bool): Whether accidental hit is removed. Default: False. | remove_accidental_hits (bool): Whether accidental hit is removed. Default: False. | ||||
| Inputs: | Inputs: | ||||
| @@ -553,6 +554,7 @@ class UniformCandidateSampler(PrimitiveWithInfer): | |||||
| self.num_sampled = num_sampled | self.num_sampled = num_sampled | ||||
| def infer_dtype(self, true_classes_type): | def infer_dtype(self, true_classes_type): | ||||
| Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, (mstype.int32), self.name) | |||||
| return (true_classes_type, mstype.float32, mstype.float32) | return (true_classes_type, mstype.float32, mstype.float32) | ||||
| def infer_shape(self, true_classes_shape): | def infer_shape(self, true_classes_shape): | ||||