|
|
@@ -18,69 +18,54 @@ |
|
|
|
|
|
|
|
|
namespace mindspore { |
|
|
namespace mindspore { |
|
|
namespace kernel { |
|
|
namespace kernel { |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, |
|
|
Gather, |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), |
|
|
GatherV2GpuFwdKernel, double, int) |
|
|
GatherV2GpuFwdKernel, double, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, |
|
|
Gather, |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), |
|
|
GatherV2GpuFwdKernel, double, int64_t) |
|
|
GatherV2GpuFwdKernel, double, int64_t) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, |
|
|
Gather, |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), |
|
|
GatherV2GpuFwdKernel, float, int) |
|
|
GatherV2GpuFwdKernel, float, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, |
|
|
Gather, |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), |
|
|
GatherV2GpuFwdKernel, float, int64_t) |
|
|
GatherV2GpuFwdKernel, float, int64_t) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, |
|
|
Gather, |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), |
|
|
GatherV2GpuFwdKernel, half, int) |
|
|
GatherV2GpuFwdKernel, half, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, |
|
|
Gather, |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), |
|
|
GatherV2GpuFwdKernel, half, int64_t) |
|
|
GatherV2GpuFwdKernel, half, int64_t) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), |
|
|
GatherV2GpuFwdKernel, int, int) |
|
|
GatherV2GpuFwdKernel, int, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), |
|
|
GatherV2GpuFwdKernel, int, int64_t) |
|
|
GatherV2GpuFwdKernel, int, int64_t) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), |
|
|
GatherV2GpuFwdKernel, int16_t, int) |
|
|
GatherV2GpuFwdKernel, int16_t, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), |
|
|
GatherV2GpuFwdKernel, int16_t, int64_t) |
|
|
GatherV2GpuFwdKernel, int16_t, int64_t) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), |
|
|
GatherV2GpuFwdKernel, int8_t, int) |
|
|
GatherV2GpuFwdKernel, int8_t, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), |
|
|
GatherV2GpuFwdKernel, int8_t, int64_t) |
|
|
GatherV2GpuFwdKernel, int8_t, int64_t) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), |
|
|
GatherV2GpuFwdKernel, uint8_t, int) |
|
|
GatherV2GpuFwdKernel, uint8_t, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), |
|
|
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), |
|
|
GatherV2GpuFwdKernel, uint8_t, int64_t) |
|
|
GatherV2GpuFwdKernel, uint8_t, int64_t) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
KernelAttr() |
|
|
KernelAttr() |
|
|
.AddInputAttr(kNumberTypeFloat32) |
|
|
.AddInputAttr(kNumberTypeFloat32) |
|
|
@@ -88,7 +73,6 @@ MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
.AddInputAttr(kNumberTypeInt64) |
|
|
.AddInputAttr(kNumberTypeInt64) |
|
|
.AddOutputAttr(kNumberTypeFloat32), |
|
|
.AddOutputAttr(kNumberTypeFloat32), |
|
|
GatherV2GpuFwdKernel, float, int) |
|
|
GatherV2GpuFwdKernel, float, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
KernelAttr() |
|
|
KernelAttr() |
|
|
.AddInputAttr(kNumberTypeFloat32) |
|
|
.AddInputAttr(kNumberTypeFloat32) |
|
|
@@ -96,7 +80,6 @@ MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
.AddInputAttr(kNumberTypeInt64) |
|
|
.AddInputAttr(kNumberTypeInt64) |
|
|
.AddOutputAttr(kNumberTypeFloat32), |
|
|
.AddOutputAttr(kNumberTypeFloat32), |
|
|
GatherV2GpuFwdKernel, float, int64_t) |
|
|
GatherV2GpuFwdKernel, float, int64_t) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
KernelAttr() |
|
|
KernelAttr() |
|
|
.AddInputAttr(kNumberTypeFloat16) |
|
|
.AddInputAttr(kNumberTypeFloat16) |
|
|
@@ -104,7 +87,6 @@ MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
.AddInputAttr(kNumberTypeInt64) |
|
|
.AddInputAttr(kNumberTypeInt64) |
|
|
.AddOutputAttr(kNumberTypeFloat16), |
|
|
.AddOutputAttr(kNumberTypeFloat16), |
|
|
GatherV2GpuFwdKernel, half, int) |
|
|
GatherV2GpuFwdKernel, half, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
KernelAttr() |
|
|
KernelAttr() |
|
|
.AddInputAttr(kNumberTypeFloat16) |
|
|
.AddInputAttr(kNumberTypeFloat16) |
|
|
@@ -112,17 +94,14 @@ MS_REG_GPU_KERNEL_TWO(Gather, |
|
|
.AddInputAttr(kNumberTypeInt64) |
|
|
.AddInputAttr(kNumberTypeInt64) |
|
|
.AddOutputAttr(kNumberTypeFloat16), |
|
|
.AddOutputAttr(kNumberTypeFloat16), |
|
|
GatherV2GpuFwdKernel, half, int64_t) |
|
|
GatherV2GpuFwdKernel, half, int64_t) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
SparseGatherV2, |
|
|
SparseGatherV2, |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), |
|
|
GatherV2GpuFwdKernel, float, int) |
|
|
GatherV2GpuFwdKernel, float, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
MS_REG_GPU_KERNEL_TWO( |
|
|
SparseGatherV2, |
|
|
SparseGatherV2, |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), |
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), |
|
|
GatherV2GpuFwdKernel, half, int) |
|
|
GatherV2GpuFwdKernel, half, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO(SparseGatherV2, |
|
|
MS_REG_GPU_KERNEL_TWO(SparseGatherV2, |
|
|
KernelAttr() |
|
|
KernelAttr() |
|
|
.AddInputAttr(kNumberTypeFloat32) |
|
|
.AddInputAttr(kNumberTypeFloat32) |
|
|
@@ -130,7 +109,6 @@ MS_REG_GPU_KERNEL_TWO(SparseGatherV2, |
|
|
.AddInputAttr(kNumberTypeInt64) |
|
|
.AddInputAttr(kNumberTypeInt64) |
|
|
.AddOutputAttr(kNumberTypeFloat32), |
|
|
.AddOutputAttr(kNumberTypeFloat32), |
|
|
GatherV2GpuFwdKernel, float, int) |
|
|
GatherV2GpuFwdKernel, float, int) |
|
|
|
|
|
|
|
|
MS_REG_GPU_KERNEL_TWO(SparseGatherV2, |
|
|
MS_REG_GPU_KERNEL_TWO(SparseGatherV2, |
|
|
KernelAttr() |
|
|
KernelAttr() |
|
|
.AddInputAttr(kNumberTypeFloat16) |
|
|
.AddInputAttr(kNumberTypeFloat16) |
|
|
|