| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(ScatterUpdate, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| ScatterUpdateKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(ScatterUpdate, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| ScatterUpdateKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,95 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_UPDATE_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_UPDATE_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class ScatterUpdateKernel : public GpuKernel { | |||
| public: | |||
| ScatterUpdateKernel() : input_size_(0), inner_size_(0), indices_size_(0), updates_size_(0) {} | |||
| ~ScatterUpdateKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| int *indices = GetDeviceAddress<int>(inputs, 1); | |||
| T *updates = GetDeviceAddress<T>(inputs, 2); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| CalScatterUpdate(input_size_, inner_size_, indices_size_, input, indices, updates, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 3) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but ScatterUpdate needs 3 inputs."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but ScatterUpdate has 1 output."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| input_size_ = 1; | |||
| inner_size_ = 1; | |||
| for (size_t i = 1; i < input_shape.size(); i++) { | |||
| inner_size_ *= input_shape[i]; | |||
| } | |||
| input_size_ = input_shape[0] * inner_size_; | |||
| auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| indices_size_ = 1; | |||
| for (size_t i = 0; i < indices_shape.size(); i++) { | |||
| indices_size_ *= indices_shape[i]; | |||
| } | |||
| updates_size_ = indices_size_ * inner_size_; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| input_size_list_.push_back(indices_size_ * sizeof(int)); | |||
| input_size_list_.push_back(updates_size_ * sizeof(T)); | |||
| output_size_list_.push_back(input_size_ * sizeof(T)); | |||
| } | |||
| private: | |||
| int input_size_; | |||
| int inner_size_; | |||
| int indices_size_; | |||
| int updates_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_UPDATE_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh" | |||
| template <typename T> | |||
| __global__ void ScatterUpdate(const int input_size, const int inner_size, const int indices_size, const T *input, | |||
| const int *indices, const T *updates, T *output) { | |||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) { | |||
| output[pos] = input[pos]; | |||
| const int index = pos / inner_size; | |||
| const int offset = pos % inner_size; | |||
| for (int i = 0; i < indices_size; i++) { | |||
| const int update_pos = i * inner_size + offset; | |||
| output[pos] = (indices[i] == index ? updates[update_pos] : output[pos]); | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalScatterUpdate(const int &input_size, const int &inner_size, const int &indices_size, const T *input, | |||
| const int *indices, const T *updates, T *output, cudaStream_t cuda_stream) { | |||
| ScatterUpdate<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, inner_size, indices_size, input, | |||
| indices, updates, output); | |||
| } | |||
| template void CalScatterUpdate<float>(const int &input_size, const int &inner_size, const int &indices_size, | |||
| const float *input, const int *indices, const float *updates, float *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalScatterUpdate<half>(const int &input_size, const int &inner_size, const int &indices_size, | |||
| const half *input, const int *indices, const half *updates, half *output, | |||
| cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_UPDATE_IMPL_CUH_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_UPDATE_IMPL_CUH_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalScatterUpdate(const int &input_size, const int &inner_size, const int &indices_size, const T *input, | |||
| const int *indices, const T *updates, T *output, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_UPDATE_IMPL_CUH_ | |||
| @@ -0,0 +1,106 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| # all cases tested against dchip | |||
| class TestScatterUpdateNet(nn.Cell): | |||
| def __init__(self, inputx, indices, updates): | |||
| super(TestScatterUpdateNet, self).__init__() | |||
| self.scatter_update = P.ScatterUpdate() | |||
| self.inputx = Parameter(inputx, name="inputx") | |||
| self.indices = Parameter(indices, name="indices") | |||
| self.updates = Parameter(updates, name="updates") | |||
| def construct(self): | |||
| out = self.scatter_update(self.inputx, self.indices, self.updates) | |||
| return out | |||
| def scatter_update_net(inputx, indices, updates): | |||
| net = TestScatterUpdateNet(inputx, indices, updates) | |||
| return net() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_update_small_float32(): | |||
| inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) | |||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32)) | |||
| output = scatter_update_net(inputx, indices, updates) | |||
| expected = np.array([[0., 1., 2.], | |||
| [3., 4., 5.]]) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_update_input_less_than_1_float32(): | |||
| inputx = Tensor(np.array([[0.214141, 0.415151, 0.51516], | |||
| [0.876542, 0.451611, 0.55112], | |||
| [0.111244, 0.633333, 0.34444]]).astype(np.float32)) | |||
| indices = Tensor(np.array([1, 0, 2]).astype(np.int32)) | |||
| updates = Tensor(np.arange(34, 43).reshape((3, 3)).astype(np.float32)) | |||
| output = scatter_update_net(inputx, indices, updates) | |||
| expected = np.array([[37., 38., 39.], | |||
| [34., 35., 36.], | |||
| [40., 41., 42.]], dtype=np.float32) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_update_float16(): | |||
| inputx = Tensor(np.zeros((2, 3)).astype(np.float16)) | |||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float16)) | |||
| output = scatter_update_net(inputx, indices, updates) | |||
| expected = np.array([[0., 1., 2.], | |||
| [3., 4., 5.]]) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_update_large_float16(): | |||
| inputx = Tensor(np.zeros((4, 3)).astype(np.float16)) | |||
| indices = Tensor(np.array([[2, 1], [0, 3]]).astype(np.int32)) | |||
| updates = Tensor(np.arange(63, 75).reshape((2, 2, 3)).astype(np.float16)) | |||
| output = scatter_update_net(inputx, indices, updates) | |||
| expected = np.array([[69., 70., 71.], | |||
| [66., 67., 68.], | |||
| [63., 64., 65.], | |||
| [72., 73., 74.]]) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_update_disordered_float16(): | |||
| inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.float16))) | |||
| indices = Tensor(np.array([1, 2]).astype(np.int32)) | |||
| updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.float16)) | |||
| output = scatter_update_net(inputx, indices, updates) | |||
| expected = np.array([[45., 44., 43., 42.], | |||
| [63., 64., 65., 66.], | |||
| [67., 68., 69., 70.]]) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||