Browse Source

fix bug of brpop of FloorMod

tags/v0.3.0-alpha
buxue 5 years ago
parent
commit
0c6cf98db0
2 changed files with 6 additions and 8 deletions
  1. +4
    -5
      mindspore/ops/_grad/grad_math_ops.py
  2. +2
    -3
      tests/ut/python/ops/test_ops.py

+ 4
- 5
mindspore/ops/_grad/grad_math_ops.py View File

@@ -255,13 +255,10 @@ def get_bprop_floordiv(self):
@bprop_getters.register(P.FloorMod) @bprop_getters.register(P.FloorMod)
def get_bprop_floormod(self): def get_bprop_floormod(self):
"""Grad definition for `FloorMod` operation.""" """Grad definition for `FloorMod` operation."""
div_op = P.FloorMod()
neg = P.Neg()
mul_op = P.Mul()


def bprop(x, y, out, dout): def bprop(x, y, out, dout):
bc_x = div_op(dout, y)
bc_y = neg(mul_op(bc_x, out))
bc_x = dout
bc_y = -dout * (x // y)
return binop_grad_common(x, y, bc_x, bc_y) return binop_grad_common(x, y, bc_x, bc_y)
return bprop return bprop


@@ -412,6 +409,7 @@ def get_bprop_reducesum(self):
def get_bprop_cumsum(self): def get_bprop_cumsum(self):
"""Grad definition for `CumSum` operation.""" """Grad definition for `CumSum` operation."""
cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse) cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)

def bprop(x, axis, out, dout): def bprop(x, axis, out, dout):
return cumsum(dout, axis), zeros_like(axis) return cumsum(dout, axis), zeros_like(axis)
return bprop return bprop
@@ -787,6 +785,7 @@ def get_bprop_atan2(self):
"""Generate bprop for Atan2""" """Generate bprop for Atan2"""


square = P.Square() square = P.Square()

def bprop(x, y, out, dout): def bprop(x, y, out, dout):
tmp = dout / (square(x) + square(y)) tmp = dout / (square(x) + square(y))
dx = tmp * y dx = tmp * y


+ 2
- 3
tests/ut/python/ops/test_ops.py View File

@@ -351,9 +351,8 @@ test_case_math_ops = [
'skip': ['backward']}), 'skip': ['backward']}),
('FloorMod', { ('FloorMod', {
'block': P.FloorMod(), 'block': P.FloorMod(),
'desc_inputs': [Tensor(np.random.rand(4).astype(np.float16)),
Tensor(np.random.rand(4).astype(np.float16))],
'skip': ['backward']}),
'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]],
'desc_bprop': [[2, 3, 4, 5]]}),
('identity', { ('identity', {
'block': ops.functional.identity, 'block': ops.functional.identity,
'desc_inputs': [[2, 2]], 'desc_inputs': [[2, 2]],


Loading…
Cancel
Save