From 57cb024e03d1546fbbdcf5ef8913cdc276720008 Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Sat, 20 Mar 2021 02:01:54 -0400 Subject: [PATCH] add float64 support for ops for end of version 1.1 fix typo --- .../gpu/arrays/argmaxwithvalue_gpu_kernel.cc | 6 +++++- .../gpu/arrays/broadcast_to_gpu_kernel.cc | 4 +++- .../gpu/arrays/tensor_scatter_update_gpu_kernel.cc | 8 ++++++++ .../gpu/cuda_impl/broadcast_grad_impl.cu | 9 ++++++++- .../kernel_compiler/gpu/cuda_impl/broadcast_impl.cu | 3 +++ .../gpu/cuda_impl/general_reduction_impl.cu | 4 +++- .../gpu/cuda_impl/tensor_scatter_update.cu | 5 +++++ .../kernel_compiler/gpu/math/broadcast_gpu_kernel.cc | 7 +++++++ .../gpu/math/broadcast_grad_gpu_kernel.cc | 10 +++++++++- 9 files changed, 51 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc index 5ead387ccc..4ddbd35e07 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,10 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_TWO( + ArgMaxWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + ArgmaxWithValueGpuKernel, double, int) MS_REG_GPU_KERNEL_TWO( ArgMaxWithValue, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.cc index 96e82bc5f3..0d5387186b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + BroadcastToGpuKernel, double) MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), BroadcastToGpuKernel, float) MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc index f7ef765ab5..6412c834a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc @@ -34,6 +34,14 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, .AddOutputAttr(kNumberTypeFloat32), TensorScatterUpdateGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + TensorScatterUpdateGpuFwdKernel, double, int) + MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, KernelAttr() .AddInputAttr(kNumberTypeInt8) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu index 150fb31e8e..b1e54d0772 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -113,6 +113,9 @@ void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, NoBroadcastGradKernel<<>>(nums, grad_x1, grad_x2, op, x1, x2, dy, dx1, dx2); } +template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, + const double *x1, const double *x2, const double *dy, double *dx1, double *dx2, + cudaStream_t stream); template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1, const float *x2, const float *dy, float *dx1, float *dx2, cudaStream_t stream); @@ -124,6 +127,10 @@ template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool & template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const int64_t *x1, const int64_t *x2, const int64_t *dy, int64_t *dx1, int64_t *dx2, cudaStream_t stream); +template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const double *x1, + const double *x2, const double *dy, double *dx1, double *dx2, cudaStream_t stream); template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index 1675976004..c00bfed8ba 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -555,6 +555,9 @@ void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const siz output_addr); } +template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, + const size_t &o1, const size_t &o2, const size_t &o3, const double *input_addr, + double *output_addr, cudaStream_t stream); template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, const size_t &o1, const size_t &o2, const size_t &o3, const float *input_addr, float *output_addr, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu index d7b5c57261..534b374596 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -327,6 +327,8 @@ void CalGeneralReduction(bool small, const T *input, const size_t bound, const s return; } +template void CalGeneralReduction(bool small, const double *input, const size_t bound_, const size_t outerSize_, + const size_t innerSize_, int *index, double *output, cudaStream_t cuda_stream); template void CalGeneralReduction(bool small, const float *input, const size_t bound_, const size_t outerSize_, const size_t innerSize_, int *index, float *output, cudaStream_t cuda_stream); template void CalGeneralReduction(bool small, const half *input, const size_t bound_, const size_t outerSize_, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu index 8470cf29c9..e322e579f9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu @@ -67,6 +67,11 @@ template void TensorScatterUpdate(float *input, int *indices, float const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, cudaStream_t stream); +template void TensorScatterUpdate(double *input, int *indices, double *update, double *output, + const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); template void TensorScatterUpdate(char *input, int *indices, char *update, char *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index 9b38d39db0..481becc694 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -23,6 +23,10 @@ MS_REG_GPU_KERNEL_ONE( Greater, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), BroadcastOpGpuKernel, double) +MS_REG_GPU_KERNEL_ONE( + Minimum, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + BroadcastOpGpuKernel, double) MS_REG_GPU_KERNEL_ONE( Less, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), BroadcastOpGpuKernel, double) @@ -46,6 +50,9 @@ MS_REG_GPU_KERNEL_ONE( RealDiv, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), BroadcastOpGpuKernel, double) +MS_REG_GPU_KERNEL_ONE( + Pow, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + BroadcastOpGpuKernel, double) // fp32 MS_REG_GPU_KERNEL_ONE( diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc index 4d9ef3b184..6500da7fbe 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,14 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + BroadcastOpGradGpuKernel, double) MS_REG_GPU_KERNEL_ONE(MinimumGrad, KernelAttr() .AddInputAttr(kNumberTypeFloat32)