From: @yuan_shen_zhou Reviewed-by: @liangchenghui Signed-off-by: @liangchenghuitags/v1.1.0
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/unpack_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Unpack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnpackGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Unpack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnpackGpuFwdKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Unpack, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| UnpackGpuFwdKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE(Unpack, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||
| UnpackGpuFwdKernel, int16_t) | |||
| MS_REG_GPU_KERNEL_ONE(Unpack, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||
| UnpackGpuFwdKernel, uchar) | |||
| MS_REG_GPU_KERNEL_ONE(Unpack, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||
| UnpackGpuFwdKernel, bool) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,112 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNPACK_GPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNPACK_GPU_KERNEL_H | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/unpack.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class UnpackGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| UnpackGpuFwdKernel() : axis_(0), output_num_(0), input_size_(1), dims_after_axis_(1), outputs_host_(nullptr) {} | |||
| ~UnpackGpuFwdKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T **outputs_array = GetDeviceAddress<T *>(workspace, 0); | |||
| for (size_t i = 0; i < outputs.size(); i++) { | |||
| outputs_host_[i] = GetDeviceAddress<T>(outputs, i); | |||
| } | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(outputs_array, outputs_host_.get(), sizeof(T *) * output_num_, | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "Unpack opt cudaMemcpyAsync outputs failed"); | |||
| UnpackKernel(SizeToInt(input_size_), output_num_, dims_after_axis_, outputs_array, input, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| if (!CheckParam(kernel_node)) { | |||
| return false; | |||
| } | |||
| axis_ = static_cast<int32_t>(GetAttr<int64_t>(kernel_node, "axis")); | |||
| if (axis_ < 0) { | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| axis_ += SizeToInt(input_shape.size()); | |||
| } | |||
| auto origin_data_format = AnfAlgo::GetOriginDataFormat(kernel_node); | |||
| auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0); | |||
| axis_ = AxisTransform(origin_data_format, input_format, axis_); | |||
| output_num_ = static_cast<int32_t>(GetAttr<int64_t>(kernel_node, "num")); | |||
| outputs_host_ = std::make_unique<T *[]>(output_num_); | |||
| for (int i = 0; i < output_num_; i++) { | |||
| size_t _size = 1; | |||
| auto _shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i); | |||
| for (size_t j = 0; j < _shape.size(); j++) { | |||
| _size *= _shape[j]; | |||
| } | |||
| output_size_list_.push_back(_size * sizeof(T)); | |||
| } | |||
| workspace_size_list_.push_back(sizeof(T *) * output_num_); | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| for (int i = 0; i < SizeToInt(input_shape.size()); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| if (i > axis_) { | |||
| dims_after_axis_ *= input_shape[i]; | |||
| } | |||
| } | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override {} | |||
| private: | |||
| bool CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "input number is " << input_num << ", but UnpackGpuFwdKernel needs 1 input."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| int axis_; | |||
| int output_num_; | |||
| size_t input_size_; | |||
| int dims_after_axis_; | |||
| std::unique_ptr<T *[]> outputs_host_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNPACK_GPU_KERNEL_H | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <stdio.h> | |||
| #include <stdint.h> | |||
| #include <cuda_runtime.h> | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/unpack.cuh" | |||
| template <typename T> | |||
| __global__ void Unpack(const int size, const int output_num, | |||
| const int dims_after_axis, T** outputs, const T* input) { | |||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| int cycle = pos / (output_num * dims_after_axis); | |||
| int cur_output_index = pos % (output_num * dims_after_axis) / dims_after_axis; | |||
| int local_index = pos % (output_num * dims_after_axis) % dims_after_axis; | |||
| outputs[cur_output_index][cycle * dims_after_axis + local_index] = input[pos]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void UnpackKernel(const int size, const int output_num, | |||
| const int dims_after_axis, T** outputs, const T* input, | |||
| cudaStream_t cuda_stream) { | |||
| Unpack<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, output_num, | |||
| dims_after_axis, outputs, input); | |||
| return; | |||
| } | |||
| template void UnpackKernel(const int size, const int output_num, | |||
| const int dims_after_axis, float** outputs, const float* input, | |||
| cudaStream_t cuda_stream); | |||
| template void UnpackKernel(const int size, const int output_num, | |||
| const int dims_after_axis, half** outputs, const half* input, | |||
| cudaStream_t cuda_stream); | |||
| template void UnpackKernel(const int size, const int output_num, | |||
| const int dims_after_axis, int** outputs, const int* input, | |||
| cudaStream_t cuda_stream); | |||
| template void UnpackKernel(const int size, const int output_num, | |||
| const int dims_after_axis, int16_t** outputs, const int16_t* input, | |||
| cudaStream_t cuda_stream); | |||
| template void UnpackKernel(const int size, const int output_num, | |||
| const int dims_after_axis, unsigned char** outputs, const unsigned char* input, | |||
| cudaStream_t cuda_stream); | |||
| template void UnpackKernel(const int size, const int output_num, | |||
| const int dims_after_axis, bool** outputs, const bool* input, | |||
| cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNPACKIMPL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNPACKIMPL_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void UnpackKernel(const int size, const int output_num, | |||
| const int dims_after_axis, T** outputs, const T* input, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNPACKIMPL_H_ | |||
| @@ -0,0 +1,162 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations.array_ops as P | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| class UnpackNet(nn.Cell): | |||
| def __init__(self, nptype): | |||
| super(UnpackNet, self).__init__() | |||
| self.unpack = P.Unpack(axis=3) | |||
| self.data_np = np.array([[[[[0, 0], | |||
| [0, 1]], | |||
| [[0, 0], | |||
| [2, 3]]], | |||
| [[[0, 0], | |||
| [4, 5]], | |||
| [[0, 0], | |||
| [6, 7]]]], | |||
| [[[[0, 0], | |||
| [8, 9]], | |||
| [[0, 0], | |||
| [10, 11]]], | |||
| [[[0, 0], | |||
| [12, 13]], | |||
| [[0, 0], | |||
| [14, 15]]]]]).astype(nptype) | |||
| self.x1 = Parameter(initializer(Tensor(self.data_np), [2, 2, 2, 2, 2]), name='x1') | |||
| @ms_function | |||
| def construct(self): | |||
| return self.unpack(self.x1) | |||
| def unpack(nptype): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| unpack_ = UnpackNet(nptype) | |||
| output = unpack_() | |||
| expect = (np.reshape(np.array([0] * 16).astype(nptype), (2, 2, 2, 2)), | |||
| np.arange(2 * 2 * 2 * 2).reshape(2, 2, 2, 2).astype(nptype)) | |||
| for i, exp in enumerate(expect): | |||
| assert (output[i].asnumpy() == exp).all() | |||
| def unpack_pynative(nptype): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| x1 = np.array([[[[[0, 0], | |||
| [0, 1]], | |||
| [[0, 0], | |||
| [2, 3]]], | |||
| [[[0, 0], | |||
| [4, 5]], | |||
| [[0, 0], | |||
| [6, 7]]]], | |||
| [[[[0, 0], | |||
| [8, 9]], | |||
| [[0, 0], | |||
| [10, 11]]], | |||
| [[[0, 0], | |||
| [12, 13]], | |||
| [[0, 0], | |||
| [14, 15]]]]]).astype(nptype) | |||
| x1 = Tensor(x1) | |||
| expect = (np.reshape(np.array([0] * 16).astype(nptype), (2, 2, 2, 2)), | |||
| np.arange(2 * 2 * 2 * 2).reshape(2, 2, 2, 2).astype(nptype)) | |||
| output = P.Unpack(axis=3)(x1) | |||
| for i, exp in enumerate(expect): | |||
| assert (output[i].asnumpy() == exp).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_graph_float32(): | |||
| unpack(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_graph_float16(): | |||
| unpack(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_graph_int32(): | |||
| unpack(np.int32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_graph_int16(): | |||
| unpack(np.int16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_graph_uint8(): | |||
| unpack(np.uint8) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_graph_bool(): | |||
| unpack(np.bool) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_pynative_float32(): | |||
| unpack_pynative(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_pynative_float16(): | |||
| unpack_pynative(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_pynative_int32(): | |||
| unpack_pynative(np.int32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_pynative_int16(): | |||
| unpack_pynative(np.int16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_pynative_uint8(): | |||
| unpack_pynative(np.uint8) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unpack_pynative_bool(): | |||
| unpack_pynative(np.bool) | |||