Browse Source

!25341 Add concat op dtype.

Merge pull request !25341 from liangchenghui/add_concat_dtype
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
475777af3f
2 changed files with 42 additions and 5 deletions
  1. +21
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc
  2. +21
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu

+ 21
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc View File

@@ -24,18 +24,36 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ConcatV2GpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(Concat,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ConcatV2GpuFwdKernel, int)
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ConcatV2GpuFwdKernel, half)

MS_REG_GPU_KERNEL_ONE(Concat,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ConcatV2GpuFwdKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(Concat,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ConcatV2GpuFwdKernel, int)
MS_REG_GPU_KERNEL_ONE(Concat,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
ConcatV2GpuFwdKernel, short) // NOLINT
MS_REG_GPU_KERNEL_ONE(Concat,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
ConcatV2GpuFwdKernel, char)

MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
ConcatV2GpuFwdKernel, uint64_t)
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
ConcatV2GpuFwdKernel, uint)
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
ConcatV2GpuFwdKernel, uint16_t)
MS_REG_GPU_KERNEL_ONE(Concat,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ConcatV2GpuFwdKernel, uchar)

MS_REG_GPU_KERNEL_ONE(Concat,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ConcatV2GpuFwdKernel, bool)


+ 21
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu View File

@@ -57,17 +57,36 @@ template void ConcatKernel(const size_t size, const int input_num, const int all
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, float **inputs, float *output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, int **inputs, int *output, cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, half **inputs, half *output,
cudaStream_t cuda_stream);

template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, int64_t **inputs, int64_t *output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, int **inputs, int *output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, short **inputs, short *output, // NOLINT
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, char **inputs, char *output,
cudaStream_t cuda_stream);

template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, uint64_t **inputs, uint64_t *output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, uint32_t **inputs, uint32_t *output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, uint16_t **inputs, uint16_t *output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, unsigned char **inputs, unsigned char *output,
cudaStream_t cuda_stream);

template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, bool **inputs, bool *output,
cudaStream_t cuda_stream);

Loading…
Cancel
Save