| @@ -19,15 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Concat, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ConcatV2GpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Concat, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ConcatV2GpuFwdKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Concat, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ConcatV2GpuFwdKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Concat, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ConcatV2GpuFwdKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -27,7 +27,7 @@ namespace kernel { | |||
| template <typename T> | |||
| class ConcatV2GpuFwdKernel : public GpuKernel { | |||
| public: | |||
| ConcatV2GpuFwdKernel() : axis_(0), input0_size_(0), input1_size_(0), output_size_(0), workspace_size_(0) {} | |||
| ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {} | |||
| ~ConcatV2GpuFwdKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| @@ -35,12 +35,32 @@ class ConcatV2GpuFwdKernel : public GpuKernel { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||
| T *input_0 = GetDeviceAddress<T>(inputs, 0); | |||
| T *input_1 = GetDeviceAddress<T>(inputs, 1); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| if (inputs.size() == 2) { | |||
| T *input_0 = GetDeviceAddress<T>(inputs, 0); | |||
| T *input_1 = GetDeviceAddress<T>(inputs, 1); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| if (inputs.size() == 3) { | |||
| T *input_0 = GetDeviceAddress<T>(inputs, 0); | |||
| T *input_1 = GetDeviceAddress<T>(inputs, 1); | |||
| T *input_2 = GetDeviceAddress<T>(inputs, 2); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| CalConcatV2(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| if (inputs.size() == 4) { | |||
| T *input_0 = GetDeviceAddress<T>(inputs, 0); | |||
| T *input_1 = GetDeviceAddress<T>(inputs, 1); | |||
| T *input_2 = GetDeviceAddress<T>(inputs, 2); | |||
| T *input_3 = GetDeviceAddress<T>(inputs, 3); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| @@ -48,44 +68,44 @@ class ConcatV2GpuFwdKernel : public GpuKernel { | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| input0_size_ = sizeof(T); | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input0_size_ *= input_shape[i]; | |||
| } | |||
| auto input_shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| input1_size_ = sizeof(T); | |||
| for (size_t i = 0; i < input_shape1.size(); i++) { | |||
| input1_size_ *= input_shape1[i]; | |||
| } | |||
| output_size_ = input0_size_ + input1_size_; | |||
| axis_ = GetAttr<int>(kernel_node, "axis"); | |||
| if (axis_ < 0) { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| axis_ += SizeToInt(input_shape.size()); | |||
| } | |||
| w_[0] = 1; | |||
| w_[1] = 1; | |||
| for (size_t i = IntToSize(axis_); i < input_shape.size(); i++) { | |||
| w_[0] *= SizeToInt(input_shape[i]); | |||
| w_[1] *= SizeToInt(input_shape1[i]); | |||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t i = 0; i < input_num; i++) { | |||
| auto input_size = sizeof(T); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); | |||
| for (size_t j = 0; j < input_shape.size(); j++) { | |||
| input_size *= SizeToInt(input_shape[j]); | |||
| if (j >= IntToSize(axis_)) { | |||
| w_[i] *= SizeToInt(input_shape[j]); | |||
| } | |||
| input_size_list_.push_back(input_size); | |||
| } | |||
| } | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| output_size_ = sizeof(T); | |||
| for (size_t i = 0; i < output_shape.size(); i++) { | |||
| output_size_ *= output_shape[i]; | |||
| } | |||
| output_size_list_.push_back(output_size_); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input0_size_); | |||
| input_size_list_.push_back(input1_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| } | |||
| void InitSizeLists() override {} | |||
| private: | |||
| bool CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs 2 inputs."; | |||
| if (input_num < 2 || input_num > 4) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| @@ -95,16 +115,12 @@ class ConcatV2GpuFwdKernel : public GpuKernel { | |||
| } | |||
| return true; | |||
| } | |||
| int w_[2] = {1}; | |||
| int w_[4] = {1, 1, 1, 1}; | |||
| int axis_; | |||
| size_t output_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| size_t input0_size_; | |||
| size_t input1_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -19,7 +19,7 @@ | |||
| #include <cuda_runtime.h> | |||
| #include "kernel/gpu/cuda_impl/concatv2_impl.cuh" | |||
| template <typename T> | |||
| __global__ void ConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { | |||
| __global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| int n = pos / (w1 + w2); | |||
| int m = pos % (w1 + w2); | |||
| @@ -29,16 +29,80 @@ __global__ void ConcatV2(const size_t size, const int w1, const int w2, const T* | |||
| } | |||
| template <typename T> | |||
| void CalConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, | |||
| __global__ void Concat(const size_t size, const int w1, const int w2, const int w3, | |||
| const T* input_1, const T* input_2, const T* input_3, T* output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| int n = pos / (w1 + w2 + w3); | |||
| int m = pos % (w1 + w2 + w3); | |||
| output[pos] = m < w1 ? input_1[n * w1 + m] : | |||
| m < w1 + w2 ? input_2[n * w2 + m - w1] : | |||
| input_3[n * w3 + m - w1 - w2]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| int n = pos / (w1 + w2 + w3 + w4); | |||
| int m = pos % (w1 + w2 + w3 + w4); | |||
| output[pos] = m < w1 ? input_1[n * w1 + m] : | |||
| m < w1 + w2 ? input_2[n * w2 + m - w1]: | |||
| m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]: | |||
| input_4[n * w4 + m - w1 - w2 - w3]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, | |||
| cudaStream_t cuda_stream) { | |||
| ConcatV2<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, input_1, input_2, output); | |||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, input_1, input_2, output); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||
| const T* input_1, const T* input_2, const T* input_3, T* output, | |||
| cudaStream_t cuda_stream) { | |||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, input_1, input_2, input_3, output); | |||
| return; | |||
| } | |||
| template void CalConcatV2(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, | |||
| float* output, cudaStream_t cuda_stream); | |||
| template void CalConcatV2(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, | |||
| int* output, cudaStream_t cuda_stream); | |||
| template void CalConcatV2(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, | |||
| half* output, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, | |||
| cudaStream_t cuda_stream) { | |||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, w4, input_1, | |||
| input_2, input_3, input_4, output); | |||
| return; | |||
| } | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, | |||
| float* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, | |||
| int* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, | |||
| half* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||
| const float* input_1, const float* input_2, const float* input_3, | |||
| float* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||
| const int* input_1, const int* input_2, const int* input_3, | |||
| int* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||
| const half* input_1, const half* input_2, const half* input_3, | |||
| half* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const float* input_1, const float* input_2, const float* input_3, const float* input_4, | |||
| float* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const int* input_1, const int* input_2, const int* input_3, const int* input_4, | |||
| int* output, cudaStream_t cuda_stream); | |||
| template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const half* input_1, const half* input_2, const half* input_3, const half* input_4, | |||
| half* output, cudaStream_t cuda_stream); | |||
| @@ -19,7 +19,13 @@ | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, | |||
| cudaStream_t cuda_stream); | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, | |||
| const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, | |||
| const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ | |||
| @@ -113,3 +113,62 @@ def test_axis21(): | |||
| [2., 3., 3., 4., 5.]] | |||
| assert (output.asnumpy() == expect).all() | |||
| print(output) | |||
| class Concat3INet(nn.Cell): | |||
| def __init__(self): | |||
| super(Concat3INet, self).__init__() | |||
| self.cat = P.Concat(axis=1) | |||
| def construct(self, x1, x2, x3): | |||
| return self.cat((x1, x2, x3)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_concat_3i(): | |||
| cat = Concat3INet() | |||
| x1_np = np.random.randn(32, 4, 224, 224).astype(np.float32) | |||
| x2_np = np.random.randn(32, 8, 224, 224).astype(np.float32) | |||
| x3_np = np.random.randn(32, 10, 224, 224).astype(np.float32) | |||
| output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1) | |||
| x1_ms = Tensor(x1_np) | |||
| x2_ms = Tensor(x2_np) | |||
| x3_ms = Tensor(x3_np) | |||
| output_ms = cat(x1_ms, x2_ms, x3_ms) | |||
| error = np.ones(shape=output_np.shape) * 10e-6 | |||
| diff = output_ms.asnumpy() - output_np | |||
| assert np.all(diff < error) | |||
| class Concat4INet(nn.Cell): | |||
| def __init__(self): | |||
| super(Concat4INet, self).__init__() | |||
| self.cat = P.Concat(axis=1) | |||
| def construct(self, x1, x2, x3, x4): | |||
| return self.cat((x1, x2, x3, x4)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_concat_4i(): | |||
| cat = Concat4INet() | |||
| x1_np = np.random.randn(32, 4, 224, 224).astype(np.float32) | |||
| x2_np = np.random.randn(32, 8, 224, 224).astype(np.float32) | |||
| x3_np = np.random.randn(32, 10, 224, 224).astype(np.float32) | |||
| x4_np = np.random.randn(32, 5, 224, 224).astype(np.float32) | |||
| output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1) | |||
| x1_ms = Tensor(x1_np) | |||
| x2_ms = Tensor(x2_np) | |||
| x3_ms = Tensor(x3_np) | |||
| x4_ms = Tensor(x4_np) | |||
| output_ms = cat(x1_ms, x2_ms, x3_ms, x4_ms) | |||
| error = np.ones(shape=output_np.shape) * 10e-6 | |||
| diff = output_ms.asnumpy() - output_np | |||
| assert np.all(diff < error) | |||