Browse Source

fix second grad output of BCEWithLogicLoss.

tags/v1.2.0-rc1
liuxiao93 4 years ago
parent
commit
522aad27b6
1 changed files with 7 additions and 1 deletions
  1. +7
    -1
      mindspore/ops/_grad/grad_nn_ops.py

+ 7
- 1
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -1090,19 +1090,25 @@ def get_bprop_ce_with_logits_loss(self):
add = P.Add() add = P.Add()
sub = P.Sub() sub = P.Sub()
size = P.Size() size = P.Size()
neg = P.Neg()
log = P.Log()


def bprop(predict, target, weight, pos_weight, out, dout): def bprop(predict, target, weight, pos_weight, out, dout):
sigmoid_input = sigmoid(predict) sigmoid_input = sigmoid(predict)
if pos_weight is not None: if pos_weight is not None:
t = mul(target, pos_weight) t = mul(target, pos_weight)
dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout) dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout)
grad_target = mul(sub(log(sub(1, sigmoid_input)), mul(pos_weight, log(sigmoid_input))), dout)
else: else:
dx = mul((sigmoid_input - target), dout) dx = mul((sigmoid_input - target), dout)
grad_target = mul(predict, neg(dout))
if weight is not None: if weight is not None:
dx = mul(dx, weight) dx = mul(dx, weight)
grad_target = mul(grad_target, weight)
if reduction == 'mean': if reduction == 'mean':
dx = dx / size(dx) dx = dx / size(dx)
return dx, zeros_like(target), zeros_like(weight), zeros_like(pos_weight)
grad_target = grad_target / size(target)
return dx, grad_target, zeros_like(weight), zeros_like(pos_weight)


return bprop return bprop




Loading…
Cancel
Save