| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -18,6 +18,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||||
| ConcatV2GpuFwdKernel, double) | |||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| ConcatV2GpuFwdKernel, float) | ConcatV2GpuFwdKernel, float) | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -133,4 +133,4 @@ class ConcatV2GpuFwdKernel : public GpuKernel { | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_ | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -19,9 +19,8 @@ | |||||
| #include <cuda_runtime.h> | #include <cuda_runtime.h> | ||||
| #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" | #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void Concat(const size_t size, const int input_num, | |||||
| const int all_size_before_axis, const int all_size_axis, | |||||
| int* len_axis, T** inputs, T* output) { | |||||
| __global__ void Concat(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis, | |||||
| int *len_axis, T **inputs, T *output) { | |||||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | ||||
| int num = pos % all_size_before_axis / all_size_axis; | int num = pos % all_size_before_axis / all_size_axis; | ||||
| int block = -1; | int block = -1; | ||||
| @@ -37,45 +36,38 @@ __global__ void Concat(const size_t size, const int input_num, | |||||
| } | } | ||||
| block_len = len_axis[block]; | block_len = len_axis[block]; | ||||
| axis_inc -= len_axis[block]; | axis_inc -= len_axis[block]; | ||||
| int block_pos = pos / all_size_before_axis * block_len * all_size_axis + | |||||
| (num - axis_inc) * all_size_axis + pos % all_size_axis;; | |||||
| int block_pos = | |||||
| pos / all_size_before_axis * block_len * all_size_axis + (num - axis_inc) * all_size_axis + pos % all_size_axis; | |||||
| output[pos] = inputs[block][block_pos]; | output[pos] = inputs[block][block_pos]; | ||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void ConcatKernel(const size_t size, const int input_num, | |||||
| const int all_size_before_axis, const int all_size_axis, | |||||
| int* len_axis, T** inputs, T* output, | |||||
| cudaStream_t cuda_stream) { | |||||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num, | |||||
| all_size_before_axis, all_size_axis, | |||||
| void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis, | |||||
| int *len_axis, T **inputs, T *output, cudaStream_t cuda_stream) { | |||||
| Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num, all_size_before_axis, all_size_axis, | |||||
| len_axis, inputs, output); | len_axis, inputs, output); | ||||
| return; | return; | ||||
| } | } | ||||
| template void ConcatKernel(const size_t size, const int input_num, | |||||
| const int all_size_before_axis, const int all_size_axis, | |||||
| int* len_axis, float** inputs, float* output, | |||||
| template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, | |||||
| const int all_size_axis, int *len_axis, double **inputs, double *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void ConcatKernel(const size_t size, const int input_num, | |||||
| const int all_size_before_axis, const int all_size_axis, | |||||
| int* len_axis, int** inputs, int* output, | |||||
| template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, | |||||
| const int all_size_axis, int *len_axis, float **inputs, float *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void ConcatKernel(const size_t size, const int input_num, | |||||
| const int all_size_before_axis, const int all_size_axis, | |||||
| int* len_axis, half** inputs, half* output, | |||||
| template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, | |||||
| const int all_size_axis, int *len_axis, int **inputs, int *output, cudaStream_t cuda_stream); | |||||
| template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, | |||||
| const int all_size_axis, int *len_axis, half **inputs, half *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void ConcatKernel(const size_t size, const int input_num, | |||||
| const int all_size_before_axis, const int all_size_axis, | |||||
| int* len_axis, short** inputs, short* output, // NOLINT | |||||
| template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, | |||||
| const int all_size_axis, int *len_axis, short **inputs, short *output, // NOLINT | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void ConcatKernel(const size_t size, const int input_num, | |||||
| const int all_size_before_axis, const int all_size_axis, | |||||
| int* len_axis, unsigned char** inputs, unsigned char* output, | |||||
| template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, | |||||
| const int all_size_axis, int *len_axis, unsigned char **inputs, unsigned char *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void ConcatKernel(const size_t size, const int input_num, | |||||
| const int all_size_before_axis, const int all_size_axis, | |||||
| int* len_axis, bool** inputs, bool* output, | |||||
| template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, | |||||
| const int all_size_axis, int *len_axis, bool **inputs, bool *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -14,13 +14,11 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T> | template <typename T> | ||||
| void ConcatKernel(const size_t size, const int input_num, | |||||
| const int all_size_before_axis, const int all_size_axis, | |||||
| int* len_axis, T** inputs, T* output, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ | |||||
| void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis, | |||||
| int *len_axis, T **inputs, T *output, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_ | |||||
| @@ -2162,7 +2162,7 @@ class Concat(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, axis=0): | def __init__(self, axis=0): | ||||
| """Initialize Tile""" | |||||
| """Initialize Concat""" | |||||
| validator.check_value_type("axis", axis, [int], self.name) | validator.check_value_type("axis", axis, [int], self.name) | ||||
| def __infer__(self, input_x): | def __infer__(self, input_x): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -49,9 +49,14 @@ def axis32(nptype): | |||||
| [1., 2., 3.]], | [1., 2., 3.]], | ||||
| [[2., 4., 5.], | [[2., 4., 5.], | ||||
| [3., 6., 7.]]]).astype(nptype) | [3., 6., 7.]]]).astype(nptype) | ||||
| print(output) | |||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_axis32_float64(): | |||||
| axis32(np.float64) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| @@ -106,8 +111,12 @@ def axis43(nptype): | |||||
| [[12., 13., 18., 19., 20.], | [[12., 13., 18., 19., 20.], | ||||
| [14., 15., 21., 22., 23.]]]]).astype(nptype) | [14., 15., 21., 22., 23.]]]]).astype(nptype) | ||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| print(output) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_axis43_float64(): | |||||
| axis43(np.float64) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @@ -155,7 +164,12 @@ def axis21(nptype): | |||||
| expect = np.array([[0., 1., 0., 1., 2.], | expect = np.array([[0., 1., 0., 1., 2.], | ||||
| [2., 3., 3., 4., 5.]]).astype(nptype) | [2., 3., 3., 4., 5.]]).astype(nptype) | ||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| print(output) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_axis21_float64(): | |||||
| axis21(np.float64) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @@ -208,6 +222,12 @@ def concat_3i(nptype): | |||||
| diff = output_ms.asnumpy() - output_np | diff = output_ms.asnumpy() - output_np | ||||
| assert np.all(diff < error) | assert np.all(diff < error) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_concat_3i_float64(): | |||||
| concat_3i(np.float64) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| @@ -273,6 +293,12 @@ def concat_4i(nptype): | |||||
| diff = output_ms.asnumpy() - output_np | diff = output_ms.asnumpy() - output_np | ||||
| assert np.all(diff < error) | assert np.all(diff < error) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_concat_4i_float64(): | |||||
| concat_4i(np.float64) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||