Browse Source

fix cpu StridedSliceGrad

tags/v0.5.0-beta
sunsuodong 5 years ago
parent
commit
63d9e47291
3 changed files with 34 additions and 7 deletions
  1. +5
    -5
      mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc
  2. +28
    -1
      tests/st/ops/cpu/test_stridedslice_grad_op.py
  3. +1
    -1
      tests/st/ops/cpu/test_stridedslice_op.py

+ 5
- 5
mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc View File

@@ -61,11 +61,11 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
end_.emplace_back(begin_[i] + sizes[i]); end_.emplace_back(begin_[i] + sizes[i]);
} }
} }
CPUKernelUtils::ExpandDimsTo4(&output_dx_shape_);
auto input_len = input_dy_shape_.size();
if (input_len < 4) {
for (size_t i = 0; i < 4 - input_len; ++i) {
input_dy_shape_.insert(input_dy_shape_.begin(), 1);
auto output_len = output_dx_shape_.size();
if (output_len < 4) {
for (size_t i = 0; i < 4 - output_len; ++i) {
output_dx_shape_.insert(output_dx_shape_.begin(), 1);
begin_.insert(begin_.begin(), 0); begin_.insert(begin_.begin(), 0);
strides_.insert(strides_.begin(), 1); strides_.insert(strides_.begin(), 1);
end_.insert(end_.begin(), 1); end_.insert(end_.begin(), 1);


+ 28
- 1
tests/st/ops/cpu/test_stridedslice_grad_op.py View File

@@ -19,6 +19,7 @@ import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
@@ -38,7 +39,7 @@ class StridedSliceGrad(nn.Cell):




@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_slice(): def test_slice():
x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32)) x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32))
@@ -47,3 +48,29 @@ def test_slice():
output = ssg(dy, x) output = ssg(dy, x)
expect = [[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[5, 1, 5], [6, 1, 8]]] expect = [[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[5, 1, 5], [6, 1, 8]]]
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()


class StridedSliceGrad2(nn.Cell):
def __init__(self):
super(StridedSliceGrad2, self).__init__()
self.ssg = G.StridedSliceGrad()
self.shape = P.Shape()

@ms_function
def construct(self, dy, x):
return self.ssg(dy, self.shape(x), (0, 0, 0), (1, 4, 2), (1, 1, 1))

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_slice2():
x = Tensor(np.arange(2 * 4 * 2).reshape(2, 4, 2), mstype.float32)
dy = Tensor(np.arange(4 * 2).reshape(4, 2), mstype.float32)
ssg = StridedSliceGrad2()
output = ssg(dy, x)
expect = [[[0., 1.], [2., 3.], [4., 5.], [6., 7.]], [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]
assert (output.asnumpy() == expect).all()

if __name__ == '__main__':
test_slice()
test_slice2()

+ 1
- 1
tests/st/ops/cpu/test_stridedslice_op.py View File

@@ -34,7 +34,7 @@ class StridedSlice(nn.Cell):




@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_slice(): def test_slice():
x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32)) x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32))


Loading…
Cancel
Save