diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index cdce91ece5..16f3dcb174 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ class _Loss(Cell): """ Base class for other losses. """ - def __init__(self, reduction='mean', weights=1.0): + def __init__(self, reduction='mean'): super(_Loss, self).__init__() if reduction is None: reduction = 'none' @@ -47,10 +47,7 @@ class _Loss(Cell): self.reduce_mean = _selected_ops.ReduceMean() self.reduce_sum = P.ReduceSum() self.mul = P.Mul() - if isinstance(weights, int): - self.weights = float(weights) - else: - self.weights = weights + self.cast = P.Cast() def get_axis(self, x): shape = F.shape(x) @@ -58,13 +55,22 @@ class _Loss(Cell): perm = F.make_range(0, length) return perm - def get_loss(self, x): - if self.weights != 1.0: - x = self.mul(self.weights, x) + def get_loss(self, x, weights=1.0): + """ + Computes the weighted loss + Args: + weights: Optional `Tensor` whose rank is either 0, or the same rank as inputs, and must be broadcastable to + inputs (i.e., all dimensions must be either `1`, or the same as the corresponding inputs dimension). + """ + input_dtype = x.dtype + x = self.cast(x, mstype.float32) + weights = self.cast(weights, mstype.float32) + x = self.mul(weights, x) if self.reduce and self.average: x = self.reduce_mean(x, self.get_axis(x)) if self.reduce and not self.average: x = self.reduce_sum(x, self.get_axis(x)) + x = self.cast(x, input_dtype) return x def construct(self, base, target): diff --git a/tests/st/ops/gpu/test_loss.py b/tests/st/ops/gpu/test_loss.py index 693f8f01bd..a5d95a9859 100644 --- a/tests/st/ops/gpu/test_loss.py +++ b/tests/st/ops/gpu/test_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,19 +14,24 @@ # ============================================================================ """ test loss """ import numpy as np - +import mindspore from mindspore import Tensor from mindspore.ops import operations as P from mindspore.nn.loss.loss import _Loss +from mindspore.nn.loss.loss import L1Loss +import mindspore.context as context + +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') class WeightedLoss(_Loss): def __init__(self, reduction='mean', weights=1.0): - super(WeightedLoss, self).__init__(reduction, weights) + super(WeightedLoss, self).__init__(reduction) self.abs = P.Abs() + self.weights = weights def construct(self, base, target): x = self.abs(base - target) - return self.get_loss(x) + return self.get_loss(x, self.weights) def test_WeightedLoss(): loss = WeightedLoss() @@ -35,17 +40,41 @@ def test_WeightedLoss(): output_data = loss(input_data, target_data) error_range = np.ones(shape=output_data.shape) * 10e-6 - loss.weights = 1.0 + loss = WeightedLoss(weights=2.0) test_output = loss(input_data, target_data) - diff = test_output - output_data * loss.weights + diff = test_output - output_data * 2.0 assert np.all(abs(diff.asnumpy()) < error_range) - loss.weights = 2.0 + loss = WeightedLoss(weights=3) test_output = loss(input_data, target_data) - diff = test_output - output_data * loss.weights + diff = test_output - output_data * 3 assert np.all(abs(diff.asnumpy()) < error_range) - loss.weights = 3 - test_output = loss(input_data, target_data) - diff = test_output - output_data * loss.weights + loss = WeightedLoss(weights=Tensor(np.array([[0.7, 0.3], [0.7, 0.3]]))) + y_true = Tensor(np.array([[0., 1.], [0., 0.]]), mindspore.float32) + y_pred = Tensor(np.array([[1., 1.], [1., 0.]]), mindspore.float32) + test_data = 0.35 + output = loss(y_true, y_pred) + diff = test_data - output.asnumpy() + assert np.all(abs(diff) < error_range) + +class CustomLoss(_Loss): + def __init__(self, reduction='mean'): + super(CustomLoss, self).__init__(reduction) + self.abs = P.Abs() + + def construct(self, base, target): + x = self.abs(base - target) + return self.get_loss(x, weights=2.0) + +def test_CustomLoss(): + loss = L1Loss() + input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32)) + target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32)) + output_data = loss(input_data, target_data) + + error_range = np.ones(shape=output_data.shape) * 10e-6 + customloss = CustomLoss() + test_output = customloss(input_data, target_data) + diff = test_output - output_data * 2.0 assert np.all(abs(diff.asnumpy()) < error_range)