Merge pull request !8244 from DeshiChen/1104_eliminate_redundant_parametertags/v1.1.0
| @@ -18,6 +18,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" | #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" | ||||
| @@ -144,11 +145,35 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { | |||||
| return JsonDescToAnf(kernel_desc_str, ori_inputs); | return JsonDescToAnf(kernel_desc_str, ori_inputs); | ||||
| } | } | ||||
| void GraphKernelExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) { | |||||
| const auto &ori_parameter = func_graph->parameters(); | |||||
| auto todos = TopoSort(func_graph->get_return()); | |||||
| std::unordered_set<AnfNodePtr> used_param; | |||||
| for (auto node : todos) { | |||||
| if (node->isa<Parameter>()) { | |||||
| used_param.insert(node); | |||||
| } | |||||
| } | |||||
| if (used_param.size() == ori_parameter.size()) { | |||||
| return; | |||||
| } | |||||
| AnfNodePtrList new_parameter, new_inputs; | |||||
| for (size_t i = 0; i < ori_parameter.size(); ++i) { | |||||
| if (used_param.count(ori_parameter[i])) { | |||||
| new_parameter.push_back(ori_parameter[i]); | |||||
| new_inputs.push_back((*inputs)[i]); | |||||
| } | |||||
| } | |||||
| func_graph->set_parameters(new_parameter); | |||||
| *inputs = std::move(new_inputs); | |||||
| } | |||||
| AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func_graph, | AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func_graph, | ||||
| const FuncGraphPtr &new_func_graph, const CNodePtr &node) { | const FuncGraphPtr &new_func_graph, const CNodePtr &node) { | ||||
| std::vector<AnfNodePtr> inputs(node->inputs().begin() + 1, node->inputs().end()); | std::vector<AnfNodePtr> inputs(node->inputs().begin() + 1, node->inputs().end()); | ||||
| AnfNodePtrList kernel_nodes; | AnfNodePtrList kernel_nodes; | ||||
| AnfNodePtrList outputs; | AnfNodePtrList outputs; | ||||
| EliminateRedundantParameters(new_func_graph, &inputs); | |||||
| kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes); | kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes); | ||||
| kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); | kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); | ||||
| auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs, false); | auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs, false); | ||||
| @@ -184,7 +209,6 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { | |||||
| auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node); | auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node); | ||||
| new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node))); | new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node))); | ||||
| MS_LOG(INFO) << "create new cnode success."; | |||||
| // replace origin node. | // replace origin node. | ||||
| (void)mng->Replace(node, graph_kernel_node); | (void)mng->Replace(node, graph_kernel_node); | ||||
| @@ -32,6 +32,7 @@ class GraphKernelExpander : public Pass { | |||||
| FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node); | FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node); | ||||
| bool DoExpand(const FuncGraphPtr &func_graph); | bool DoExpand(const FuncGraphPtr &func_graph); | ||||
| void ToPrimitive(const FuncGraphPtr &func_graph) const; | void ToPrimitive(const FuncGraphPtr &func_graph) const; | ||||
| void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs); | |||||
| AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph, | AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph, | ||||
| const CNodePtr &node); | const CNodePtr &node); | ||||
| bool CanExpand(const CNodePtr &node) { | bool CanExpand(const CNodePtr &node) { | ||||
| @@ -702,10 +702,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | std::unordered_set<PrimitivePtr> GetExpandOps() { | ||||
| std::unordered_set<PrimitivePtr> expand_ops = { | std::unordered_set<PrimitivePtr> expand_ops = { | ||||
| prim::kPrimSquare, | |||||
| prim::kPrimBiasAdd, | |||||
| prim::kPrimBiasAddGrad, | |||||
| prim::kPrimGelu, | |||||
| prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, prim::kPrimGeluGrad, | |||||
| }; | }; | ||||
| return expand_ops; | return expand_ops; | ||||
| } | } | ||||
| @@ -0,0 +1,83 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| from mindspore import Tensor | |||||
| from mindspore.nn import Cell | |||||
| import mindspore.ops.operations as P | |||||
| import mindspore.ops.operations._grad_ops as G | |||||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") | |||||
| class GeluNet(Cell): | |||||
| def __init__(self): | |||||
| super(GeluNet, self).__init__() | |||||
| self.gelu = P.Gelu() | |||||
| def construct(self, x): | |||||
| return self.gelu(x) | |||||
| class GeluGradNet(Cell): | |||||
| def __init__(self): | |||||
| super(GeluGradNet, self).__init__() | |||||
| self.gelu_grad = G.GeluGrad() | |||||
| def construct(self, dy, x, y): | |||||
| return self.gelu_grad(dy, x, y) | |||||
| def CalGelu(x): | |||||
| tmp = np.sqrt(2.0 / np.pi) * (x + 0.044715 * x * x * x) | |||||
| expect = 0.5 * x * (1.0 + np.tanh(tmp)) | |||||
| return expect | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_gelu(): | |||||
| input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||||
| net = GeluNet() | |||||
| result = net(Tensor(input_x)) | |||||
| expect = CalGelu(input_x) | |||||
| res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True) | |||||
| assert res | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_gelu_grad(): | |||||
| input_dy = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||||
| input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||||
| input_y = CalGelu(input_x) | |||||
| net = GeluGradNet() | |||||
| result = net(Tensor(input_dy), Tensor(input_x), Tensor(input_y)) | |||||
| tanh_res = np.tanh(0.7978845608 * (input_x + 0.044715 * input_x * input_x * input_x)) | |||||
| mul_right = 0.7978845608 + 0.1070322244 * input_x * input_x | |||||
| dx = 0.5 * (1.0 + tanh_res) + 0.5 * input_x * (1.0 - tanh_res * tanh_res) * mul_right | |||||
| expect = input_dy * dx | |||||
| res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True) | |||||
| assert res | |||||