From: @linqingke Reviewed-by: @liangchenghui,@oacjiewen Signed-off-by: @liangchenghuitags/v1.1.0
| @@ -27,5 +27,8 @@ MS_REG_GPU_KERNEL_ONE(Split, | |||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| SplitGpuFwdKernel, half) | SplitGpuFwdKernel, half) | ||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), | |||||
| SplitGpuFwdKernel, uint32_t) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -48,3 +48,6 @@ template void SplitKernel(const size_t size, const int axis_step, const int all_ | |||||
| template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis, | template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis, | ||||
| const int all_size_axis, const half* input, half** outputs, | const int all_size_axis, const half* input, half** outputs, | ||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis, | |||||
| const int all_size_axis, const uint32_t* input, uint32_t** outputs, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -42,12 +42,18 @@ class SGDGpuKernel : public GpuKernel { | |||||
| T *accum = GetDeviceAddress<T>(inputs, 3); | T *accum = GetDeviceAddress<T>(inputs, 3); | ||||
| T *momentum = GetDeviceAddress<T>(inputs, 4); | T *momentum = GetDeviceAddress<T>(inputs, 4); | ||||
| T *stat = GetDeviceAddress<T>(inputs, 5); | T *stat = GetDeviceAddress<T>(inputs, 5); | ||||
| T *output_param = GetDeviceAddress<T>(outputs, 0); | |||||
| SGD(size_, dampening_, weight_decay_, nesterov_, lr, momentum, grad, param, accum, stat, | SGD(size_, dampening_, weight_decay_, nesterov_, lr, momentum, grad, param, accum, stat, | ||||
| reinterpret_cast<cudaStream_t>(stream)); | reinterpret_cast<cudaStream_t>(stream)); | ||||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||||
| cudaMemcpyAsync(output_param, param, sizeof(T) * size_, cudaMemcpyDeviceToDevice, | |||||
| reinterpret_cast<cudaStream_t>(stream)), | |||||
| "SGD cudaMemcpyAsync params to outputs failed"); | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| kernel_node_ = kernel_node; | |||||
| dampening_ = GetAttr<float>(kernel_node, "dampening"); | dampening_ = GetAttr<float>(kernel_node, "dampening"); | ||||
| weight_decay_ = GetAttr<float>(kernel_node, "weight_decay"); | weight_decay_ = GetAttr<float>(kernel_node, "weight_decay"); | ||||
| nesterov_ = GetAttr<bool>(kernel_node, "nesterov"); | nesterov_ = GetAttr<bool>(kernel_node, "nesterov"); | ||||