Merge pull request !6818 from VectorSL/combine-momtags/v1.1.0
| @@ -99,7 +99,7 @@ template <typename T, typename S> | |||||
| __global__ void FusedMomentumScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, | __global__ void FusedMomentumScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, | ||||
| const T *learning_rate, const S *gradient, const T *momentum) { | const T *learning_rate, const S *gradient, const T *momentum) { | ||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { | ||||
| accumulation[i] = momentum[0] * accumulation[i] + static_cast<T>(gradient[i]); | |||||
| accumulation[i] = momentum[0] * accumulation[i] + static_cast<T>(gradient[i]) * scale[0]; | |||||
| variable[i] -= learning_rate[0] * accumulation[i]; | variable[i] -= learning_rate[0] * accumulation[i]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -113,6 +113,56 @@ void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accu | |||||
| element_num, scale, variable, accumulation, learning_rate, gradient, momentum); | element_num, scale, variable, accumulation, learning_rate, gradient, momentum); | ||||
| } | } | ||||
| // CombineFusedScaleMomentum | |||||
| template <typename T, typename S> | |||||
| __global__ void CombineFusedMomentumScaleMomentum(const size_t num, const size_t *element_num, | |||||
| T **scale, T **variable, T **accumulation, | |||||
| T **learning_rate, S **gradient, T **momentum) { | |||||
| for (size_t idx = 0; idx < num; idx++) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) { | |||||
| accumulation[idx][i] = momentum[idx][0] * accumulation[idx][i] + static_cast<T>(gradient[idx][i]) * scale[idx][0]; | |||||
| variable[idx][i] -= learning_rate[idx][0] * accumulation[idx][i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T, typename S> | |||||
| void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, T **scale, | |||||
| T **variable, T **accumulation, T **learning_rate, S **gradient, | |||||
| T **momentum, cudaStream_t cuda_stream) { | |||||
| size_t thread_per_block = 256; | |||||
| size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block; | |||||
| CombineFusedMomentumScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||||
| num, elements, scale, variable, accumulation, learning_rate, gradient, momentum); | |||||
| } | |||||
| // end CombineFusedScaleMomentum | |||||
| // CombineFusedWeightDecayScaleMomentum | |||||
| template <typename T, typename S> | |||||
| __global__ void CombineFusedMomentumWeightDecayScaleMomentum(const size_t num, const size_t *element_num, | |||||
| T **weight_decay, T **scale, T **variable, | |||||
| T **accumulation, T **learning_rate, S **gradient, | |||||
| T **momentum) { | |||||
| for (size_t idx = 0; idx < num; idx++) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) { | |||||
| T grad = (variable[idx][i] * weight_decay[idx][0] + static_cast<T>(gradient[idx][i])) * scale[idx][0]; | |||||
| accumulation[idx][i] = momentum[idx][0] * accumulation[idx][i] + grad; | |||||
| variable[idx][i] -= learning_rate[idx][0] * accumulation[idx][i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T, typename S> | |||||
| void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *element_num, | |||||
| T **weight_decay, T **scale, T **variable, T **accumulation, | |||||
| T **learning_rate, S **gradient, T **momentum, | |||||
| cudaStream_t cuda_stream) { | |||||
| size_t thread_per_block = 256; | |||||
| size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block; | |||||
| CombineFusedMomentumWeightDecayScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||||
| num, element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum); | |||||
| } | |||||
| // end CombineFusedWeightDecayScaleMomentum | |||||
| template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation, | template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation, | ||||
| const float *learning_rate, const float *gradient, | const float *learning_rate, const float *gradient, | ||||
| const float *momentum, bool use_nesterov, | const float *momentum, bool use_nesterov, | ||||
| @@ -142,3 +192,17 @@ template void FusedScaleMomentum(const size_t element_num, float *scale, float * | |||||
| template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation, | template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation, | ||||
| const float *learning_rate, const half *gradient, const float *momentum, | const float *learning_rate, const half *gradient, const float *momentum, | ||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements, | |||||
| float **weight_decay, float **scale, float **variable, | |||||
| float **accumulation, float **learning_rate, float **gradient, | |||||
| float **momentum, cudaStream_t cuda_stream); | |||||
| template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements, | |||||
| float **weight_decay, float **scale, float **variable, | |||||
| float **accumulation, float **learning_rate, half **gradient, | |||||
| float **momentum, cudaStream_t cuda_stream); | |||||
| template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale, | |||||
| float **variable, float **accumulation, float **learning_rate, | |||||
| float **gradient, float **momentum, cudaStream_t cuda_stream); | |||||
| template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale, | |||||
| float **variable, float **accumulation, float **learning_rate, | |||||
| half **gradient, float **momentum, cudaStream_t cuda_stream); | |||||
| @@ -28,5 +28,12 @@ void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T | |||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate, | void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate, | ||||
| const S *gradient, const T *momentum, cudaStream_t cuda_stream); | const S *gradient, const T *momentum, cudaStream_t cuda_stream); | ||||
| template <typename T, typename S> | |||||
| void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *element, T **weight_decay, | |||||
| T **scale, T **variable, T **accumulation, T **learning_rate, S **gradient, | |||||
| T **momentum, cudaStream_t cuda_stream); | |||||
| template <typename T, typename S> | |||||
| void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *element, T **scale, T **variable, | |||||
| T **accumulation, T **learning_rate, S **gradient, T **momentum, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | ||||
| @@ -40,22 +40,12 @@ void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const Kernel | |||||
| std::vector<std::pair<KernelAttr, GpuKernelCreater>> *iter_second, | std::vector<std::pair<KernelAttr, GpuKernelCreater>> *iter_second, | ||||
| size_t attr_index) { | size_t attr_index) { | ||||
| if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) { | if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) { | ||||
| if (iter_second->at(attr_index).first.GetAllSame()) { | |||||
| auto dtype = iter_second->at(attr_index).first.GetInputAttr(0).first; | |||||
| for (size_t attr = 1; attr < kernel_info->GetInputNum(); ++attr) { | |||||
| (void)iter_second->at(attr_index).first.AddInputAttr(dtype); | |||||
| } | |||||
| } else { | |||||
| if (!iter_second->at(attr_index).first.GetAllSame()) { | |||||
| MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!"; | MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!"; | ||||
| } | } | ||||
| } | } | ||||
| if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) { | if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) { | ||||
| if (iter_second->at(attr_index).first.GetAllSame()) { | |||||
| auto dtype = iter_second->at(attr_index).first.GetOutputAttr(0).first; | |||||
| for (size_t attr = 1; attr < kernel_info->GetOutputNum(); ++attr) { | |||||
| (void)iter_second->at(attr_index).first.AddOutputAttr(dtype); | |||||
| } | |||||
| } else { | |||||
| if (!iter_second->at(attr_index).first.GetAllSame()) { | |||||
| MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!"; | MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!"; | ||||
| } | } | ||||
| } | } | ||||
| @@ -99,6 +89,7 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string & | |||||
| for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { | for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { | ||||
| CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index); | CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index); | ||||
| bool flag = true; | bool flag = true; | ||||
| auto attr_size = (&(iter->second))->at(attr_index).first.GetInputSize(); | |||||
| // data type matching check of all input parameters of kernel | // data type matching check of all input parameters of kernel | ||||
| for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) { | for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) { | ||||
| if (marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { | if (marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { | ||||
| @@ -110,7 +101,7 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string & | |||||
| << ", but the current device's computing capacity is " << marjor_sm; | << ", but the current device's computing capacity is " << marjor_sm; | ||||
| } | } | ||||
| if (kernel_info->GetInputDeviceType(input_index) != | if (kernel_info->GetInputDeviceType(input_index) != | ||||
| (iter->second)[attr_index].first.GetInputAttr(input_index).first) { | |||||
| (iter->second)[attr_index].first.GetInputAttr(input_index % attr_size).first) { | |||||
| flag = false; | flag = false; | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -118,10 +109,11 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string & | |||||
| if (!flag) { | if (!flag) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| attr_size = (&(iter->second))->at(attr_index).first.GetOutputSize(); | |||||
| // data type matching check of all output parameters of kernel | // data type matching check of all output parameters of kernel | ||||
| for (size_t output_index = 0; output_index < kernel_info->GetOutputNum(); output_index++) { | for (size_t output_index = 0; output_index < kernel_info->GetOutputNum(); output_index++) { | ||||
| if (kernel_info->GetOutputDeviceType(output_index) != | if (kernel_info->GetOutputDeviceType(output_index) != | ||||
| (iter->second)[attr_index].first.GetOutputAttr(output_index).first) { | |||||
| (iter->second)[attr_index].first.GetOutputAttr(output_index % attr_size).first) { | |||||
| flag = false; | flag = false; | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_TWO(CombineMomentum, | |||||
| KernelAttr() | |||||
| .AddAllSameAttr(true) | |||||
| .AddInputAttr(kNumberTypeFloat32) // scale | |||||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||||
| .AddInputAttr(kNumberTypeFloat32) // learning_rate | |||||
| .AddInputAttr(kNumberTypeFloat32) // gradient | |||||
| .AddInputAttr(kNumberTypeFloat32) // momentum | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| CombineMomentumGpuKernel, float, float) | |||||
| MS_REG_GPU_KERNEL_TWO(CombineMomentum, | |||||
| KernelAttr() | |||||
| .AddAllSameAttr(true) | |||||
| .AddInputAttr(kNumberTypeFloat32) // scale | |||||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||||
| .AddInputAttr(kNumberTypeFloat32) // learning_rate | |||||
| .AddInputAttr(kNumberTypeFloat16) // gradient | |||||
| .AddInputAttr(kNumberTypeFloat32) // momentum | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| CombineMomentumGpuKernel, float, half) | |||||
| MS_REG_GPU_KERNEL_TWO(CombineMomentumWeight, | |||||
| KernelAttr() | |||||
| .AddAllSameAttr(true) | |||||
| .AddInputAttr(kNumberTypeFloat32) // weight decay | |||||
| .AddInputAttr(kNumberTypeFloat32) // scale | |||||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||||
| .AddInputAttr(kNumberTypeFloat32) // learning_rate | |||||
| .AddInputAttr(kNumberTypeFloat32) // gradient | |||||
| .AddInputAttr(kNumberTypeFloat32) // momentum | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| CombineMomentumGpuKernel, float, float) | |||||
| MS_REG_GPU_KERNEL_TWO(CombineMomentumWeight, | |||||
| KernelAttr() | |||||
| .AddAllSameAttr(true) | |||||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||||
| .AddInputAttr(kNumberTypeFloat32) // learning_rate | |||||
| .AddInputAttr(kNumberTypeFloat16) // gradient | |||||
| .AddInputAttr(kNumberTypeFloat32) // momentum | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| CombineMomentumGpuKernel, float, half) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,207 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T, typename S> | |||||
| class CombineMomentumGpuKernel : public GpuKernel { | |||||
| public: | |||||
| CombineMomentumGpuKernel() : element_num_(1), num_(0), max_(0), input_num_(6) {} | |||||
| ~CombineMomentumGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &workspace, void *stream_ptr) override { | |||||
| const cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr); | |||||
| auto weight_decay = std::make_unique<T *[]>(input_num_ * num_); | |||||
| auto scale = std::make_unique<T *[]>(input_num_ * num_); | |||||
| auto variable = std::make_unique<T *[]>(input_num_ * num_); | |||||
| auto accumulation = std::make_unique<T *[]>(input_num_ * num_); | |||||
| auto learning_rate = std::make_unique<T *[]>(input_num_ * num_); | |||||
| auto gradient = std::make_unique<S *[]>(input_num_ * num_); | |||||
| auto momentum = std::make_unique<T *[]>(input_num_ * num_); | |||||
| if (input_num_ == 6) { | |||||
| LaunchCombineMom(inputs, workspace, stream, scale, variable, accumulation, learning_rate, gradient, momentum); | |||||
| } else { | |||||
| LaunchCombineMomWeightDecay(inputs, workspace, stream, weight_decay, scale, variable, accumulation, learning_rate, | |||||
| gradient, momentum); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| num_ = GetAttr<size_t>(kernel_node, "n"); | |||||
| elements_ = std::make_unique<size_t[]>(num_); | |||||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| if (kernel_name == "CombineMomentum") { | |||||
| input_num_ = 6; | |||||
| } else { | |||||
| input_num_ = 7; | |||||
| workspace_size_list_.push_back(sizeof(T *) * num_); | |||||
| } | |||||
| for (size_t i = 0; i < num_; i++) { | |||||
| element_num_ = 1; | |||||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i * input_num_ + input_num_ - 4); | |||||
| for (size_t j = 0; j < variable_shape.size(); j++) { | |||||
| element_num_ *= variable_shape[j]; | |||||
| } | |||||
| if (max_ < element_num_) { | |||||
| max_ = element_num_; | |||||
| } | |||||
| elements_[i] = element_num_; | |||||
| InitSizeLists(); | |||||
| } | |||||
| workspace_size_list_.push_back(sizeof(T *) * num_); | |||||
| workspace_size_list_.push_back(sizeof(T *) * num_); | |||||
| workspace_size_list_.push_back(sizeof(T *) * num_); | |||||
| workspace_size_list_.push_back(sizeof(T *) * num_); | |||||
| workspace_size_list_.push_back(sizeof(S *) * num_); | |||||
| workspace_size_list_.push_back(sizeof(T *) * num_); | |||||
| workspace_size_list_.push_back(sizeof(size_t) * num_); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| if (input_num_ == 7) { | |||||
| input_size_list_.push_back(sizeof(T)); | |||||
| } | |||||
| input_size_list_.push_back(sizeof(T)); | |||||
| input_size_list_.push_back(element_num_ * sizeof(T)); | |||||
| input_size_list_.push_back(element_num_ * sizeof(T)); | |||||
| input_size_list_.push_back(sizeof(T)); | |||||
| input_size_list_.push_back(element_num_ * sizeof(S)); | |||||
| input_size_list_.push_back(sizeof(T)); | |||||
| output_size_list_.push_back(element_num_ * sizeof(T)); | |||||
| } | |||||
| private: | |||||
| void LaunchCombineMom(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const cudaStream_t &stream, const std::unique_ptr<T *[]> &scale, | |||||
| const std::unique_ptr<T *[]> &variable, const std::unique_ptr<T *[]> &accumulation, | |||||
| const std::unique_ptr<T *[]> &learning_rate, const std::unique_ptr<S *[]> &gradient, | |||||
| const std::unique_ptr<T *[]> &momentum) { | |||||
| for (size_t i = 0; i < num_; i++) { | |||||
| scale[i] = GetDeviceAddress<T>(inputs, i * input_num_); | |||||
| variable[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 1); | |||||
| accumulation[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 2); | |||||
| learning_rate[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 3); | |||||
| gradient[i] = GetDeviceAddress<S>(inputs, i * input_num_ + 4); | |||||
| momentum[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 5); | |||||
| } | |||||
| T **scale_dev = GetDeviceAddress<T *>(workspace, 0); | |||||
| T **variable_dev = GetDeviceAddress<T *>(workspace, 1); | |||||
| T **accumulation_dev = GetDeviceAddress<T *>(workspace, 2); | |||||
| T **learning_rate_dev = GetDeviceAddress<T *>(workspace, 3); | |||||
| S **gradient_dev = GetDeviceAddress<S *>(workspace, 4); | |||||
| T **momentum_dev = GetDeviceAddress<T *>(workspace, 5); | |||||
| size_t *elements_dev = GetDeviceAddress<size_t>(workspace, 6); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(scale_dev, scale.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(variable_dev, variable.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(accumulation_dev, accumulation.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(learning_rate_dev, learning_rate.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(gradient_dev, gradient.get(), sizeof(S *) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(momentum_dev, momentum.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(elements_dev, elements_.get(), sizeof(size_t) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CombineFusedScaleMomentum(max_, num_, elements_dev, scale_dev, variable_dev, accumulation_dev, learning_rate_dev, | |||||
| gradient_dev, momentum_dev, stream); | |||||
| } | |||||
| void LaunchCombineMomWeightDecay(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const cudaStream_t &stream, const std::unique_ptr<T *[]> &weight_decay, | |||||
| const std::unique_ptr<T *[]> &scale, const std::unique_ptr<T *[]> &variable, | |||||
| const std::unique_ptr<T *[]> &accumulation, | |||||
| const std::unique_ptr<T *[]> &learning_rate, const std::unique_ptr<S *[]> &gradient, | |||||
| const std::unique_ptr<T *[]> &momentum) { | |||||
| for (size_t i = 0; i < num_; i++) { | |||||
| weight_decay[i] = GetDeviceAddress<T>(inputs, i * input_num_); | |||||
| scale[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 1); | |||||
| variable[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 2); | |||||
| accumulation[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 3); | |||||
| learning_rate[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 4); | |||||
| gradient[i] = GetDeviceAddress<S>(inputs, i * input_num_ + 5); | |||||
| momentum[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 6); | |||||
| } | |||||
| T **weight_decay_dev = GetDeviceAddress<T *>(workspace, 0); | |||||
| T **scale_dev = GetDeviceAddress<T *>(workspace, 1); | |||||
| T **variable_dev = GetDeviceAddress<T *>(workspace, 2); | |||||
| T **accumulation_dev = GetDeviceAddress<T *>(workspace, 3); | |||||
| T **learning_rate_dev = GetDeviceAddress<T *>(workspace, 4); | |||||
| S **gradient_dev = GetDeviceAddress<S *>(workspace, 5); | |||||
| T **momentum_dev = GetDeviceAddress<T *>(workspace, 6); | |||||
| size_t *elements_dev = GetDeviceAddress<size_t>(workspace, 7); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(weight_decay_dev, weight_decay.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(scale_dev, scale.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(variable_dev, variable.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(accumulation_dev, accumulation.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(learning_rate_dev, learning_rate.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(gradient_dev, gradient.get(), sizeof(S *) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(momentum_dev, momentum.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(elements_dev, elements_.get(), sizeof(size_t) * num_, cudaMemcpyHostToDevice, stream), | |||||
| "cudaMemCPY failed") | |||||
| CombineFusedWeightDecayScaleMomentum(max_, num_, elements_dev, weight_decay_dev, scale_dev, variable_dev, | |||||
| accumulation_dev, learning_rate_dev, gradient_dev, momentum_dev, stream); | |||||
| } | |||||
| size_t element_num_; | |||||
| std::unique_ptr<size_t[]> elements_; | |||||
| size_t num_; | |||||
| size_t max_; | |||||
| int input_num_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ | |||||
| @@ -52,7 +52,7 @@ class FusedScaleMomentumGpuKernel : public GpuKernel { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| for (size_t i = 0; i < variable_shape.size(); i++) { | for (size_t i = 0; i < variable_shape.size(); i++) { | ||||
| element_num_ *= variable_shape[i]; | element_num_ *= variable_shape[i]; | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ class FusedWeightDecayScaleMomentumGpuKernel : public GpuKernel { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||||
| for (size_t i = 0; i < variable_shape.size(); i++) { | for (size_t i = 0; i < variable_shape.size(); i++) { | ||||
| element_num_ *= variable_shape[i]; | element_num_ *= variable_shape[i]; | ||||
| } | } | ||||
| @@ -0,0 +1,133 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/gpu/combine_momentum_fusion.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) { | |||||
| std::vector<std::string> inputs_device_format; | |||||
| std::vector<std::string> outputs_device_format; | |||||
| std::vector<TypeId> inputs_device_type; | |||||
| std::vector<TypeId> outputs_device_type; | |||||
| std::vector<std::vector<size_t>> outputs_shape; | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| for (size_t idx = 0; idx < node_list.size(); ++idx) { | |||||
| auto cnode = utils::cast<CNodePtr>(node_list[idx]); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||||
| inputs_device_format.push_back(kOpFormat_DEFAULT); | |||||
| inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index)); | |||||
| } | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||||
| outputs_device_format.push_back(kOpFormat_DEFAULT); | |||||
| outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index)); | |||||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); | |||||
| } | |||||
| } | |||||
| builder.SetInputsFormat(inputs_device_format); | |||||
| builder.SetOutputsFormat(outputs_device_format); | |||||
| builder.SetInputsDeviceType(inputs_device_type); | |||||
| builder.SetOutputsDeviceType(outputs_device_type); | |||||
| return builder.Build(); | |||||
| } | |||||
| bool GetDealList(const std::vector<AnfNodePtr> &node_list, std::vector<std::vector<AnfNodePtr>> *deal_list) { | |||||
| std::vector<AnfNodePtr> momentum; | |||||
| std::vector<AnfNodePtr> momentum_decay; | |||||
| for (auto &momentum_node : node_list) { | |||||
| if (momentum_node != nullptr && momentum_node->isa<CNode>()) { | |||||
| if (AnfAlgo::GetCNodeName(momentum_node) == kFusedScaleApplyMomentum) { | |||||
| momentum.push_back(momentum_node); | |||||
| } else if (AnfAlgo::GetCNodeName(momentum_node) == kFusedWeightScaleApplyMomentum) { | |||||
| momentum_decay.push_back(momentum_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (momentum.size() <= 1 && momentum_decay.size() <= 1) { | |||||
| return false; | |||||
| } | |||||
| if (momentum.size() > 1) { | |||||
| deal_list->push_back(momentum); | |||||
| } | |||||
| if (momentum_decay.size() > 1) { | |||||
| deal_list->push_back(momentum_decay); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool CombineMomentumFusion::Run(const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return()); | |||||
| // 1 get all the cast node | |||||
| std::vector<std::vector<AnfNodePtr>> deal_list; | |||||
| if (!GetDealList(node_list, &deal_list)) { | |||||
| return false; | |||||
| } | |||||
| for (auto momentums : deal_list) { | |||||
| // 2 create node momentum | |||||
| std::vector<AnfNodePtr> inputs = {}; | |||||
| if (AnfAlgo::GetCNodeName(momentums[0]) == kFusedScaleApplyMomentum) { | |||||
| auto prim = std::make_shared<Primitive>("CombineMomentum"); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| inputs.push_back(NewValueNode(prim)); | |||||
| } else { | |||||
| auto prim = std::make_shared<Primitive>("CombineMomentumWeight"); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| inputs.push_back(NewValueNode(prim)); | |||||
| } | |||||
| // set inputs for momentum | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(momentums[0]); | |||||
| for (auto mom : momentums) { | |||||
| for (size_t i = 0; i < input_num; i++) { | |||||
| inputs.push_back(AnfAlgo::GetInputNode(utils::cast<CNodePtr>(mom), i)); | |||||
| } | |||||
| } | |||||
| auto combine_mom = graph->NewCNode(inputs); | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| combine_mom->set_kernel_info(kernel_info); | |||||
| AbstractBasePtrList abstract_list; | |||||
| for (size_t idx = 0; idx < momentums.size(); ++idx) { | |||||
| auto cnode = utils::cast<CNodePtr>(momentums[idx]); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| abstract_list.push_back(cnode->abstract()); | |||||
| } | |||||
| auto kernel_build_info = GenerateKernelBuildInfo(momentums); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, combine_mom.get()); | |||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||||
| MS_EXCEPTION_IF_NULL(abstract_tuple); | |||||
| combine_mom->set_abstract(abstract_tuple); | |||||
| AnfAlgo::SetNodeAttr("n", MakeValue(momentums.size()), combine_mom); | |||||
| // 3 replace all the cast by momentum | |||||
| for (size_t idx = 0; idx < momentums.size(); ++idx) { | |||||
| if (!manager->Replace(momentums[idx], combine_mom)) { | |||||
| MS_LOG(EXCEPTION) << "manager replace node failed"; | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_MOMENTUM_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_MOMENTUM_FUSION_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class CombineMomentumFusion : public Pass { | |||||
| public: | |||||
| explicit CombineMomentumFusion(const std::string &name) : Pass("combine_momentum") {} | |||||
| ~CombineMomentumFusion() override = default; | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_MOMENTUM_FUSION_H_ | |||||