Browse Source

Conversion weights type when compute weights loss

tags/v1.2.0-rc1
xuguoyang 5 years ago
parent
commit
7dcc3d89c3
2 changed files with 55 additions and 20 deletions
  1. +15
    -9
      mindspore/nn/loss/loss.py
  2. +40
    -11
      tests/st/ops/gpu/test_loss.py

+ 15
- 9
mindspore/nn/loss/loss.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -29,7 +29,7 @@ class _Loss(Cell):
""" """
Base class for other losses. Base class for other losses.
""" """
def __init__(self, reduction='mean', weights=1.0):
def __init__(self, reduction='mean'):
super(_Loss, self).__init__() super(_Loss, self).__init__()
if reduction is None: if reduction is None:
reduction = 'none' reduction = 'none'
@@ -47,10 +47,7 @@ class _Loss(Cell):
self.reduce_mean = _selected_ops.ReduceMean() self.reduce_mean = _selected_ops.ReduceMean()
self.reduce_sum = P.ReduceSum() self.reduce_sum = P.ReduceSum()
self.mul = P.Mul() self.mul = P.Mul()
if isinstance(weights, int):
self.weights = float(weights)
else:
self.weights = weights
self.cast = P.Cast()


def get_axis(self, x): def get_axis(self, x):
shape = F.shape(x) shape = F.shape(x)
@@ -58,13 +55,22 @@ class _Loss(Cell):
perm = F.make_range(0, length) perm = F.make_range(0, length)
return perm return perm


def get_loss(self, x):
if self.weights != 1.0:
x = self.mul(self.weights, x)
def get_loss(self, x, weights=1.0):
"""
Computes the weighted loss
Args:
weights: Optional `Tensor` whose rank is either 0, or the same rank as inputs, and must be broadcastable to
inputs (i.e., all dimensions must be either `1`, or the same as the corresponding inputs dimension).
"""
input_dtype = x.dtype
x = self.cast(x, mstype.float32)
weights = self.cast(weights, mstype.float32)
x = self.mul(weights, x)
if self.reduce and self.average: if self.reduce and self.average:
x = self.reduce_mean(x, self.get_axis(x)) x = self.reduce_mean(x, self.get_axis(x))
if self.reduce and not self.average: if self.reduce and not self.average:
x = self.reduce_sum(x, self.get_axis(x)) x = self.reduce_sum(x, self.get_axis(x))
x = self.cast(x, input_dtype)
return x return x


def construct(self, base, target): def construct(self, base, target):


+ 40
- 11
tests/st/ops/gpu/test_loss.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,19 +14,24 @@
# ============================================================================ # ============================================================================
""" test loss """ """ test loss """
import numpy as np import numpy as np
import mindspore
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.nn.loss.loss import _Loss from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import L1Loss
import mindspore.context as context

context.set_context(mode=context.GRAPH_MODE, device_target='GPU')


class WeightedLoss(_Loss): class WeightedLoss(_Loss):
def __init__(self, reduction='mean', weights=1.0): def __init__(self, reduction='mean', weights=1.0):
super(WeightedLoss, self).__init__(reduction, weights)
super(WeightedLoss, self).__init__(reduction)
self.abs = P.Abs() self.abs = P.Abs()
self.weights = weights


def construct(self, base, target): def construct(self, base, target):
x = self.abs(base - target) x = self.abs(base - target)
return self.get_loss(x)
return self.get_loss(x, self.weights)


def test_WeightedLoss(): def test_WeightedLoss():
loss = WeightedLoss() loss = WeightedLoss()
@@ -35,17 +40,41 @@ def test_WeightedLoss():
output_data = loss(input_data, target_data) output_data = loss(input_data, target_data)


error_range = np.ones(shape=output_data.shape) * 10e-6 error_range = np.ones(shape=output_data.shape) * 10e-6
loss.weights = 1.0
loss = WeightedLoss(weights=2.0)
test_output = loss(input_data, target_data) test_output = loss(input_data, target_data)
diff = test_output - output_data * loss.weights
diff = test_output - output_data * 2.0
assert np.all(abs(diff.asnumpy()) < error_range) assert np.all(abs(diff.asnumpy()) < error_range)


loss.weights = 2.0
loss = WeightedLoss(weights=3)
test_output = loss(input_data, target_data) test_output = loss(input_data, target_data)
diff = test_output - output_data * loss.weights
diff = test_output - output_data * 3
assert np.all(abs(diff.asnumpy()) < error_range) assert np.all(abs(diff.asnumpy()) < error_range)


loss.weights = 3
test_output = loss(input_data, target_data)
diff = test_output - output_data * loss.weights
loss = WeightedLoss(weights=Tensor(np.array([[0.7, 0.3], [0.7, 0.3]])))
y_true = Tensor(np.array([[0., 1.], [0., 0.]]), mindspore.float32)
y_pred = Tensor(np.array([[1., 1.], [1., 0.]]), mindspore.float32)
test_data = 0.35
output = loss(y_true, y_pred)
diff = test_data - output.asnumpy()
assert np.all(abs(diff) < error_range)

class CustomLoss(_Loss):
def __init__(self, reduction='mean'):
super(CustomLoss, self).__init__(reduction)
self.abs = P.Abs()

def construct(self, base, target):
x = self.abs(base - target)
return self.get_loss(x, weights=2.0)

def test_CustomLoss():
loss = L1Loss()
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32))
target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32))
output_data = loss(input_data, target_data)

error_range = np.ones(shape=output_data.shape) * 10e-6
customloss = CustomLoss()
test_output = customloss(input_data, target_data)
diff = test_output - output_data * 2.0
assert np.all(abs(diff.asnumpy()) < error_range) assert np.all(abs(diff.asnumpy()) < error_range)

Loading…
Cancel
Save