| @@ -19,15 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| Concat, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ConcatV2GpuFwdKernel, float) | ConcatV2GpuFwdKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(Concat, | |||||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ConcatV2GpuFwdKernel, int) | |||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| Concat, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ConcatV2GpuFwdKernel, int) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| Concat, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| ConcatV2GpuFwdKernel, half) | ConcatV2GpuFwdKernel, half) | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,7 +27,7 @@ namespace kernel { | |||||
| template <typename T> | template <typename T> | ||||
| class ConcatV2GpuFwdKernel : public GpuKernel { | class ConcatV2GpuFwdKernel : public GpuKernel { | ||||
| public: | public: | ||||
| ConcatV2GpuFwdKernel() : axis_(0), input0_size_(0), input1_size_(0), output_size_(0), workspace_size_(0) {} | |||||
| ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {} | |||||
| ~ConcatV2GpuFwdKernel() override = default; | ~ConcatV2GpuFwdKernel() override = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||
| @@ -35,12 +35,32 @@ class ConcatV2GpuFwdKernel : public GpuKernel { | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | ||||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | ||||
| T *input_0 = GetDeviceAddress<T>(inputs, 0); | |||||
| T *input_1 = GetDeviceAddress<T>(inputs, 1); | |||||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||||
| if (inputs.size() == 2) { | |||||
| T *input_0 = GetDeviceAddress<T>(inputs, 0); | |||||
| T *input_1 = GetDeviceAddress<T>(inputs, 1); | |||||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||||
| ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| } | |||||
| if (inputs.size() == 3) { | |||||
| T *input_0 = GetDeviceAddress<T>(inputs, 0); | |||||
| T *input_1 = GetDeviceAddress<T>(inputs, 1); | |||||
| T *input_2 = GetDeviceAddress<T>(inputs, 2); | |||||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||||
| ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| } | |||||
| CalConcatV2(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| if (inputs.size() == 4) { | |||||
| T *input_0 = GetDeviceAddress<T>(inputs, 0); | |||||
| T *input_1 = GetDeviceAddress<T>(inputs, 1); | |||||
| T *input_2 = GetDeviceAddress<T>(inputs, 2); | |||||
| T *input_3 = GetDeviceAddress<T>(inputs, 3); | |||||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||||
| ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| @@ -48,44 +68,44 @@ class ConcatV2GpuFwdKernel : public GpuKernel { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| input0_size_ = sizeof(T); | |||||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||||
| input0_size_ *= input_shape[i]; | |||||
| } | |||||
| auto input_shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| input1_size_ = sizeof(T); | |||||
| for (size_t i = 0; i < input_shape1.size(); i++) { | |||||
| input1_size_ *= input_shape1[i]; | |||||
| } | |||||
| output_size_ = input0_size_ + input1_size_; | |||||
| axis_ = GetAttr<int>(kernel_node, "axis"); | axis_ = GetAttr<int>(kernel_node, "axis"); | ||||
| if (axis_ < 0) { | if (axis_ < 0) { | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| axis_ += SizeToInt(input_shape.size()); | axis_ += SizeToInt(input_shape.size()); | ||||
| } | } | ||||
| w_[0] = 1; | |||||
| w_[1] = 1; | |||||
| for (size_t i = IntToSize(axis_); i < input_shape.size(); i++) { | |||||
| w_[0] *= SizeToInt(input_shape[i]); | |||||
| w_[1] *= SizeToInt(input_shape1[i]); | |||||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| for (size_t i = 0; i < input_num; i++) { | |||||
| auto input_size = sizeof(T); | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); | |||||
| for (size_t j = 0; j < input_shape.size(); j++) { | |||||
| input_size *= SizeToInt(input_shape[j]); | |||||
| if (j >= IntToSize(axis_)) { | |||||
| w_[i] *= SizeToInt(input_shape[j]); | |||||
| } | |||||
| input_size_list_.push_back(input_size); | |||||
| } | |||||
| } | } | ||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| output_size_ = sizeof(T); | |||||
| for (size_t i = 0; i < output_shape.size(); i++) { | |||||
| output_size_ *= output_shape[i]; | |||||
| } | |||||
| output_size_list_.push_back(output_size_); | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| protected: | protected: | ||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input0_size_); | |||||
| input_size_list_.push_back(input1_size_); | |||||
| output_size_list_.push_back(output_size_); | |||||
| } | |||||
| void InitSizeLists() override {} | |||||
| private: | private: | ||||
| bool CheckParam(const CNodePtr &kernel_node) { | bool CheckParam(const CNodePtr &kernel_node) { | ||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 2) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs 2 inputs."; | |||||
| if (input_num < 2 || input_num > 4) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | ||||
| @@ -95,16 +115,12 @@ class ConcatV2GpuFwdKernel : public GpuKernel { | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| int w_[2] = {1}; | |||||
| int w_[4] = {1, 1, 1, 1}; | |||||
| int axis_; | int axis_; | ||||
| size_t output_size_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| size_t input0_size_; | |||||
| size_t input1_size_; | |||||
| size_t output_size_; | |||||
| size_t workspace_size_; | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <cuda_runtime.h> | #include <cuda_runtime.h> | ||||
| #include "kernel/gpu/cuda_impl/concatv2_impl.cuh" | #include "kernel/gpu/cuda_impl/concatv2_impl.cuh" | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void ConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { | |||||
| __global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | ||||
| int n = pos / (w1 + w2); | int n = pos / (w1 + w2); | ||||
| int m = pos % (w1 + w2); | int m = pos % (w1 + w2); | ||||
| @@ -29,16 +29,80 @@ __global__ void ConcatV2(const size_t size, const int w1, const int w2, const T* | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void CalConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, | |||||
| __global__ void Concat(const size_t size, const int w1, const int w2, const int w3, | |||||
| const T* input_1, const T* input_2, const T* input_3, T* output) { | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||||
| int n = pos / (w1 + w2 + w3); | |||||
| int m = pos % (w1 + w2 + w3); | |||||
| output[pos] = m < w1 ? input_1[n * w1 + m] : | |||||
| m < w1 + w2 ? input_2[n * w2 + m - w1] : | |||||
| input_3[n * w3 + m - w1 - w2]; | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||||
| const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) { | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||||
| int n = pos / (w1 + w2 + w3 + w4); | |||||
| int m = pos % (w1 + w2 + w3 + w4); | |||||
| output[pos] = m < w1 ? input_1[n * w1 + m] : | |||||
| m < w1 + w2 ? input_2[n * w2 + m - w1]: | |||||
| m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]: | |||||
| input_4[n * w4 + m - w1 - w2 - w3]; | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, | |||||
| cudaStream_t cuda_stream) { | cudaStream_t cuda_stream) { | ||||
| ConcatV2<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, input_1, input_2, output); | |||||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, input_1, input_2, output); | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||||
| const T* input_1, const T* input_2, const T* input_3, T* output, | |||||
| cudaStream_t cuda_stream) { | |||||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, input_1, input_2, input_3, output); | |||||
| return; | return; | ||||
| } | } | ||||
| template void CalConcatV2(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, | |||||
| float* output, cudaStream_t cuda_stream); | |||||
| template void CalConcatV2(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, | |||||
| int* output, cudaStream_t cuda_stream); | |||||
| template void CalConcatV2(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, | |||||
| half* output, cudaStream_t cuda_stream); | |||||
| template <typename T> | |||||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||||
| const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, | |||||
| cudaStream_t cuda_stream) { | |||||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, w4, input_1, | |||||
| input_2, input_3, input_4, output); | |||||
| return; | |||||
| } | |||||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, | |||||
| float* output, cudaStream_t cuda_stream); | |||||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, | |||||
| int* output, cudaStream_t cuda_stream); | |||||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, | |||||
| half* output, cudaStream_t cuda_stream); | |||||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||||
| const float* input_1, const float* input_2, const float* input_3, | |||||
| float* output, cudaStream_t cuda_stream); | |||||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||||
| const int* input_1, const int* input_2, const int* input_3, | |||||
| int* output, cudaStream_t cuda_stream); | |||||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||||
| const half* input_1, const half* input_2, const half* input_3, | |||||
| half* output, cudaStream_t cuda_stream); | |||||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||||
| const float* input_1, const float* input_2, const float* input_3, const float* input_4, | |||||
| float* output, cudaStream_t cuda_stream); | |||||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||||
| const int* input_1, const int* input_2, const int* input_3, const int* input_4, | |||||
| int* output, cudaStream_t cuda_stream); | |||||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||||
| const half* input_1, const half* input_2, const half* input_3, const half* input_4, | |||||
| half* output, cudaStream_t cuda_stream); | |||||
| @@ -19,7 +19,13 @@ | |||||
| #include "device/gpu/cuda_common.h" | #include "device/gpu/cuda_common.h" | ||||
| template <typename T> | template <typename T> | ||||
| void CalConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, | |||||
| cudaStream_t cuda_stream); | |||||
| void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, | |||||
| cudaStream_t cuda_stream); | |||||
| template <typename T> | |||||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||||
| const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream); | |||||
| template <typename T> | |||||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||||
| const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ | ||||
| @@ -113,3 +113,62 @@ def test_axis21(): | |||||
| [2., 3., 3., 4., 5.]] | [2., 3., 3., 4., 5.]] | ||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| print(output) | print(output) | ||||
| class Concat3INet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Concat3INet, self).__init__() | |||||
| self.cat = P.Concat(axis=1) | |||||
| def construct(self, x1, x2, x3): | |||||
| return self.cat((x1, x2, x3)) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_concat_3i(): | |||||
| cat = Concat3INet() | |||||
| x1_np = np.random.randn(32, 4, 224, 224).astype(np.float32) | |||||
| x2_np = np.random.randn(32, 8, 224, 224).astype(np.float32) | |||||
| x3_np = np.random.randn(32, 10, 224, 224).astype(np.float32) | |||||
| output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1) | |||||
| x1_ms = Tensor(x1_np) | |||||
| x2_ms = Tensor(x2_np) | |||||
| x3_ms = Tensor(x3_np) | |||||
| output_ms = cat(x1_ms, x2_ms, x3_ms) | |||||
| error = np.ones(shape=output_np.shape) * 10e-6 | |||||
| diff = output_ms.asnumpy() - output_np | |||||
| assert np.all(diff < error) | |||||
| class Concat4INet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Concat4INet, self).__init__() | |||||
| self.cat = P.Concat(axis=1) | |||||
| def construct(self, x1, x2, x3, x4): | |||||
| return self.cat((x1, x2, x3, x4)) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_concat_4i(): | |||||
| cat = Concat4INet() | |||||
| x1_np = np.random.randn(32, 4, 224, 224).astype(np.float32) | |||||
| x2_np = np.random.randn(32, 8, 224, 224).astype(np.float32) | |||||
| x3_np = np.random.randn(32, 10, 224, 224).astype(np.float32) | |||||
| x4_np = np.random.randn(32, 5, 224, 224).astype(np.float32) | |||||
| output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1) | |||||
| x1_ms = Tensor(x1_np) | |||||
| x2_ms = Tensor(x2_np) | |||||
| x3_ms = Tensor(x3_np) | |||||
| x4_ms = Tensor(x4_np) | |||||
| output_ms = cat(x1_ms, x2_ms, x3_ms, x4_ms) | |||||
| error = np.ones(shape=output_np.shape) * 10e-6 | |||||
| diff = output_ms.asnumpy() - output_np | |||||
| assert np.all(diff < error) | |||||