diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc index 0101f65001..74edfc2156 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc @@ -27,5 +27,8 @@ MS_REG_GPU_KERNEL_ONE(Split, MS_REG_GPU_KERNEL_ONE( Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), SplitGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE( + Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + SplitGpuFwdKernel, uint32_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu index e892a3b47d..1a17989df2 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu @@ -48,3 +48,6 @@ template void SplitKernel(const size_t size, const int axis_step, const int all_ template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis, const int all_size_axis, const half* input, half** outputs, cudaStream_t cuda_stream); +template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const uint32_t* input, uint32_t** outputs, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h index 70a57cded0..dc116a6826 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h @@ -42,12 +42,18 @@ class SGDGpuKernel : public GpuKernel { T *accum = GetDeviceAddress(inputs, 3); T *momentum = GetDeviceAddress(inputs, 4); T *stat = GetDeviceAddress(inputs, 5); + T *output_param = GetDeviceAddress(outputs, 0); SGD(size_, dampening_, weight_decay_, nesterov_, lr, momentum, grad, param, accum, stat, reinterpret_cast(stream)); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(output_param, param, sizeof(T) * size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream)), + "SGD cudaMemcpyAsync params to outputs failed"); return true; } bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; dampening_ = GetAttr(kernel_node, "dampening"); weight_decay_ = GetAttr(kernel_node, "weight_decay"); nesterov_ = GetAttr(kernel_node, "nesterov");