|
|
|
@@ -281,6 +281,10 @@ class SoftmaxCrossEntropyWithLogits(_Loss): |
|
|
|
x = self.softmax_cross_entropy(logits, labels)[0] |
|
|
|
return self.get_loss(x) |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _check_label_dtype(labels_dtype, cls_name): |
|
|
|
validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name) |
|
|
|
|
|
|
|
|
|
|
|
class SampledSoftmaxLoss(_Loss): |
|
|
|
r""" |
|
|
|
@@ -373,8 +377,11 @@ class SampledSoftmaxLoss(_Loss): |
|
|
|
self.zeros_like = P.ZerosLike() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.expand_dims = P.ExpandDims() |
|
|
|
self.dtype = P.DType() |
|
|
|
|
|
|
|
def construct(self, weights, biases, labels, inputs): |
|
|
|
_check_label_dtype(self.dtype(labels), self.cls_name) |
|
|
|
|
|
|
|
logits, labels = self._compute_sampled_logits( |
|
|
|
weights=weights, |
|
|
|
biases=biases, |
|
|
|
@@ -424,6 +431,7 @@ class SampledSoftmaxLoss(_Loss): |
|
|
|
`[batch_size, num_true + num_sampled]` |
|
|
|
out_labels: A Tensor object with the same shape as `out_logits`. |
|
|
|
""" |
|
|
|
|
|
|
|
if not labels.dtype == mstype.int32: |
|
|
|
labels = self.cast(labels, mstype.int32) |
|
|
|
labels = self.reshape(labels, (-1, num_true)) |
|
|
|
|