Browse Source

!10164 add LogUniformCandidateSampler input check

From: @yanzhenxiang2020
Reviewed-by: @wuxuejian,@liangchenghui
Signed-off-by: @wuxuejian
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
956e0cf1dc
2 changed files with 6 additions and 4 deletions
  1. +1
    -1
      mindspore/ops/operations/nn_ops.py
  2. +5
    -3
      mindspore/ops/operations/random_ops.py

+ 1
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -3423,7 +3423,7 @@ class ComputeAccidentalHits(PrimitiveWithCheck):
- **true_classes** (Tensor) - The target classes. With data type of int32 or int64
and shape [batch_size, num_true].
- **sampled_candidates** (Tensor) - The sampled_candidates output of CandidateSampler,
with shape [num_sampled] and the same type as true_classes.
with data type of int32 or int64 and shape [num_sampled].

Outputs:
Tuple of 3 Tensors.


+ 5
- 3
mindspore/ops/operations/random_ops.py View File

@@ -602,9 +602,10 @@ class LogUniformCandidateSampler(PrimitiveWithInfer):
Args:
num_true (int): The number of target classes per training example. Default: 1.
num_sampled (int): The number of classes to randomly sample. Default: 5.
unique (bool): Determines whether sample with rejection. If unique is True,
all sampled classes in a batch are unique. Default: True.
range_max (int): The number of possible classes. Default: 5.
unique (bool): Determines whether sample with rejection. If `unique` is True,
all sampled classes in a batch are unique. Default: True.
range_max (int): The number of possible classes. When `unique` is True,
`range_max` must be greater than or equal to `num_sampled`. Default: 5.
seed (int): Random seed, must be non-negative.

Inputs:
@@ -644,6 +645,7 @@ class LogUniformCandidateSampler(PrimitiveWithInfer):
Validator.check_value_type("seed", seed, [int], self.name)
self.num_true = Validator.check_number("num_true", num_true, 1, Rel.GE, self.name)
self.num_sampled = Validator.check_number("num_sampled", num_sampled, 1, Rel.GE, self.name)
Validator.check_number("range_max", range_max, 1, Rel.GE, self.name)
if unique:
Validator.check("range_max", range_max, "num_sampled", num_sampled, Rel.GE, self.name)
self.range_max = range_max


Loading…
Cancel
Save