| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "adam_weight_decay_impl.cuh" | |||||
| #include "device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| __global__ void AdamWeightDecayKernel(const int element_num_, const bool need_decay, const float *beta1, | |||||
| const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, | |||||
| const float *epsilon, const float *lr, const float *weight_decay, T *m, T *v, | |||||
| T *param, T *gradient) { | |||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num_; i += blockDim.x * gridDim.x) { | |||||
| float next_m = beta1[0] * m[i] + one_sub_beta1[0] * gradient[i]; | |||||
| float next_v = beta2[0] * v[i] + one_sub_beta2[0] * gradient[i] * gradient[i]; | |||||
| float update = next_m / (sqrt(next_v) + epsilon[0]); | |||||
| if (need_decay && weight_decay != nullptr) { | |||||
| update += weight_decay[0] * param[i]; | |||||
| } | |||||
| param[i] -= lr[0] * update; | |||||
| m[i] = next_m; | |||||
| v[i] = next_v; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1, | |||||
| const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr, | |||||
| const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream) { | |||||
| AdamWeightDecayKernel<<<GET_BLOCKS(element_num_), GET_THREADS, 0, stream>>>( | |||||
| element_num_, need_decay, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, param, | |||||
| gradient); | |||||
| } | |||||
| template void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, | |||||
| const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, | |||||
| const float *epsilon, const float *lr, const float *weight_decay, float *m, float *v, | |||||
| float *param, float *gradient, cudaStream_t stream); | |||||
| @@ -0,0 +1,24 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_ | |||||
| template <typename T> | |||||
| void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1, | |||||
| const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr, | |||||
| const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_ | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/nn/fused_adam_weight_decay.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(FusedAdamWeightDecay, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| FusedAdamWeightDecayGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(FusedAdam, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| FusedAdamWeightDecayGpuKernel, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,103 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "kernel/gpu/gpu_kernel.h" | |||||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||||
| #include "kernel/gpu/kernel_constants.h" | |||||
| #include "kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class FusedAdamWeightDecayGpuKernel : public GpuKernel { | |||||
| public: | |||||
| FusedAdamWeightDecayGpuKernel() : element_nums_(0), weight_decay_(false) {} | |||||
| ~FusedAdamWeightDecayGpuKernel() override = default; | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| auto node_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| if (node_name == "AdamWeighDecay") { | |||||
| weight_decay_ = true; | |||||
| } | |||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 7); | |||||
| element_nums_ = 1; | |||||
| for (auto i : shape) { | |||||
| element_nums_ *= i; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| float *beta1 = GetDeviceAddress<float>(inputs, 0); | |||||
| float *one_sub_beta1 = GetDeviceAddress<float>(inputs, 1); | |||||
| float *beta2 = GetDeviceAddress<float>(inputs, 2); | |||||
| float *one_sub_beta2 = GetDeviceAddress<float>(inputs, 3); | |||||
| float *epsilon = GetDeviceAddress<float>(inputs, 4); | |||||
| float *lr = GetDeviceAddress<float>(inputs, 5); | |||||
| T *param = GetDeviceAddress<T>(inputs, 6); | |||||
| T *m = GetDeviceAddress<T>(inputs, 7); | |||||
| T *v = GetDeviceAddress<T>(inputs, 8); | |||||
| T *gradient = GetDeviceAddress<T>(inputs, 9); | |||||
| float *weight_decay = nullptr; | |||||
| if (weight_decay_) { | |||||
| weight_decay = GetDeviceAddress<float>(inputs, 10); | |||||
| } | |||||
| AdamWeightDecay(element_nums_, true, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, | |||||
| param, gradient, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitResource() override{}; | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(sizeof(float)); | |||||
| input_size_list_.push_back(sizeof(float)); | |||||
| input_size_list_.push_back(sizeof(float)); | |||||
| input_size_list_.push_back(sizeof(float)); | |||||
| input_size_list_.push_back(element_nums_ * sizeof(T)); | |||||
| input_size_list_.push_back(sizeof(float)); | |||||
| input_size_list_.push_back(sizeof(float)); | |||||
| input_size_list_.push_back(element_nums_ * sizeof(T)); | |||||
| if (weight_decay_) { | |||||
| input_size_list_.push_back(sizeof(float)); | |||||
| } | |||||
| output_size_list_.push_back(element_nums_ * sizeof(T)); | |||||
| } | |||||
| private: | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| int element_nums_; | |||||
| bool weight_decay_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ | |||||
| @@ -182,9 +182,11 @@ extern const PrimitivePtr kPrimReduceMin; | |||||
| extern const PrimitivePtr kPrimNeg; | extern const PrimitivePtr kPrimNeg; | ||||
| extern const PrimitivePtr kPrimSub; | extern const PrimitivePtr kPrimSub; | ||||
| extern const PrimitivePtr kPrimMul; | extern const PrimitivePtr kPrimMul; | ||||
| extern const PrimitivePtr kPrimRealDiv; | |||||
| extern const PrimitivePtr kPrimMinimum; | extern const PrimitivePtr kPrimMinimum; | ||||
| extern const PrimitivePtr kPrimMaximum; | extern const PrimitivePtr kPrimMaximum; | ||||
| extern const PrimitivePtr kPrimSquare; | extern const PrimitivePtr kPrimSquare; | ||||
| extern const PrimitivePtr kPrimSqrt; | |||||
| extern const PrimitivePtr kPrimEqual; | extern const PrimitivePtr kPrimEqual; | ||||
| extern const PrimitivePtr kPrimLess; | extern const PrimitivePtr kPrimLess; | ||||
| extern const PrimitivePtr kPrimLessEqual; | extern const PrimitivePtr kPrimLessEqual; | ||||
| @@ -0,0 +1,112 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/gpu/adam_fusion.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||||
| std::vector<std::string> inputs_format; | |||||
| std::vector<std::string> outputs_format; | |||||
| std::vector<TypeId> inputs_type; | |||||
| std::vector<TypeId> outputs_type; | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { | |||||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); | |||||
| inputs_format.push_back(kOpFormat_DEFAULT); | |||||
| } | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { | |||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); | |||||
| outputs_format.push_back(kOpFormat_DEFAULT); | |||||
| } | |||||
| builder.SetInputsDeviceType(inputs_type); | |||||
| builder.SetInputsFormat(inputs_format); | |||||
| builder.SetOutputsDeviceType(outputs_type); | |||||
| builder.SetOutputsFormat(outputs_format); | |||||
| return builder.Build(); | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef AdamFusion::DefinePattern() const { | |||||
| VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), | |||||
| VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | |||||
| VectorRef next_v = | |||||
| VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), | |||||
| VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | |||||
| VectorRef update = VectorRef( | |||||
| {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | |||||
| VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); | |||||
| VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); | |||||
| VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); | |||||
| VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); | |||||
| return depend3; | |||||
| } | |||||
| const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| auto beta1_input = utils::cast<AnfNodePtr>((*equiv)[beta1_]); | |||||
| auto one_sub_beta1_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta1_]); | |||||
| auto beta2_input = utils::cast<AnfNodePtr>((*equiv)[beta2_]); | |||||
| auto one_sub_beta2_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta2_]); | |||||
| auto eps_input = utils::cast<AnfNodePtr>((*equiv)[eps_]); | |||||
| auto lr_input = utils::cast<AnfNodePtr>((*equiv)[lr_]); | |||||
| auto param_input = utils::cast<AnfNodePtr>((*equiv)[param_]); | |||||
| auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]); | |||||
| auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]); | |||||
| auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]); | |||||
| MS_EXCEPTION_IF_NULL(beta1_input); | |||||
| MS_EXCEPTION_IF_NULL(one_sub_beta1_input); | |||||
| MS_EXCEPTION_IF_NULL(beta2_input); | |||||
| MS_EXCEPTION_IF_NULL(one_sub_beta2_input); | |||||
| MS_EXCEPTION_IF_NULL(eps_input); | |||||
| MS_EXCEPTION_IF_NULL(lr_input); | |||||
| MS_EXCEPTION_IF_NULL(param_input); | |||||
| MS_EXCEPTION_IF_NULL(m_input); | |||||
| MS_EXCEPTION_IF_NULL(v_input); | |||||
| MS_EXCEPTION_IF_NULL(gradient_input); | |||||
| auto prim = std::make_shared<Primitive>(kFusedAdamName); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| std::vector<AnfNodePtr> inputs = { | |||||
| NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, | |||||
| eps_input, lr_input, param_input, m_input, v_input, | |||||
| gradient_input}; | |||||
| auto adam = graph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(adam); | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||||
| auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get()); | |||||
| adam->set_scope(node->scope()); | |||||
| auto build_info = GenerateKernelBuildInfo(adam); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get()); | |||||
| return adam; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ | |||||
| #include <memory> | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class AdamFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit AdamFusion(bool multigraph = true) : PatternProcessPass("adam_fusion", multigraph) { | |||||
| beta1_ = std::make_shared<Var>(); | |||||
| one_sub_beta1_ = std::make_shared<Var>(); | |||||
| beta2_ = std::make_shared<Var>(); | |||||
| one_sub_beta2_ = std::make_shared<Var>(); | |||||
| eps_ = std::make_shared<Var>(); | |||||
| lr_ = std::make_shared<Var>(); | |||||
| param_ = std::make_shared<Var>(); | |||||
| m_ = std::make_shared<Var>(); | |||||
| v_ = std::make_shared<Var>(); | |||||
| gradient_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~AdamFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| VarPtr beta1_; | |||||
| VarPtr one_sub_beta1_; | |||||
| VarPtr beta2_; | |||||
| VarPtr one_sub_beta2_; | |||||
| VarPtr eps_; | |||||
| VarPtr lr_; | |||||
| VarPtr param_; | |||||
| VarPtr m_; | |||||
| VarPtr v_; | |||||
| VarPtr gradient_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ | |||||
| @@ -0,0 +1,117 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/gpu/adam_weight_decay_fusion.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||||
| std::vector<std::string> inputs_format; | |||||
| std::vector<std::string> outputs_format; | |||||
| std::vector<TypeId> inputs_type; | |||||
| std::vector<TypeId> outputs_type; | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { | |||||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); | |||||
| inputs_format.push_back(kOpFormat_DEFAULT); | |||||
| } | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { | |||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); | |||||
| outputs_format.push_back(kOpFormat_DEFAULT); | |||||
| } | |||||
| builder.SetInputsDeviceType(inputs_type); | |||||
| builder.SetInputsFormat(inputs_format); | |||||
| builder.SetOutputsDeviceType(outputs_type); | |||||
| builder.SetOutputsFormat(outputs_format); | |||||
| return builder.Build(); | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef AdamWeightDecayFusion::DefinePattern() const { | |||||
| VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), | |||||
| VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | |||||
| VectorRef next_v = | |||||
| VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), | |||||
| VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | |||||
| VectorRef update = VectorRef( | |||||
| {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | |||||
| VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); | |||||
| VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); | |||||
| VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); | |||||
| VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); | |||||
| VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); | |||||
| return depend3; | |||||
| } | |||||
| const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| auto beta1_input = utils::cast<AnfNodePtr>((*equiv)[beta1_]); | |||||
| auto one_sub_beta1_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta1_]); | |||||
| auto beta2_input = utils::cast<AnfNodePtr>((*equiv)[beta2_]); | |||||
| auto one_sub_beta2_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta2_]); | |||||
| auto eps_input = utils::cast<AnfNodePtr>((*equiv)[eps_]); | |||||
| auto lr_input = utils::cast<AnfNodePtr>((*equiv)[lr_]); | |||||
| auto weight_decay_input = utils::cast<AnfNodePtr>((*equiv)[weight_decay_]); | |||||
| auto param_input = utils::cast<AnfNodePtr>((*equiv)[param_]); | |||||
| auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]); | |||||
| auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]); | |||||
| auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]); | |||||
| MS_EXCEPTION_IF_NULL(beta1_input); | |||||
| MS_EXCEPTION_IF_NULL(one_sub_beta1_input); | |||||
| MS_EXCEPTION_IF_NULL(beta2_input); | |||||
| MS_EXCEPTION_IF_NULL(one_sub_beta2_input); | |||||
| MS_EXCEPTION_IF_NULL(eps_input); | |||||
| MS_EXCEPTION_IF_NULL(lr_input); | |||||
| MS_EXCEPTION_IF_NULL(weight_decay_input); | |||||
| MS_EXCEPTION_IF_NULL(param_input); | |||||
| MS_EXCEPTION_IF_NULL(m_input); | |||||
| MS_EXCEPTION_IF_NULL(v_input); | |||||
| MS_EXCEPTION_IF_NULL(gradient_input); | |||||
| auto prim = std::make_shared<Primitive>(kFusedAdamWeightDecayName); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| std::vector<AnfNodePtr> inputs = { | |||||
| NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, | |||||
| eps_input, lr_input, param_input, m_input, v_input, | |||||
| gradient_input, weight_decay_input}; | |||||
| auto adam_weight_decay = graph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(adam_weight_decay); | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||||
| auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam_weight_decay.get()); | |||||
| adam_weight_decay->set_scope(node->scope()); | |||||
| auto build_info = GenerateKernelBuildInfo(adam_weight_decay); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get()); | |||||
| return adam_weight_decay; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ | |||||
| #include <memory> | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class AdamWeightDecayFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit AdamWeightDecayFusion(bool multigraph = true) : PatternProcessPass("adam_weight_decay_fusion", multigraph) { | |||||
| beta1_ = std::make_shared<Var>(); | |||||
| one_sub_beta1_ = std::make_shared<Var>(); | |||||
| beta2_ = std::make_shared<Var>(); | |||||
| one_sub_beta2_ = std::make_shared<Var>(); | |||||
| eps_ = std::make_shared<Var>(); | |||||
| lr_ = std::make_shared<Var>(); | |||||
| weight_decay_ = std::make_shared<Var>(); | |||||
| param_ = std::make_shared<Var>(); | |||||
| m_ = std::make_shared<Var>(); | |||||
| v_ = std::make_shared<Var>(); | |||||
| gradient_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~AdamWeightDecayFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| VarPtr beta1_; | |||||
| VarPtr one_sub_beta1_; | |||||
| VarPtr beta2_; | |||||
| VarPtr one_sub_beta2_; | |||||
| VarPtr eps_; | |||||
| VarPtr lr_; | |||||
| VarPtr weight_decay_; | |||||
| VarPtr param_; | |||||
| VarPtr m_; | |||||
| VarPtr v_; | |||||
| VarPtr gradient_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ | |||||
| @@ -23,6 +23,8 @@ | |||||
| #include "pre_activate/common/helper.h" | #include "pre_activate/common/helper.h" | ||||
| #include "pre_activate/pass/communication_op_fusion.h" | #include "pre_activate/pass/communication_op_fusion.h" | ||||
| #include "pre_activate/pass/getitem_tuple.h" | #include "pre_activate/pass/getitem_tuple.h" | ||||
| #include "pre_activate/gpu/adam_weight_decay_fusion.h" | |||||
| #include "pre_activate/gpu/adam_fusion.h" | |||||
| #include "device/kernel_runtime_manager.h" | #include "device/kernel_runtime_manager.h" | ||||
| #include "predict/predict.h" | #include "predict/predict.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| @@ -53,6 +55,16 @@ void GPUSession::StartKernelRT() const { | |||||
| void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::AdamFusion>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| } | |||||
| void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | ||||
| @@ -151,14 +163,16 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList | |||||
| auto graph_id = graph_sum_; | auto graph_id = graph_sum_; | ||||
| auto graph = ConstructKernelGraph(lst, outputs); | auto graph = ConstructKernelGraph(lst, outputs); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| // Optimize | |||||
| Optimize(graph); | |||||
| // Select kernel build info | // Select kernel build info | ||||
| SelectKernel(graph); | SelectKernel(graph); | ||||
| // Convert kernel Graph to model | // Convert kernel Graph to model | ||||
| predictmodel::StepConvertGraph(graph); | predictmodel::StepConvertGraph(graph); | ||||
| // Start gpu kernel runtime | // Start gpu kernel runtime | ||||
| StartKernelRT(); | StartKernelRT(); | ||||
| // AllReduce Optimize | |||||
| Optimize(graph); | |||||
| // HardwareOptimize | |||||
| HardwareOptimize(graph); | |||||
| // Assign CUDA streams | // Assign CUDA streams | ||||
| AssignStream(graph); | AssignStream(graph); | ||||
| // Hide NoOp from execution graph | // Hide NoOp from execution graph | ||||
| @@ -51,6 +51,8 @@ class GPUSession : public SessionBasic { | |||||
| void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph); | void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph); | ||||
| void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph); | |||||
| void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph); | void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph); | ||||
| void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| @@ -161,6 +161,8 @@ constexpr auto kNMSWithMaskOpName = "NMSWithMask"; | |||||
| constexpr auto kSoftmaxGradExtOpName = "SoftmaxGradExt"; | constexpr auto kSoftmaxGradExtOpName = "SoftmaxGradExt"; | ||||
| constexpr auto kStridedReadOpName = "StridedRead"; | constexpr auto kStridedReadOpName = "StridedRead"; | ||||
| constexpr auto kStridedWriteOpName = "StridedWrite"; | constexpr auto kStridedWriteOpName = "StridedWrite"; | ||||
| constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay"; | |||||
| constexpr auto kFusedAdamName = "FusedAdam"; | |||||
| // attr key name | // attr key name | ||||
| constexpr auto kAttrInputNames = "input_names"; | constexpr auto kAttrInputNames = "input_names"; | ||||
| @@ -0,0 +1,87 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.common.api import ms_function | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, decay_flag=True): | |||||
| super(Net, self).__init__() | |||||
| self.decay_flag = decay_flag | |||||
| self.op_mul = P.Mul() | |||||
| self.op_square = P.Square() | |||||
| self.op_sqrt = P.Sqrt() | |||||
| self.op_cast = P.Cast() | |||||
| self.op_reshape = P.Reshape() | |||||
| self.op_shape = P.Shape() | |||||
| self.param = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='param') | |||||
| self.m = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='m') | |||||
| self.v = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='v') | |||||
| @ms_function | |||||
| def construct(self, beta1, beta2, gradient, eps, weight_decay_tensor, lr): | |||||
| param_fp32 = self.op_cast(self.param, mstype.float32) | |||||
| m_fp32 = self.op_cast(self.m, mstype.float32) | |||||
| v_fp32 = self.op_cast(self.v, mstype.float32) | |||||
| gradient_fp32 = self.op_cast(gradient, mstype.float32) | |||||
| next_m = self.op_mul(beta1, m_fp32) + \ | |||||
| self.op_mul(self.op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) | |||||
| next_v = self.op_mul(beta2, v_fp32) + self.op_mul(self.op_cast(F.tuple_to_array((1.0,)), mstype.float32) - \ | |||||
| beta2, self.op_square(gradient_fp32)) | |||||
| update = next_m / (eps + self.op_sqrt(next_v)) | |||||
| if self.decay_flag: | |||||
| update = self.op_mul(weight_decay_tensor, param_fp32) + update | |||||
| update_with_lr = self.op_mul(lr, update) | |||||
| next_param = param_fp32 - self.op_reshape(update_with_lr, self.op_shape(param_fp32)) | |||||
| next_v = F.depend(next_v, F.assign(self.param, next_param)) | |||||
| next_v = F.depend(next_v, F.assign(self.m, next_m)) | |||||
| next_v = F.depend(next_v, F.assign(self.v, next_v)) | |||||
| return next_v | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test(): | |||||
| beta1 = Tensor(np.array([0.9]).astype(np.float32)) | |||||
| beta2 = Tensor(np.array([0.999]).astype(np.float32)) | |||||
| lr = Tensor(np.array([0.001]).astype(np.float32)) | |||||
| eps = Tensor(np.array([1e-6]).astype(np.float32)) | |||||
| weight_decay_tensor = Tensor(np.array([0.001]).astype(np.float32)) | |||||
| gradient = Tensor(np.array([0.01, 0.03, 0.05]).astype(np.float32)) | |||||
| opt = Net(True) | |||||
| _ = opt(beta1, beta2, gradient, eps, weight_decay_tensor, lr) | |||||
| param_expect = np.array([0.09971199, 0.29950103, 0.4993557]).astype(np.float32) | |||||
| m_expect = np.array([0.091, 0.273, 0.45499998]).astype(np.float32) | |||||
| v_expect = np.array([0.0999001, 0.29970092, 0.4995025]).astype(np.float32) | |||||
| assert np.allclose(opt.param.data.asnumpy(), param_expect) | |||||
| assert np.allclose(opt.m.data.asnumpy(), m_expect) | |||||
| assert np.allclose(opt.v.data.asnumpy(), v_expect) | |||||
| @@ -119,6 +119,10 @@ def test_4d_transpose_ab(): | |||||
| [[5612, 5810, 6008, 6206]]]] | [[5612, 5810, 6008, 6206]]]] | ||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_4D_fp16(): | def test_4D_fp16(): | ||||
| input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float16) | input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float16) | ||||
| input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float16) | input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float16) | ||||
| @@ -126,13 +130,13 @@ def test_4D_fp16(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| net = BatchMatMulNet() | net = BatchMatMulNet() | ||||
| output = net(input_x, input_y) | output = net(input_x, input_y) | ||||
| expect = [[[[20, 23, 26, 29]], | |||||
| [[200, 212, 224, 236]], | |||||
| [[596, 617, 638, 659]], | |||||
| [[1208, 1238, 1268, 1298]]], | |||||
| [[[2036, 2075, 2114, 2153]], | |||||
| [[3080, 3128, 3176, 3224]], | |||||
| [[4340, 4397, 4454, 4511]], | |||||
| [[5816, 5882, 5948, 6014]]]] | |||||
| expect = np.array([[[[20, 23, 26, 29]], | |||||
| [[200, 212, 224, 236]], | |||||
| [[596, 617, 638, 659]], | |||||
| [[1208, 1238, 1268, 1298]]], | |||||
| [[[2036, 2076, 2114, 2152]], | |||||
| [[3080, 3128, 3176, 3224]], | |||||
| [[4340, 4396, 4456, 4510]], | |||||
| [[5816, 5880, 5948, 6016]]]]).astype(np.float16) | |||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||