Browse Source

!27203 while inplace op for cpu backend

Merge pull request !27203 from zhuzhongrui/gmres
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
96f56b827e
4 changed files with 26 additions and 24 deletions
  1. +5
    -6
      mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc
  2. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_functor_gpu_kernel.h
  3. +13
    -10
      tests/st/ops/cpu/test_scatter_nd_update_op.py
  4. +6
    -6
      tests/st/ops/gpu/test_scatter_nd_func_op.py

+ 5
- 6
mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc View File

@@ -141,7 +141,11 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
template <typename T>
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto x = reinterpret_cast<T *>(outputs[0]->addr);
auto ret = memcpy_s(x, outputs[0]->size, inputs[0]->addr, inputs[0]->size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy_s error. Error no: " << ret;
}
ComputeParams<T> params;
params.x_ = x;
params.indices_ = reinterpret_cast<int *>(inputs[1]->addr);
@@ -165,11 +169,6 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &input
start += once_compute_size;
}
(void)common::ThreadPool::GetInstance().SyncRun(tasks);

auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy_s error. Error no: " << ret;
}
}
} // namespace kernel
} // namespace mindspore

+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_functor_gpu_kernel.h View File

@@ -56,12 +56,12 @@ class ScatterNdFunctorKernel : public GpuKernel {
cudaMemcpyAsync(indices_stride, &out_strides_[0], indices_len, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in ScatterNdFunctorGpuFwdKernel::Launch.");
CalScatterNdFunctor(scatter_nd_functor_type_, unit_size_, num_units_, index_depth_, indices_stride, indices,
updates, input, reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
CalScatterNdFunctor(scatter_nd_functor_type_, unit_size_, num_units_, index_depth_, indices_stride, indices,
updates, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}



+ 13
- 10
tests/st/ops/cpu/test_scatter_nd_update_op.py View File

@@ -36,6 +36,7 @@ def test_op1(dtype):
Description: test cases for updating float values
Expectation: the result match scipy
"""

class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
@@ -50,10 +51,10 @@ def test_op1(dtype):
update = Tensor(np.array([1.0, 2.2], dtype=dtype))

scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
print("x:\n", scatter_nd_update.x.data.asnumpy())
output = scatter_nd_update(indices, update)
print("x:\n", output.asnumpy())
expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]
assert np.allclose(scatter_nd_update.x.data.asnumpy(),
assert np.allclose(output.asnumpy(),
np.array(expect, dtype=dtype))


@@ -67,6 +68,7 @@ def test_op2(dtype):
Description: test cases for updating int values
Expectation: the result match scipy
"""

class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
@@ -81,10 +83,10 @@ def test_op2(dtype):
update = Tensor(np.array([9, 10, 11, 12], dtype=dtype))

scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
print("x:\n", scatter_nd_update.x.data.asnumpy())
output = scatter_nd_update(indices, update)
print("x:\n", output.asnumpy())
expect = [1, 11, 3, 10, 9, 6, 7, 12]
assert np.allclose(scatter_nd_update.x.data.asnumpy(),
assert np.allclose(output.asnumpy(),
np.array(expect, dtype=dtype))


@@ -98,6 +100,7 @@ def test_op3(dtype):
Description: test cases for updating int values
Expectation: the result match scipy
"""

class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
@@ -114,13 +117,13 @@ def test_op3(dtype):
[7, 7, 7, 7], [8, 8, 8, 8]]], dtype=dtype))

scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
print("x:\n", scatter_nd_update.x.data.asnumpy())
output = scatter_nd_update(indices, update)
print("x:\n", output.asnumpy())
expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=dtype))
assert np.allclose(output.asnumpy(), np.array(expect, dtype=dtype))


@pytest.mark.level0
@@ -133,6 +136,7 @@ def test_op4(dtype):
Description: test cases for updating single float value
Expectation: the result match scipy
"""

class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
@@ -148,6 +152,5 @@ def test_op4(dtype):
scatter_nd_update = ScatterNdUpdate()
out = scatter_nd_update(indices, update)
print("x:\n", out)
assert np.allclose(out.asnumpy(), scatter_nd_update.x.data.asnumpy())
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype))

+ 6
- 6
tests/st/ops/gpu/test_scatter_nd_func_op.py View File

@@ -92,21 +92,21 @@ def test_scatter_nd_func_input_updated():

# update
net = TestScatterNdFuncNet("update", lock, inputx, indices, updates)
net()
output = net()
expected = np.array([[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]])
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

# add
net = TestScatterNdFuncNet("add", lock, inputx, indices, updates)
net()
output = net()
expected = np.array([[0.9, 0.3, 3.6], [0.4, 2.7, -3.2]])
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

# sub
net = TestScatterNdFuncNet("sub", lock, inputx, indices, updates)
net()
output = net()
expected = np.array([[-1.1, 0.3, 3.6], [0.4, -1.7, -3.2]])
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)


@pytest.mark.level0


Loading…
Cancel
Save