From b86016a26ff4808b8afc660602b9b8a5385bd530 Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Thu, 4 Jun 2020 16:51:20 +0800 Subject: [PATCH] remove the useless transdata and cast of control depend node --- .../ascend/format_type/merge_cast_to_op.cc | 23 ++++--- mindspore/ccsrc/pre_activate/common/helper.cc | 30 ++++++-- mindspore/ccsrc/pre_activate/common/helper.h | 4 ++ .../pre_activate/pass/optimize_dependence.cc | 68 ++++++++++++------- .../models/bert/test_bert_tdt_lossscale.py | 18 ++--- .../pass/optimize_dependence_test.cc | 42 ++++++++++++ .../pre_activate/optimize_dependence_test.py | 40 +++++++++++ 7 files changed, 179 insertions(+), 46 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc index dc47757e5d..8bb58c18a5 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc @@ -61,16 +61,14 @@ bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, bool GetNextNodeAndCastIndex(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodePtr *next_node, size_t *cast_index) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - // Check whether the cast node is used for input by only one another node. - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(node) == manager->node_users().end() || manager->node_users()[node].size() != 1) { + auto output_node_list = GetRealNodeUsedList(graph, node); + MS_EXCEPTION_IF_NULL(output_node_list); + if (output_node_list->size() != 1) { return false; } - *next_node = manager->node_users()[node].begin()->first; - *cast_index = IntToSize(manager->node_users()[node].begin()->second - 1); + auto node_pair = output_node_list->at(0); + *next_node = node_pair.first; + *cast_index = node_pair.second - 1; return true; } @@ -148,7 +146,10 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co if (alternative_kernel_info == kernel_info_list.end()) { return nullptr; } - MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_op_name; + auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(next_node); + MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_cnode->DebugString() + << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" + << (*alternative_kernel_info)->ToString(); AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); if (node->inputs().size() < kCastInputNum) { auto op_name = AnfAlgo::GetCNodeName(node); @@ -217,6 +218,10 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod if (kernel_info_it == kernel_info_list.end()) { return nullptr; } + auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(prior_op); + MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << prior_op->DebugString() + << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" + << (*kernel_info_it)->ToString(); AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get()); auto prior_name = AnfAlgo::GetCNodeName(prior_op); diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc index 896ab71e09..82a5550bd8 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ b/mindspore/ccsrc/pre_activate/common/helper.cc @@ -16,6 +16,7 @@ #include "pre_activate/common/helper.h" #include +#include #include #include #include @@ -473,15 +474,36 @@ void RemoveNopNode(session::KernelGraph *const graph) { } } -bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { +std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, + const AnfNodePtr &node) { + auto output_node_list = std::make_shared>>(); MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(node) == manager->node_users().end()) { + auto iter = manager->node_users().find(node); + if (iter == manager->node_users().end()) { MS_LOG(EXCEPTION) << "node has no output in manager"; } - return manager->node_users()[node].size() > 1; + auto output_info_list = iter->second; + for (const auto &output_info : output_info_list) { + if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { + continue; + } + if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && + output_info.second == kDependAttachNodeIndex) { + continue; + } + output_node_list->push_back(output_info); + } + return output_node_list; +} + +bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto output_node_list = GetRealNodeUsedList(graph, node); + MS_EXCEPTION_IF_NULL(output_node_list); + return output_node_list->size() > 1; } AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h index 1206fe2430..08a13671f1 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ b/mindspore/ccsrc/pre_activate/common/helper.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include "ir/func_graph.h" @@ -163,6 +164,9 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); +std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, + const AnfNodePtr &node); + void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs); bool AnfEqual(const BaseRef &a, const BaseRef &b); diff --git a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc index 86a90a4dfe..ee480b9c86 100644 --- a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc @@ -44,11 +44,11 @@ AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { return cnode->input(kSingleInputIndex); } -bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { +AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { - return false; + return nullptr; } std::vector new_make_tuple_inputs; bool need_update = false; @@ -75,17 +75,16 @@ bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { auto manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); manager->Replace(cnode, new_make_tuple); + return new_make_tuple; } - return true; + return nullptr; } } // namespace const BaseRef OptimizeDependence::DefinePattern() const { - VarPtr X = std::make_shared("X"); - MS_EXCEPTION_IF_NULL(X); - VarPtr Y = std::make_shared("Y"); - MS_EXCEPTION_IF_NULL(Y); - return VectorRef({prim::kPrimDepend, X, Y}); + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); } const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, @@ -95,27 +94,48 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con if (!node->isa()) { return nullptr; } + auto node_name = AnfAlgo::GetCNodeName(node); + if (node_name != prim::kPrimControlDepend->name() && node_name != prim::kPrimDepend->name()) { + return nullptr; + } + size_t index = 0; auto depend_cnode = node->cast(); MS_EXCEPTION_IF_NULL(depend_cnode); - CheckCNodeInputSize(depend_cnode, kDependInputNum); - auto replacing_node = depend_cnode->input(kDependInputNum - 1); - MS_EXCEPTION_IF_NULL(replacing_node); - if (!replacing_node->isa()) { - return nullptr; + std::vector new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)}; + if (node_name == prim::kPrimDepend->name()) { + index = 1; + new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend)); } - auto replacing_cnode = replacing_node->cast(); - MS_EXCEPTION_IF_NULL(replacing_cnode); - // Deal with the make_tuple with TransData or Cast inputs. - if (ReplaceMakeTuple(func_graph, replacing_cnode)) { - return nullptr; + if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) { + MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got " + << AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString(); } - AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); - if (replace_node == nullptr) { - MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); - return nullptr; + auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode); + while (index < input_num) { + auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); + ++index; + MS_EXCEPTION_IF_NULL(replacing_node); + if (!replacing_node->isa()) { + new_depend_inputs.push_back(replacing_node); + continue; + } + auto replacing_cnode = replacing_node->cast(); + MS_EXCEPTION_IF_NULL(replacing_cnode); + // Deal with the make_tuple with TransData or Cast inputs. + auto make_tuple_replace_node = ReplaceMakeTuple(func_graph, replacing_cnode); + if (make_tuple_replace_node != nullptr) { + new_depend_inputs.push_back(make_tuple_replace_node); + continue; + } + AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); + if (replace_node == nullptr) { + new_depend_inputs.push_back(replacing_node); + MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " + << node->DebugString(); + continue; + } + new_depend_inputs.push_back(replace_node); } - std::vector new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex), - depend_cnode->input(kRealInputIndexInDepend), replace_node}; auto kernel_graph = func_graph->cast>(); CNodePtr new_depend; if (kernel_graph == nullptr) { diff --git a/tests/st/networks/models/bert/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/test_bert_tdt_lossscale.py index bda4afacaa..29b4e7a542 100644 --- a/tests/st/networks/models/bert/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/test_bert_tdt_lossscale.py @@ -201,18 +201,18 @@ def test_bert_percision(): loss_value = np.array(callback.loss_list) assert np.allclose(loss_value[0], 12.206575, 0, 0.000001) - expect_loss_value = [12.206575, 11.980493, 11.984225, 11.878742, 11.832555, 12.410444, 12.008799, - 12.620619, 12.22254, 12.4261055] + expect_loss_value = [12.206575, 11.865044, 11.828129, 11.826707, 11.82108, 12.407423, 12.005459, + 12.621225, 12.222903, 12.427446] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) overflow = np.array(callback.overflow_list) - expect_overflow = [True, True, False, False, False, True, False, False, False, True] + expect_overflow = [False, False, False, True, False, False, False, True, False, False] print("overflow: {}".format(overflow)) assert (overflow == expect_overflow).all() loss_scale = np.array(callback.lossscale_list) - expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0] + expect_loss_scale = [65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0] print("loss scale: {}".format(loss_scale)) assert np.allclose(loss_scale, expect_loss_scale, 0, 0) @@ -259,27 +259,27 @@ def test_bert_performance(): # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) - expect_loss_value = [10.237753, 10.213153, 10.212972] + expect_loss_value = [10.235566, 10.207392, 10.206976] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) overflow = np.array(callback.overflow_list) - expect_overflow = [False, False, False] + expect_overflow = [True, True, True] print("overflow: {}".format(overflow)) assert (overflow == expect_overflow).all() loss_scale = np.array(callback.lossscale_list) - expect_loss_scale = [16384.0, 16384.0, 16384.0] + expect_loss_scale = [262144.0, 262144.0, 262144.0] print("loss scale: {}".format(loss_scale)) assert np.allclose(loss_scale, expect_loss_scale, 0, 0) epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2] - expect_epoch_mseconds = 1726 + expect_epoch_mseconds = 1600 print("epoch mseconds: {}".format(epoch_mseconds)) assert epoch_mseconds <= expect_epoch_mseconds + 5 per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2] - expect_per_step_mseconds = 17 + expect_per_step_mseconds = 16 print("per step mseconds: {}".format(per_step_mseconds)) assert per_step_mseconds <= expect_per_step_mseconds + 1 diff --git a/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc b/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc index e95d63e93e..04461e6602 100644 --- a/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc +++ b/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc @@ -68,5 +68,47 @@ TEST_F(TestHWOptimizeDependence, test_optimize_dependence_with_make_tuple) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_dependence_with_make_tuple", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + + +TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tuple) { + /* + * def before(x, y, a, b): + * z = make_tuple(TransData(a), TransData(b)) + * depend_intput = control_depend(y, z) + * sum = add(x, depend_intput) + * return sum + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "before"); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(g); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + + +TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence) { + /* + * def before(x, y, a, b): + * z = make_tuple(TransData(a), TransData(b)) + * depend_intput = control_depend(y, z) + * sum = add(x, depend_intput) + * return sum + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "before"); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(g); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py index cba857a5cb..2d98b50e3f 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py @@ -16,6 +16,7 @@ from mindspore.ops import Primitive from mindspore.ops import operations as P depend = P.Depend() +controldepend = Primitive("ControlDepend") TransData = Primitive('TransData') add = P.TensorAdd() make_tuple = Primitive('make_tuple') @@ -69,3 +70,42 @@ def test_optimize_dependence_with_make_tuple(tag): return sum_add return fns[tag] + + +def test_optimize_control_dependence(tag): + fns = FnDict() + + @fns + def before(x, y, z): + new_z = TransData(z) + depend_intput = controldepend(y, new_z) + sum_add = add(x, depend_intput) + return sum_add + + @fns + def after(x, y, z): + depend_intput = controldepend(y, z) + sum_add = add(x, depend_intput) + return sum_add + + return fns[tag] + + +def test_optimize_control_dependence_with_make_tuple(tag): + fns = FnDict() + + @fns + def before(x, y, a, b): + z = make_tuple(TransData(a), TransData(b)) + depend_intput = controldepend(y, z) + sum_add = add(x, depend_intput) + return sum_add + + @fns + def after(x, y, a, b): + z = make_tuple(a, b) + depend_intput = controldepend(y, z) + sum_add = add(x, depend_intput) + return sum_add + + return fns[tag]