| @@ -1090,19 +1090,25 @@ def get_bprop_ce_with_logits_loss(self): | |||||
| add = P.Add() | add = P.Add() | ||||
| sub = P.Sub() | sub = P.Sub() | ||||
| size = P.Size() | size = P.Size() | ||||
| neg = P.Neg() | |||||
| log = P.Log() | |||||
| def bprop(predict, target, weight, pos_weight, out, dout): | def bprop(predict, target, weight, pos_weight, out, dout): | ||||
| sigmoid_input = sigmoid(predict) | sigmoid_input = sigmoid(predict) | ||||
| if pos_weight is not None: | if pos_weight is not None: | ||||
| t = mul(target, pos_weight) | t = mul(target, pos_weight) | ||||
| dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout) | dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout) | ||||
| grad_target = mul(sub(log(sub(1, sigmoid_input)), mul(pos_weight, log(sigmoid_input))), dout) | |||||
| else: | else: | ||||
| dx = mul((sigmoid_input - target), dout) | dx = mul((sigmoid_input - target), dout) | ||||
| grad_target = mul(predict, neg(dout)) | |||||
| if weight is not None: | if weight is not None: | ||||
| dx = mul(dx, weight) | dx = mul(dx, weight) | ||||
| grad_target = mul(grad_target, weight) | |||||
| if reduction == 'mean': | if reduction == 'mean': | ||||
| dx = dx / size(dx) | dx = dx / size(dx) | ||||
| return dx, zeros_like(target), zeros_like(weight), zeros_like(pos_weight) | |||||
| grad_target = grad_target / size(target) | |||||
| return dx, grad_target, zeros_like(weight), zeros_like(pos_weight) | |||||
| return bprop | return bprop | ||||