You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.ops.LogUniformCandidateSampler.rst 1.7 kB

1234567891011121314151617181920212223242526272829303132333435
  1. mindspore.ops.LogUniformCandidateSampler
  2. =========================================
  3. .. py:class:: mindspore.ops.LogUniformCandidateSampler(num_true=1, num_sampled=5, unique=True, range_max=5, seed=0)
  4. 使用log-uniform(Zipfian)分布对一组类进行采样。
  5. 该操作从整数范围[0, range_max)中随机采样一个采样类(sampled_candidates)的Tensor。
  6. **参数:**
  7. - **num_true** (int) - 每个训练样本的目标类数。默认值:1。
  8. - **num_sampled** (int) - 随机采样的类数。默认值:5。
  9. - **unique** (bool) - 确认批处理中的所有采样类是否都是唯一的。如果 `unique` 为True,则批处理中的所有采样类都唯一。默认值:True。
  10. - **range_max** (int) - 可能的类数。当 `unique` 为True时, `range_max` 必须大于或等于 `num_sampled` 。默认值:5。
  11. - **seed** (int) - 随机种子,必须是非负。默认值:0。
  12. **输入:**
  13. - **true_classes** (Tensor) - 目标类,其数据类型为int64,shape为[batch_size, num_true]。
  14. **输出:**
  15. 3个Tensor组成的元组。
  16. - **sampled_candidates** (Tensor) - shape为(num_sampled,)且数据类型与 `true_classes` 相同的Tensor。
  17. - **true_expected_count** (Tensor) - shape与 `true_classes` 相同且数据类型为float32的Tensor。
  18. - **sampled_expected_count** (Tensor) - shape与 `sampled_candidates` 相同且数据类型为float32的Tensor。
  19. **异常:**
  20. - **TypeError** - `num_true` 和 `num_sampled` 都不是int。
  21. - **TypeError** - `unique` 不是bool。
  22. - **TypeError** - `range_max` 和 `seed` 都不是int。
  23. - **TypeError** - `true_classes` 不是Tensor。