diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/raise_reduction_precision.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/raise_reduction_precision.cc new file mode 100644 index 0000000000..eb6a52fcbb --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/raise_reduction_precision.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/graph_kernel/raise_reduction_precision.h" + +#include +#include +#include +#include + +#include "base/core_ops.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/tensor.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace opt { +bool RaiseReductionPrecision::IsFp16ReduceSum(const AnfNodePtr &node) { + return IsPrimitiveCNode(node, prim::kPrimReduceSum) && AnfAlgo::GetInputDeviceDataType(node, 0) == kNumberTypeFloat16; +} + +AnfNodePtr RaiseReductionPrecision::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type, std::string format) { + auto func_graph = input->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast->Clone()), input}; + auto cnode = CreateCNode(inputs, func_graph, {.format = format, .shape = GetShape(input), .type = dst_type}); + AnfAlgo::SetNodeAttr("dst_type", MakeValue(kernel::TypeId2String(dst_type->type_id())), cnode); + return cnode; +} + +AnfNodePtr RaiseReductionPrecision::CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + cnode->set_input(1, input); + cnode->set_abstract(std::make_shared(kFloat32, GetShape(node))); + kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder; + info_builder.SetInputsFormat({AnfAlgo::GetInputFormat(node, 0)}); + info_builder.SetInputsDeviceType({kFloat32->type_id()}); + info_builder.SetOutputsFormat({AnfAlgo::GetOutputFormat(node, 0)}); + info_builder.SetOutputsDeviceType({kFloat32->type_id()}); + info_builder.SetProcessor(AnfAlgo::GetProcessor(node)); + info_builder.SetKernelType(KernelType::AKG_KERNEL); + info_builder.SetFusionType(kernel::FusionType::OPAQUE); + AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), cnode.get()); + return node; +} + +void RaiseReductionPrecision::ReplaceNode(const AnfNodePtr &reduce_node, const AnfNodePtr &cast_node) { + auto mng = reduce_node->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(mng); + // use a copy of user, since the following `mng->Replace` will change the original users of reduce_node. + auto users = mng->node_users()[reduce_node]; + for (const auto &user : users) { + auto user_node = user.first; + auto user_index = user.second; + if (IsPrimitiveCNode(user_node, prim::kPrimCast) && + AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat32) { + mng->Replace(user_node, reduce_node); + } else { + if (user_node->isa()) { + user_node->cast()->set_input(user_index, cast_node); + } + } + } +} + +bool RaiseReductionPrecision::Process(const FuncGraphPtr &func_graph) { + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + auto todos = TopoSort(func_graph->get_return()); + bool changed = false; + for (auto node : todos) { + if (IsFp16ReduceSum(node)) { + auto cast1 = CreateCast(node->cast()->input(1), kFloat32, AnfAlgo::GetInputFormat(node, 0)); + auto new_reduce = CreateReduceSum(node, cast1); + auto cast2 = CreateCast(new_reduce, kFloat16, AnfAlgo::GetOutputFormat(node, 0)); + ReplaceNode(node, cast2); + changed = true; + } + } + if (changed) { + mng->RemoveRoots(); + mng->KeepRoots({func_graph}); + } + return changed; +} + +bool RaiseReductionPrecision::Run(const FuncGraphPtr &func_graph) { + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + auto todos = TopoSort(func_graph->get_return()); + bool changed = false; + for (const auto &node : todos) { + if (AnfAlgo::IsGraphKernel(node)) { + auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_ERROR_IF_NULL(sub_func_graph); + changed = Process(sub_func_graph) || changed; + } + } + if (changed) { + mng->RemoveRoots(); + mng->KeepRoots({func_graph}); + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/raise_reduction_precision.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/raise_reduction_precision.h new file mode 100644 index 0000000000..1db4fe7864 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/raise_reduction_precision.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_RAISE_REDUCTION_PRECISION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_RAISE_REDUCTION_PRECISION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class RaiseReductionPrecision : public Pass { + public: + RaiseReductionPrecision() : Pass("raise_reduction_precision") {} + ~RaiseReductionPrecision() override = default; + bool Run(const FuncGraphPtr &func_graph); + + private: + bool IsFp16ReduceSum(const AnfNodePtr &node); + bool Process(const FuncGraphPtr &func_graph); + AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type, std::string format); + AnfNodePtr CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input); + void ReplaceNode(const AnfNodePtr &src_node, const AnfNodePtr &dst_node); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_RAISE_REDUCTION_PRECISION_H_ diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 69e3dd5001..04d7345f21 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -47,6 +47,7 @@ #include "backend/optimizer/graph_kernel/tensor_promotion.h" #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" +#include "backend/optimizer/graph_kernel/raise_reduction_precision.h" #include "backend/optimizer/graph_kernel/graph_kernel_cse.h" #include "backend/optimizer/graph_kernel/shape_ops_splitter.h" #include "backend/optimizer/graph_kernel/value_graph_binder.h" @@ -182,6 +183,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr &kernel_ pm->AddPass(std::make_shared(duplicated_ops)); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared(duplicated_ops)); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared(duplicated_ops));