|
|
|
@@ -78,6 +78,21 @@ def test_slice_grad2(): |
|
|
|
[[0., 0.], [8., 9.], [10., 11.]]] |
|
|
|
assert (output.asnumpy() == expect).all() |
|
|
|
|
|
|
|
def test_slice_grad3(): |
|
|
|
x = Tensor(np.array([[[1.0, 3.5, 5.8], [2.5, 4, 1]], [[3.5, 15.3, 3.1], [2.2, 4.0, 1.1]], |
|
|
|
[[43.4, 1.1, 12.1], [2.4, 6.5, 6.3]]]), mstype.float64) |
|
|
|
dy = Tensor(np.array([[[3.1, 1.1, 2.2]], [[4.4, 1.2, 4.2]]]), mstype.float64) |
|
|
|
slicegrad = SliceGrad() |
|
|
|
output = slicegrad(dy, x) |
|
|
|
expect = [[[0., 0., 0.], |
|
|
|
[3.1, 1.1, 2.2]], |
|
|
|
[[0., 0., 0.], |
|
|
|
[4.4, 1.2, 4.2]], |
|
|
|
[[0., 0., 0.], |
|
|
|
[0., 0., 0.]]] |
|
|
|
print("output:\n", output) |
|
|
|
assert (output.asnumpy() == expect).all() |
|
|
|
|
|
|
|
class StridedSliceGrad(nn.Cell): |
|
|
|
def __init__(self, x, begin, end, stride): |
|
|
|
super(StridedSliceGrad, self).__init__() |
|
|
|
|