|
|
|
@@ -284,7 +284,7 @@ class SampledSoftmaxLoss(_Loss): |
|
|
|
where a sampled class equals one of the target classes. Default is True. |
|
|
|
seed (int): Random seed for candidate sampling. Default: 0 |
|
|
|
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none". |
|
|
|
If "none", do not perform reduction. Default: "None". |
|
|
|
If "none", do not perform reduction. Default: "none". |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **weights** (Tensor) - Tensor of shape (C, dim). |
|
|
|
@@ -311,7 +311,22 @@ class SampledSoftmaxLoss(_Loss): |
|
|
|
def __init__(self, num_sampled, num_classes, num_true=1, |
|
|
|
sampled_values=None, remove_accidental_hits=True, seed=0, |
|
|
|
reduction='none'): |
|
|
|
super(SampledSoftmaxLoss, self).__init__() |
|
|
|
super(SampledSoftmaxLoss, self).__init__(reduction) |
|
|
|
|
|
|
|
if num_true < 1: |
|
|
|
raise ValueError(f"num_true {num_true} is less than 1.") |
|
|
|
if seed < 0: |
|
|
|
raise ValueError(f"seed {seed} is less than 0.") |
|
|
|
if num_sampled > num_classes: |
|
|
|
raise ValueError(f"num_sampled {num_sampled} is great than num_classes {num_classes}.") |
|
|
|
if num_true > num_classes: |
|
|
|
raise ValueError(f"num_true {num_true} is great than num_classes {num_classes}.") |
|
|
|
if sampled_values is not None: |
|
|
|
if not isinstance(sampled_values, (list, tuple)): |
|
|
|
raise TypeError(f"sampled_values {sampled_values} is not a list.") |
|
|
|
if len(sampled_values) != 3: |
|
|
|
raise ValueError(f"sampled_values size {len(sampled_values)} is not 3.") |
|
|
|
|
|
|
|
self.num_sampled = num_sampled |
|
|
|
self.num_classes = num_classes |
|
|
|
self.num_true = num_true |
|
|
|
|