|
|
|
@@ -28,6 +28,7 @@ from mindspore._checkparam import Validator as validator |
|
|
|
from mindspore._checkparam import Rel |
|
|
|
from ... import context |
|
|
|
|
|
|
|
|
|
|
|
class LossBase(Cell): |
|
|
|
""" |
|
|
|
Base class for other losses. |
|
|
|
@@ -124,6 +125,7 @@ def _check_is_tensor(param_name, input_data, cls_name): |
|
|
|
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', " |
|
|
|
f"but got '{F.typeof(input_data)}'") |
|
|
|
|
|
|
|
|
|
|
|
class L1Loss(LossBase): |
|
|
|
r""" |
|
|
|
L1Loss creates a criterion to measure the mean absolute error (MAE) between :math:`x` and :math:`y` element-wise, |
|
|
|
@@ -580,6 +582,7 @@ class SoftmaxCrossEntropyWithLogits(LossBase): |
|
|
|
x = self.softmax_cross_entropy(logits, labels)[0] |
|
|
|
return self.get_loss(x) |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _check_label_dtype(labels_dtype, cls_name): |
|
|
|
"""Internal function, used to check whether the data type of labels meets the requirements.""" |
|
|
|
|