diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu index 1e94fd57d5..21c7e4b786 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu @@ -99,7 +99,7 @@ template __global__ void FusedMomentumScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate, const S *gradient, const T *momentum) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { - accumulation[i] = momentum[0] * accumulation[i] + static_cast(gradient[i]); + accumulation[i] = momentum[0] * accumulation[i] + static_cast(gradient[i]) * scale[0]; variable[i] -= learning_rate[0] * accumulation[i]; } } @@ -113,6 +113,56 @@ void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accu element_num, scale, variable, accumulation, learning_rate, gradient, momentum); } +// CombineFusedScaleMomentum +template +__global__ void CombineFusedMomentumScaleMomentum(const size_t num, const size_t *element_num, + T **scale, T **variable, T **accumulation, + T **learning_rate, S **gradient, T **momentum) { + for (size_t idx = 0; idx < num; idx++) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) { + accumulation[idx][i] = momentum[idx][0] * accumulation[idx][i] + static_cast(gradient[idx][i]) * scale[idx][0]; + variable[idx][i] -= learning_rate[idx][0] * accumulation[idx][i]; + } + } +} + +template +void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, T **scale, + T **variable, T **accumulation, T **learning_rate, S **gradient, + T **momentum, cudaStream_t cuda_stream) { + size_t thread_per_block = 256; + size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block; + CombineFusedMomentumScaleMomentum<<>>( + num, elements, scale, variable, accumulation, learning_rate, gradient, momentum); +} +// end CombineFusedScaleMomentum + +// CombineFusedWeightDecayScaleMomentum +template +__global__ void CombineFusedMomentumWeightDecayScaleMomentum(const size_t num, const size_t *element_num, + T **weight_decay, T **scale, T **variable, + T **accumulation, T **learning_rate, S **gradient, + T **momentum) { + for (size_t idx = 0; idx < num; idx++) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) { + T grad = (variable[idx][i] * weight_decay[idx][0] + static_cast(gradient[idx][i])) * scale[idx][0]; + accumulation[idx][i] = momentum[idx][0] * accumulation[idx][i] + grad; + variable[idx][i] -= learning_rate[idx][0] * accumulation[idx][i]; + } + } +} + +template +void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *element_num, + T **weight_decay, T **scale, T **variable, T **accumulation, + T **learning_rate, S **gradient, T **momentum, + cudaStream_t cuda_stream) { + size_t thread_per_block = 256; + size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block; + CombineFusedMomentumWeightDecayScaleMomentum<<>>( + num, element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum); +} +// end CombineFusedWeightDecayScaleMomentum template void MomentumUpdateVariable(const size_t size, float *variable, float *accumulation, const float *learning_rate, const float *gradient, const float *momentum, bool use_nesterov, @@ -142,3 +192,17 @@ template void FusedScaleMomentum(const size_t element_num, float *scale, float * template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation, const float *learning_rate, const half *gradient, const float *momentum, cudaStream_t cuda_stream); +template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements, + float **weight_decay, float **scale, float **variable, + float **accumulation, float **learning_rate, float **gradient, + float **momentum, cudaStream_t cuda_stream); +template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements, + float **weight_decay, float **scale, float **variable, + float **accumulation, float **learning_rate, half **gradient, + float **momentum, cudaStream_t cuda_stream); +template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale, + float **variable, float **accumulation, float **learning_rate, + float **gradient, float **momentum, cudaStream_t cuda_stream); +template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale, + float **variable, float **accumulation, float **learning_rate, + half **gradient, float **momentum, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh index 7ce15c97ed..7c9ac40bbd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh @@ -28,5 +28,12 @@ void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T template void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate, const S *gradient, const T *momentum, cudaStream_t cuda_stream); - +template +void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *element, T **weight_decay, + T **scale, T **variable, T **accumulation, T **learning_rate, S **gradient, + T **momentum, cudaStream_t cuda_stream); +template +void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *element, T **scale, T **variable, + T **accumulation, T **learning_rate, S **gradient, T **momentum, + cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc index 8cd5a3c4ee..a2750f3243 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc @@ -40,22 +40,12 @@ void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const Kernel std::vector> *iter_second, size_t attr_index) { if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) { - if (iter_second->at(attr_index).first.GetAllSame()) { - auto dtype = iter_second->at(attr_index).first.GetInputAttr(0).first; - for (size_t attr = 1; attr < kernel_info->GetInputNum(); ++attr) { - (void)iter_second->at(attr_index).first.AddInputAttr(dtype); - } - } else { + if (!iter_second->at(attr_index).first.GetAllSame()) { MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!"; } } if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) { - if (iter_second->at(attr_index).first.GetAllSame()) { - auto dtype = iter_second->at(attr_index).first.GetOutputAttr(0).first; - for (size_t attr = 1; attr < kernel_info->GetOutputNum(); ++attr) { - (void)iter_second->at(attr_index).first.AddOutputAttr(dtype); - } - } else { + if (!iter_second->at(attr_index).first.GetAllSame()) { MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!"; } } @@ -99,6 +89,7 @@ std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string & for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index); bool flag = true; + auto attr_size = (&(iter->second))->at(attr_index).first.GetInputSize(); // data type matching check of all input parameters of kernel for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) { if (marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { @@ -110,7 +101,7 @@ std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string & << ", but the current device's computing capacity is " << marjor_sm; } if (kernel_info->GetInputDeviceType(input_index) != - (iter->second)[attr_index].first.GetInputAttr(input_index).first) { + (iter->second)[attr_index].first.GetInputAttr(input_index % attr_size).first) { flag = false; break; } @@ -118,10 +109,11 @@ std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string & if (!flag) { continue; } + attr_size = (&(iter->second))->at(attr_index).first.GetOutputSize(); // data type matching check of all output parameters of kernel for (size_t output_index = 0; output_index < kernel_info->GetOutputNum(); output_index++) { if (kernel_info->GetOutputDeviceType(output_index) != - (iter->second)[attr_index].first.GetOutputAttr(output_index).first) { + (iter->second)[attr_index].first.GetOutputAttr(output_index % attr_size).first) { flag = false; break; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.cc new file mode 100644 index 0000000000..1e36fce7b0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(CombineMomentum, + KernelAttr() + .AddAllSameAttr(true) + .AddInputAttr(kNumberTypeFloat32) // scale + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // learning_rate + .AddInputAttr(kNumberTypeFloat32) // gradient + .AddInputAttr(kNumberTypeFloat32) // momentum + .AddOutputAttr(kNumberTypeFloat32), + CombineMomentumGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO(CombineMomentum, + KernelAttr() + .AddAllSameAttr(true) + .AddInputAttr(kNumberTypeFloat32) // scale + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // learning_rate + .AddInputAttr(kNumberTypeFloat16) // gradient + .AddInputAttr(kNumberTypeFloat32) // momentum + .AddOutputAttr(kNumberTypeFloat32), + CombineMomentumGpuKernel, float, half) +MS_REG_GPU_KERNEL_TWO(CombineMomentumWeight, + KernelAttr() + .AddAllSameAttr(true) + .AddInputAttr(kNumberTypeFloat32) // weight decay + .AddInputAttr(kNumberTypeFloat32) // scale + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // learning_rate + .AddInputAttr(kNumberTypeFloat32) // gradient + .AddInputAttr(kNumberTypeFloat32) // momentum + .AddOutputAttr(kNumberTypeFloat32), + CombineMomentumGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO(CombineMomentumWeight, + KernelAttr() + .AddAllSameAttr(true) + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // learning_rate + .AddInputAttr(kNumberTypeFloat16) // gradient + .AddInputAttr(kNumberTypeFloat32) // momentum + .AddOutputAttr(kNumberTypeFloat32), + CombineMomentumGpuKernel, float, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.h new file mode 100644 index 0000000000..81174dd81f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.h @@ -0,0 +1,207 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" +namespace mindspore { +namespace kernel { +template +class CombineMomentumGpuKernel : public GpuKernel { + public: + CombineMomentumGpuKernel() : element_num_(1), num_(0), max_(0), input_num_(6) {} + ~CombineMomentumGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &workspace, void *stream_ptr) override { + const cudaStream_t stream = reinterpret_cast(stream_ptr); + auto weight_decay = std::make_unique(input_num_ * num_); + auto scale = std::make_unique(input_num_ * num_); + auto variable = std::make_unique(input_num_ * num_); + auto accumulation = std::make_unique(input_num_ * num_); + auto learning_rate = std::make_unique(input_num_ * num_); + auto gradient = std::make_unique(input_num_ * num_); + auto momentum = std::make_unique(input_num_ * num_); + if (input_num_ == 6) { + LaunchCombineMom(inputs, workspace, stream, scale, variable, accumulation, learning_rate, gradient, momentum); + } else { + LaunchCombineMomWeightDecay(inputs, workspace, stream, weight_decay, scale, variable, accumulation, learning_rate, + gradient, momentum); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + num_ = GetAttr(kernel_node, "n"); + elements_ = std::make_unique(num_); + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == "CombineMomentum") { + input_num_ = 6; + } else { + input_num_ = 7; + workspace_size_list_.push_back(sizeof(T *) * num_); + } + + for (size_t i = 0; i < num_; i++) { + element_num_ = 1; + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i * input_num_ + input_num_ - 4); + for (size_t j = 0; j < variable_shape.size(); j++) { + element_num_ *= variable_shape[j]; + } + if (max_ < element_num_) { + max_ = element_num_; + } + elements_[i] = element_num_; + InitSizeLists(); + } + workspace_size_list_.push_back(sizeof(T *) * num_); + workspace_size_list_.push_back(sizeof(T *) * num_); + workspace_size_list_.push_back(sizeof(T *) * num_); + workspace_size_list_.push_back(sizeof(T *) * num_); + workspace_size_list_.push_back(sizeof(S *) * num_); + workspace_size_list_.push_back(sizeof(T *) * num_); + workspace_size_list_.push_back(sizeof(size_t) * num_); + return true; + } + + protected: + void InitSizeLists() override { + if (input_num_ == 7) { + input_size_list_.push_back(sizeof(T)); + } + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(S)); + input_size_list_.push_back(sizeof(T)); + output_size_list_.push_back(element_num_ * sizeof(T)); + } + + private: + void LaunchCombineMom(const std::vector &inputs, const std::vector &workspace, + const cudaStream_t &stream, const std::unique_ptr &scale, + const std::unique_ptr &variable, const std::unique_ptr &accumulation, + const std::unique_ptr &learning_rate, const std::unique_ptr &gradient, + const std::unique_ptr &momentum) { + for (size_t i = 0; i < num_; i++) { + scale[i] = GetDeviceAddress(inputs, i * input_num_); + variable[i] = GetDeviceAddress(inputs, i * input_num_ + 1); + accumulation[i] = GetDeviceAddress(inputs, i * input_num_ + 2); + learning_rate[i] = GetDeviceAddress(inputs, i * input_num_ + 3); + gradient[i] = GetDeviceAddress(inputs, i * input_num_ + 4); + momentum[i] = GetDeviceAddress(inputs, i * input_num_ + 5); + } + T **scale_dev = GetDeviceAddress(workspace, 0); + T **variable_dev = GetDeviceAddress(workspace, 1); + T **accumulation_dev = GetDeviceAddress(workspace, 2); + T **learning_rate_dev = GetDeviceAddress(workspace, 3); + S **gradient_dev = GetDeviceAddress(workspace, 4); + T **momentum_dev = GetDeviceAddress(workspace, 5); + size_t *elements_dev = GetDeviceAddress(workspace, 6); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(scale_dev, scale.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(variable_dev, variable.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(accumulation_dev, accumulation.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(learning_rate_dev, learning_rate.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(gradient_dev, gradient.get(), sizeof(S *) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(momentum_dev, momentum.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(elements_dev, elements_.get(), sizeof(size_t) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CombineFusedScaleMomentum(max_, num_, elements_dev, scale_dev, variable_dev, accumulation_dev, learning_rate_dev, + gradient_dev, momentum_dev, stream); + } + void LaunchCombineMomWeightDecay(const std::vector &inputs, const std::vector &workspace, + const cudaStream_t &stream, const std::unique_ptr &weight_decay, + const std::unique_ptr &scale, const std::unique_ptr &variable, + const std::unique_ptr &accumulation, + const std::unique_ptr &learning_rate, const std::unique_ptr &gradient, + const std::unique_ptr &momentum) { + for (size_t i = 0; i < num_; i++) { + weight_decay[i] = GetDeviceAddress(inputs, i * input_num_); + scale[i] = GetDeviceAddress(inputs, i * input_num_ + 1); + variable[i] = GetDeviceAddress(inputs, i * input_num_ + 2); + accumulation[i] = GetDeviceAddress(inputs, i * input_num_ + 3); + learning_rate[i] = GetDeviceAddress(inputs, i * input_num_ + 4); + gradient[i] = GetDeviceAddress(inputs, i * input_num_ + 5); + momentum[i] = GetDeviceAddress(inputs, i * input_num_ + 6); + } + T **weight_decay_dev = GetDeviceAddress(workspace, 0); + T **scale_dev = GetDeviceAddress(workspace, 1); + T **variable_dev = GetDeviceAddress(workspace, 2); + T **accumulation_dev = GetDeviceAddress(workspace, 3); + T **learning_rate_dev = GetDeviceAddress(workspace, 4); + S **gradient_dev = GetDeviceAddress(workspace, 5); + T **momentum_dev = GetDeviceAddress(workspace, 6); + size_t *elements_dev = GetDeviceAddress(workspace, 7); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(weight_decay_dev, weight_decay.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(scale_dev, scale.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(variable_dev, variable.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(accumulation_dev, accumulation.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(learning_rate_dev, learning_rate.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(gradient_dev, gradient.get(), sizeof(S *) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(momentum_dev, momentum.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(elements_dev, elements_.get(), sizeof(size_t) * num_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CombineFusedWeightDecayScaleMomentum(max_, num_, elements_dev, weight_decay_dev, scale_dev, variable_dev, + accumulation_dev, learning_rate_dev, gradient_dev, momentum_dev, stream); + } + size_t element_num_; + std::unique_ptr elements_; + size_t num_; + size_t max_; + int input_num_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h index 4cfd7d6548..175f694154 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h @@ -52,7 +52,7 @@ class FusedScaleMomentumGpuKernel : public GpuKernel { return false; } - auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); for (size_t i = 0; i < variable_shape.size(); i++) { element_num_ *= variable_shape[i]; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h index 22f5a711a1..457c3b3cfd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h @@ -53,7 +53,7 @@ class FusedWeightDecayScaleMomentumGpuKernel : public GpuKernel { return false; } - auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); for (size_t i = 0; i < variable_shape.size(); i++) { element_num_ *= variable_shape[i]; } diff --git a/mindspore/ccsrc/backend/optimizer/gpu/combine_momentum_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/combine_momentum_fusion.cc new file mode 100644 index 0000000000..b7dcc26c97 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/combine_momentum_fusion.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/combine_momentum_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector &node_list) { + std::vector inputs_device_format; + std::vector outputs_device_format; + std::vector inputs_device_type; + std::vector outputs_device_type; + std::vector> outputs_shape; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + for (size_t idx = 0; idx < node_list.size(); ++idx) { + auto cnode = utils::cast(node_list[idx]); + MS_EXCEPTION_IF_NULL(cnode); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + inputs_device_format.push_back(kOpFormat_DEFAULT); + inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index)); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + outputs_device_format.push_back(kOpFormat_DEFAULT); + outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); + } + } + builder.SetInputsFormat(inputs_device_format); + builder.SetOutputsFormat(outputs_device_format); + builder.SetInputsDeviceType(inputs_device_type); + builder.SetOutputsDeviceType(outputs_device_type); + return builder.Build(); +} +bool GetDealList(const std::vector &node_list, std::vector> *deal_list) { + std::vector momentum; + std::vector momentum_decay; + for (auto &momentum_node : node_list) { + if (momentum_node != nullptr && momentum_node->isa()) { + if (AnfAlgo::GetCNodeName(momentum_node) == kFusedScaleApplyMomentum) { + momentum.push_back(momentum_node); + } else if (AnfAlgo::GetCNodeName(momentum_node) == kFusedWeightScaleApplyMomentum) { + momentum_decay.push_back(momentum_node); + } + } + } + if (momentum.size() <= 1 && momentum_decay.size() <= 1) { + return false; + } + if (momentum.size() > 1) { + deal_list->push_back(momentum); + } + if (momentum_decay.size() > 1) { + deal_list->push_back(momentum_decay); + } + return true; +} +bool CombineMomentumFusion::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector node_list = TopoSort(graph->get_return()); + // 1 get all the cast node + std::vector> deal_list; + if (!GetDealList(node_list, &deal_list)) { + return false; + } + for (auto momentums : deal_list) { + // 2 create node momentum + std::vector inputs = {}; + if (AnfAlgo::GetCNodeName(momentums[0]) == kFusedScaleApplyMomentum) { + auto prim = std::make_shared("CombineMomentum"); + MS_EXCEPTION_IF_NULL(prim); + inputs.push_back(NewValueNode(prim)); + } else { + auto prim = std::make_shared("CombineMomentumWeight"); + MS_EXCEPTION_IF_NULL(prim); + inputs.push_back(NewValueNode(prim)); + } + // set inputs for momentum + size_t input_num = AnfAlgo::GetInputTensorNum(momentums[0]); + for (auto mom : momentums) { + for (size_t i = 0; i < input_num; i++) { + inputs.push_back(AnfAlgo::GetInputNode(utils::cast(mom), i)); + } + } + auto combine_mom = graph->NewCNode(inputs); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + combine_mom->set_kernel_info(kernel_info); + AbstractBasePtrList abstract_list; + for (size_t idx = 0; idx < momentums.size(); ++idx) { + auto cnode = utils::cast(momentums[idx]); + MS_EXCEPTION_IF_NULL(cnode); + abstract_list.push_back(cnode->abstract()); + } + auto kernel_build_info = GenerateKernelBuildInfo(momentums); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, combine_mom.get()); + auto abstract_tuple = std::make_shared(abstract_list); + MS_EXCEPTION_IF_NULL(abstract_tuple); + combine_mom->set_abstract(abstract_tuple); + AnfAlgo::SetNodeAttr("n", MakeValue(momentums.size()), combine_mom); + // 3 replace all the cast by momentum + for (size_t idx = 0; idx < momentums.size(); ++idx) { + if (!manager->Replace(momentums[idx], combine_mom)) { + MS_LOG(EXCEPTION) << "manager replace node failed"; + } + } + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/combine_momentum_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/combine_momentum_fusion.h new file mode 100644 index 0000000000..b348adc431 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/combine_momentum_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_MOMENTUM_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_MOMENTUM_FUSION_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class CombineMomentumFusion : public Pass { + public: + explicit CombineMomentumFusion(const std::string &name) : Pass("combine_momentum") {} + ~CombineMomentumFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_MOMENTUM_FUSION_H_