|
|
|
@@ -219,7 +219,7 @@ class OneInputBprop(nn.Cell): |
|
|
|
return self.op(x) |
|
|
|
|
|
|
|
def bprop(self, x, out, dout): |
|
|
|
return 5 * x, |
|
|
|
return (5 * x,) |
|
|
|
|
|
|
|
|
|
|
|
def test_grad_one_input_bprop(): |
|
|
|
@@ -349,7 +349,7 @@ class MulAddWithWrongOutputNum(nn.Cell): |
|
|
|
return 2 * x + y |
|
|
|
|
|
|
|
def bprop(self, x, y, out, dout): |
|
|
|
return 2 * dout, |
|
|
|
return (2 * dout,) |
|
|
|
|
|
|
|
|
|
|
|
def test_grad_mul_add_with_wrong_output_num(): |
|
|
|
@@ -380,7 +380,7 @@ def test_grad_mul_add_with_wrong_output_type(): |
|
|
|
class MulAddWithWrongOutputShape(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super(MulAddWithWrongOutputShape, self).__init__() |
|
|
|
self.ones = Tensor(np.ones([2, ])) |
|
|
|
self.ones = Tensor(np.ones([2,])) |
|
|
|
|
|
|
|
def construct(self, x, y): |
|
|
|
return 2 * x + y |
|
|
|
|