Browse Source

blocking int64 from reducing percision to int32

tags/v1.1.0
TFbunny 5 years ago
parent
commit
b82d42cc56
1 changed files with 3 additions and 1 deletions
  1. +3
    -1
      mindspore/ops/operations/random_ops.py

+ 3
- 1
mindspore/ops/operations/random_ops.py View File

@@ -514,7 +514,8 @@ class UniformCandidateSampler(PrimitiveWithInfer):
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
unique (bool): Whether all sampled classes in a batch are unique. unique (bool): Whether all sampled classes in a batch are unique.
range_max (int): The number of possible classes, must be non-negative. range_max (int): The number of possible classes, must be non-negative.
seed (int): Random seed, must be non-negative. Default: 0.
seed (int): Used for random number generation, must be non-negative. If seed has a value of 0,
seed will be replaced with a randomly generated value. Default: 0.
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False. remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.


Inputs: Inputs:
@@ -553,6 +554,7 @@ class UniformCandidateSampler(PrimitiveWithInfer):
self.num_sampled = num_sampled self.num_sampled = num_sampled


def infer_dtype(self, true_classes_type): def infer_dtype(self, true_classes_type):
Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, (mstype.int32), self.name)
return (true_classes_type, mstype.float32, mstype.float32) return (true_classes_type, mstype.float32, mstype.float32)


def infer_shape(self, true_classes_shape): def infer_shape(self, true_classes_shape):


Loading…
Cancel
Save