Merge pull request !3083 from zhaoting/mastertags/v0.6.0-beta
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" | |||
| @@ -27,40 +28,35 @@ namespace kernel { | |||
| template <typename T> | |||
| class ConcatV2GpuFwdKernel : public GpuKernel { | |||
| public: | |||
| ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {} | |||
| ConcatV2GpuFwdKernel() | |||
| : axis_(0), | |||
| input_num_(1), | |||
| output_size_(0), | |||
| all_size_before_axis_(1), | |||
| all_size_axis_(1), | |||
| inputs_host_(nullptr), | |||
| len_axis_(nullptr) {} | |||
| ~ConcatV2GpuFwdKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| if (inputs.size() == 2) { | |||
| T *input_0 = GetDeviceAddress<T>(inputs, 0); | |||
| T *input_1 = GetDeviceAddress<T>(inputs, 1); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| if (inputs.size() == 3) { | |||
| T *input_0 = GetDeviceAddress<T>(inputs, 0); | |||
| T *input_1 = GetDeviceAddress<T>(inputs, 1); | |||
| T *input_2 = GetDeviceAddress<T>(inputs, 2); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| if (inputs.size() == 4) { | |||
| T *input_0 = GetDeviceAddress<T>(inputs, 0); | |||
| T *input_1 = GetDeviceAddress<T>(inputs, 1); | |||
| T *input_2 = GetDeviceAddress<T>(inputs, 2); | |||
| T *input_3 = GetDeviceAddress<T>(inputs, 3); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| T **inputs_device = GetDeviceAddress<T *>(workspace, 0); | |||
| int *len_axis_device = GetDeviceAddress<int>(workspace, 1); | |||
| for (size_t i = 0; i < inputs.size(); i++) { | |||
| inputs_host_[i] = GetDeviceAddress<T>(inputs, i); | |||
| } | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(inputs_device, inputs_host_.get(), sizeof(T *) * input_num_, | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "ConcatV2 opt cudaMemcpyAsync inputs failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(len_axis_device, len_axis_.get(), sizeof(int) * input_num_, | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "ConcatV2 opt cudaMemcpyAsync length on axis failed"); | |||
| ConcatKernel(output_size_, input_num_, all_size_before_axis_, all_size_axis_, len_axis_device, inputs_device, | |||
| output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| @@ -74,25 +70,34 @@ class ConcatV2GpuFwdKernel : public GpuKernel { | |||
| axis_ += SizeToInt(input_shape.size()); | |||
| } | |||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t i = 0; i < input_num; i++) { | |||
| auto input_size = sizeof(T); | |||
| input_num_ = SizeToInt(AnfAlgo::GetInputTensorNum(kernel_node)); | |||
| inputs_host_ = std::make_unique<T *[]>(input_num_); | |||
| len_axis_ = std::make_unique<int[]>(input_num_); | |||
| for (int i = 0; i < input_num_; i++) { | |||
| int input_size = 1; | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); | |||
| for (size_t j = 0; j < input_shape.size(); j++) { | |||
| input_size *= SizeToInt(input_shape[j]); | |||
| if (j >= IntToSize(axis_)) { | |||
| w_[i] *= SizeToInt(input_shape[j]); | |||
| } | |||
| input_size_list_.push_back(input_size); | |||
| } | |||
| input_size_list_.push_back(IntToSize(input_size * sizeof(T))); | |||
| len_axis_[i] = SizeToInt(input_shape[axis_]); | |||
| } | |||
| workspace_size_list_.push_back(sizeof(T *) * input_num_); | |||
| workspace_size_list_.push_back(sizeof(int) * input_num_); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| output_size_ = sizeof(T); | |||
| for (size_t i = 0; i < output_shape.size(); i++) { | |||
| output_size_ = 1; | |||
| for (int i = 0; i < SizeToInt(output_shape.size()); i++) { | |||
| output_size_ *= output_shape[i]; | |||
| if (i > axis_) { | |||
| all_size_before_axis_ *= output_shape[i]; | |||
| all_size_axis_ *= output_shape[i]; | |||
| } | |||
| if (i == axis_) { | |||
| all_size_before_axis_ *= output_shape[i]; | |||
| } | |||
| } | |||
| output_size_list_.push_back(output_size_); | |||
| output_size_list_.push_back(IntToSize(output_size_ * sizeof(T))); | |||
| InitSizeLists(); | |||
| return true; | |||
| @@ -103,11 +108,6 @@ class ConcatV2GpuFwdKernel : public GpuKernel { | |||
| private: | |||
| bool CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num < 2 || input_num > 4) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but ConcatV2GpuFwdKernel needs 1 output."; | |||
| @@ -115,9 +115,13 @@ class ConcatV2GpuFwdKernel : public GpuKernel { | |||
| } | |||
| return true; | |||
| } | |||
| int w_[4] = {1, 1, 1, 1}; | |||
| int axis_; | |||
| size_t output_size_; | |||
| int input_num_; | |||
| int output_size_; | |||
| int all_size_before_axis_; | |||
| int all_size_axis_; | |||
| std::unique_ptr<T *[]> inputs_host_; | |||
| std::unique_ptr<int[]> len_axis_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SplitGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Split, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| SplitGpuFwdKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SplitGpuFwdKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,153 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class SplitGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| SplitGpuFwdKernel() | |||
| : axis_(0), | |||
| output_num_(1), | |||
| input_size_(1), | |||
| axis_step_(1), | |||
| all_size_before_axis_(1), | |||
| all_size_axis_(1), | |||
| outputs_host_(nullptr) {} | |||
| ~SplitGpuFwdKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T **outputs_device = GetDeviceAddress<T *>(workspace, 0); | |||
| for (size_t i = 0; i < outputs.size(); i++) { | |||
| outputs_host_[i] = GetDeviceAddress<T>(outputs, i); | |||
| } | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(outputs_device, outputs_host_.get(), sizeof(T *) * output_num_, | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "Split opt cudaMemcpyAsync outputs failed"); | |||
| SplitKernel(input_size_, axis_step_, all_size_before_axis_, all_size_axis_, input, outputs_device, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| axis_ = GetAttr<int>(kernel_node, "axis"); | |||
| if (axis_ < 0) { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| axis_ += SizeToInt(input_shape.size()); | |||
| } | |||
| output_num_ = GetAttr<int>(kernel_node, "output_num"); | |||
| if (!CheckParam(kernel_node)) { | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| input_size_ = 1; | |||
| all_size_before_axis_ = 1; | |||
| all_size_axis_ = 1; | |||
| for (int i = 0; i < SizeToInt(input_shape.size()); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| if (i > axis_) { | |||
| all_size_before_axis_ *= input_shape[i]; | |||
| all_size_axis_ *= input_shape[i]; | |||
| } | |||
| if (i == axis_) { | |||
| all_size_before_axis_ *= input_shape[i]; | |||
| } | |||
| } | |||
| input_size_list_.push_back(IntToSize(input_size_ * sizeof(T))); | |||
| axis_step_ = input_shape[axis_] / output_num_; | |||
| for (int i = 0; i < output_num_; i++) { | |||
| size_t output_size = 1; | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, i); | |||
| for (size_t j = 0; j < output_shape.size(); j++) { | |||
| output_size *= output_shape[j]; | |||
| } | |||
| output_size_list_.push_back(output_size * sizeof(T)); | |||
| } | |||
| workspace_size_list_.push_back(sizeof(T *) * output_num_); | |||
| InitSizeLists(); | |||
| outputs_host_ = std::make_unique<T *[]>(output_num_); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override {} | |||
| private: | |||
| bool CheckParam(const CNodePtr &kernel_node) { | |||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| int dims = SizeToInt(input_shape.size()); | |||
| int output_num = SizeToInt(AnfAlgo::GetOutputTensorNum(kernel_node)); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but Split needs 1 input."; | |||
| return false; | |||
| } | |||
| if (dims == 0) { | |||
| MS_LOG(ERROR) << "Input dims is " << dims << ", scalar is not supported."; | |||
| return false; | |||
| } | |||
| if (axis_ < -dims || axis_ >= dims) { | |||
| MS_LOG(ERROR) << "Attr axis " << axis_ << " must be in " << -dims << "~" << dims; | |||
| return false; | |||
| } | |||
| if (output_num_ > SizeToInt(input_shape[axis_])) { | |||
| MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_]; | |||
| return false; | |||
| } | |||
| if (input_shape[axis_] % output_num_ != 0) { | |||
| MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must be divided by" << input_shape[axis_]; | |||
| return false; | |||
| } | |||
| if (output_num_ != output_num) { | |||
| MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| int axis_; | |||
| int output_num_; | |||
| int input_size_; | |||
| int axis_step_; | |||
| int all_size_before_axis_; | |||
| int all_size_axis_; | |||
| std::unique_ptr<T *[]> outputs_host_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H | |||
| @@ -19,90 +19,51 @@ | |||
| #include <cuda_runtime.h> | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" | |||
| template <typename T> | |||
| __global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| int n = pos / (w1 + w2); | |||
| int m = pos % (w1 + w2); | |||
| output[pos] = m >= w1 ? input_2[n * w2 + m - w1] : input_1[n * w1 + m]; | |||
| __global__ void Concat(const int size, const int input_num, | |||
| const int all_size_before_axis, const int all_size_axis, | |||
| int* len_axis, T** inputs, T* output) { | |||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| int num = pos % all_size_before_axis / all_size_axis; | |||
| int block = -1; | |||
| int axis_inc = 0; | |||
| int block_len = 0; | |||
| for (int i = 0; i < input_num; i++) { | |||
| if (axis_inc <= num) { | |||
| block++; | |||
| axis_inc += len_axis[i]; | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| block_len = len_axis[block]; | |||
| axis_inc -= len_axis[block]; | |||
| int block_pos = pos / all_size_before_axis * block_len * all_size_axis + | |||
| (num - axis_inc) * all_size_axis + pos % all_size_axis;; | |||
| output[pos] = inputs[block][block_pos]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void Concat(const size_t size, const int w1, const int w2, const int w3, | |||
| const T* input_1, const T* input_2, const T* input_3, T* output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| int n = pos / (w1 + w2 + w3); | |||
| int m = pos % (w1 + w2 + w3); | |||
| output[pos] = m < w1 ? input_1[n * w1 + m] : | |||
| m < w1 + w2 ? input_2[n * w2 + m - w1] : | |||
| input_3[n * w3 + m - w1 - w2]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| int n = pos / (w1 + w2 + w3 + w4); | |||
| int m = pos % (w1 + w2 + w3 + w4); | |||
| output[pos] = m < w1 ? input_1[n * w1 + m] : | |||
| m < w1 + w2 ? input_2[n * w2 + m - w1]: | |||
| m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]: | |||
| input_4[n * w4 + m - w1 - w2 - w3]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, | |||
| cudaStream_t cuda_stream) { | |||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, input_1, input_2, output); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||
| const T* input_1, const T* input_2, const T* input_3, T* output, | |||
| void ConcatKernel(const int size, const int input_num, | |||
| const int all_size_before_axis, const int all_size_axis, | |||
| int* len_axis, T** inputs, T* output, | |||
| cudaStream_t cuda_stream) { | |||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, input_1, input_2, input_3, output); | |||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num, | |||
| all_size_before_axis, all_size_axis, | |||
| len_axis, inputs, output); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, | |||
| cudaStream_t cuda_stream) { | |||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, w4, input_1, | |||
| input_2, input_3, input_4, output); | |||
| return; | |||
| } | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, | |||
| float* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, | |||
| int* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, | |||
| half* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||
| const float* input_1, const float* input_2, const float* input_3, | |||
| float* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||
| const int* input_1, const int* input_2, const int* input_3, | |||
| int* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||
| const half* input_1, const half* input_2, const half* input_3, | |||
| half* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const float* input_1, const float* input_2, const float* input_3, const float* input_4, | |||
| float* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const int* input_1, const int* input_2, const int* input_3, const int* input_4, | |||
| int* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const half* input_1, const half* input_2, const half* input_3, const half* input_4, | |||
| half* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const int size, const int input_num, | |||
| const int all_size_before_axis, const int all_size_axis, | |||
| int* len_axis, float** inputs, float* output, | |||
| cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const int size, const int input_num, | |||
| const int all_size_before_axis, const int all_size_axis, | |||
| int* len_axis, int** inputs, int* output, | |||
| cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const int size, const int input_num, | |||
| const int all_size_before_axis, const int all_size_axis, | |||
| int* len_axis, half** inputs, half* output, | |||
| cudaStream_t cuda_stream); | |||
| @@ -19,13 +19,8 @@ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||
| const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, | |||
| void ConcatKernel(const int size, const int input_num, | |||
| const int all_size_before_axis, const int all_size_axis, | |||
| int* len_axis, T** inputs, T* output, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <stdio.h> | |||
| #include <stdint.h> | |||
| #include <cuda_runtime.h> | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh" | |||
| template <typename T> | |||
| __global__ void Split(const int size, const int axis_step, const int all_size_before_axis, | |||
| const int all_size_axis, const T* input, T** outputs) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| int num = pos % all_size_before_axis / all_size_axis; | |||
| int block = num / axis_step; | |||
| int block_pos = pos / all_size_before_axis * axis_step * all_size_axis + | |||
| num % axis_step * all_size_axis + pos % all_size_axis; | |||
| outputs[block][block_pos] = input[pos]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, | |||
| const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream) { | |||
| Split<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, axis_step, all_size_before_axis, | |||
| all_size_axis, input, outputs); | |||
| return; | |||
| } | |||
| template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, | |||
| const int all_size_axis, const float* input, float** outputs, | |||
| cudaStream_t cuda_stream); | |||
| template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, | |||
| const int all_size_axis, const int* input, int** outputs, | |||
| cudaStream_t cuda_stream); | |||
| template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, | |||
| const int all_size_axis, const half* input, half** outputs, | |||
| cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,24 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, | |||
| const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ | |||
| @@ -0,0 +1,58 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| class Net(nn.Cell): | |||
| def __init__(self, axis=0, out_nums=1): | |||
| super(Net, self).__init__() | |||
| self.split = P.Split(axis, out_nums) | |||
| def construct(self, x): | |||
| return self.split(x) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_split(): | |||
| x = np.array([[[1, -1, 1], [2, -2, 2]], | |||
| [[3, -3, 3], [4, -4, 4]], | |||
| [[5, -5, 5], [6, -6, 6]]]).astype(np.float32) | |||
| split_op = Net(0, 3) | |||
| outputs = split_op(Tensor(x)) | |||
| for i, out in enumerate(outputs): | |||
| assert (out.asnumpy() == x[i]).all() | |||
| def test_split_4d(): | |||
| x_np = np.random.randn(2, 6, 4, 4).astype(np.float32) | |||
| y = np.split(x_np, 3, axis=1) | |||
| split_op = Net(1, 3) | |||
| outputs = split_op(Tensor(x_np)) | |||
| for i, out in enumerate(outputs): | |||
| assert (out.asnumpy() == y[i]).all() | |||