Merge pull request !1790 from lianliguang/remove-the-useless-transdata-connected-with-the-control-dependtags/v0.5.0-beta
| @@ -61,16 +61,14 @@ bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, | |||||
| bool GetNextNodeAndCastIndex(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodePtr *next_node, | bool GetNextNodeAndCastIndex(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodePtr *next_node, | ||||
| size_t *cast_index) { | size_t *cast_index) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| // Check whether the cast node is used for input by only one another node. | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| if (manager->node_users().find(node) == manager->node_users().end() || manager->node_users()[node].size() != 1) { | |||||
| auto output_node_list = GetRealNodeUsedList(graph, node); | |||||
| MS_EXCEPTION_IF_NULL(output_node_list); | |||||
| if (output_node_list->size() != 1) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| *next_node = manager->node_users()[node].begin()->first; | |||||
| *cast_index = IntToSize(manager->node_users()[node].begin()->second - 1); | |||||
| auto node_pair = output_node_list->at(0); | |||||
| *next_node = node_pair.first; | |||||
| *cast_index = node_pair.second - 1; | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -148,7 +146,10 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co | |||||
| if (alternative_kernel_info == kernel_info_list.end()) { | if (alternative_kernel_info == kernel_info_list.end()) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_op_name; | |||||
| auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(next_node); | |||||
| MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_cnode->DebugString() | |||||
| << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" | |||||
| << (*alternative_kernel_info)->ToString(); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); | AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); | ||||
| if (node->inputs().size() < kCastInputNum) { | if (node->inputs().size() < kCastInputNum) { | ||||
| auto op_name = AnfAlgo::GetCNodeName(node); | auto op_name = AnfAlgo::GetCNodeName(node); | ||||
| @@ -217,6 +218,10 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod | |||||
| if (kernel_info_it == kernel_info_list.end()) { | if (kernel_info_it == kernel_info_list.end()) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(prior_op); | |||||
| MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << prior_op->DebugString() | |||||
| << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" | |||||
| << (*kernel_info_it)->ToString(); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get()); | AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get()); | ||||
| auto prior_name = AnfAlgo::GetCNodeName(prior_op); | auto prior_name = AnfAlgo::GetCNodeName(prior_op); | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "pre_activate/common/helper.h" | #include "pre_activate/common/helper.h" | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <map> | #include <map> | ||||
| @@ -475,15 +476,36 @@ void RemoveNopNode(session::KernelGraph *const graph) { | |||||
| } | } | ||||
| } | } | ||||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, | |||||
| const AnfNodePtr &node) { | |||||
| auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>(); | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto manager = graph->manager(); | auto manager = graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| if (manager->node_users().find(node) == manager->node_users().end()) { | |||||
| auto iter = manager->node_users().find(node); | |||||
| if (iter == manager->node_users().end()) { | |||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | MS_LOG(EXCEPTION) << "node has no output in manager"; | ||||
| } | } | ||||
| return manager->node_users()[node].size() > 1; | |||||
| auto output_info_list = iter->second; | |||||
| for (const auto &output_info : output_info_list) { | |||||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { | |||||
| continue; | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | |||||
| output_info.second == kDependAttachNodeIndex) { | |||||
| continue; | |||||
| } | |||||
| output_node_list->push_back(output_info); | |||||
| } | |||||
| return output_node_list; | |||||
| } | |||||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto output_node_list = GetRealNodeUsedList(graph, node); | |||||
| MS_EXCEPTION_IF_NULL(output_node_list); | |||||
| return output_node_list->size() > 1; | |||||
| } | } | ||||
| AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { | AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| @@ -163,6 +164,9 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | ||||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, | |||||
| const AnfNodePtr &node); | |||||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | ||||
| bool AnfEqual(const BaseRef &a, const BaseRef &b); | bool AnfEqual(const BaseRef &a, const BaseRef &b); | ||||
| @@ -44,11 +44,11 @@ AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { | |||||
| return cnode->input(kSingleInputIndex); | return cnode->input(kSingleInputIndex); | ||||
| } | } | ||||
| bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { | if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { | ||||
| return false; | |||||
| return nullptr; | |||||
| } | } | ||||
| std::vector<AnfNodePtr> new_make_tuple_inputs; | std::vector<AnfNodePtr> new_make_tuple_inputs; | ||||
| bool need_update = false; | bool need_update = false; | ||||
| @@ -75,17 +75,16 @@ bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| auto manager = func_graph->manager(); | auto manager = func_graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| manager->Replace(cnode, new_make_tuple); | manager->Replace(cnode, new_make_tuple); | ||||
| return new_make_tuple; | |||||
| } | } | ||||
| return true; | |||||
| return nullptr; | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| const BaseRef OptimizeDependence::DefinePattern() const { | const BaseRef OptimizeDependence::DefinePattern() const { | ||||
| VarPtr X = std::make_shared<Var>("X"); | |||||
| MS_EXCEPTION_IF_NULL(X); | |||||
| VarPtr Y = std::make_shared<Var>("Y"); | |||||
| MS_EXCEPTION_IF_NULL(Y); | |||||
| return VectorRef({prim::kPrimDepend, X, Y}); | |||||
| VarPtr X = std::make_shared<Var>(); | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| return VectorRef({X, Xs}); | |||||
| } | } | ||||
| const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| @@ -95,27 +94,48 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con | |||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto node_name = AnfAlgo::GetCNodeName(node); | |||||
| if (node_name != prim::kPrimControlDepend->name() && node_name != prim::kPrimDepend->name()) { | |||||
| return nullptr; | |||||
| } | |||||
| size_t index = 0; | |||||
| auto depend_cnode = node->cast<CNodePtr>(); | auto depend_cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(depend_cnode); | MS_EXCEPTION_IF_NULL(depend_cnode); | ||||
| CheckCNodeInputSize(depend_cnode, kDependInputNum); | |||||
| auto replacing_node = depend_cnode->input(kDependInputNum - 1); | |||||
| MS_EXCEPTION_IF_NULL(replacing_node); | |||||
| if (!replacing_node->isa<CNode>()) { | |||||
| return nullptr; | |||||
| std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)}; | |||||
| if (node_name == prim::kPrimDepend->name()) { | |||||
| index = 1; | |||||
| new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend)); | |||||
| } | } | ||||
| auto replacing_cnode = replacing_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(replacing_cnode); | |||||
| // Deal with the make_tuple with TransData or Cast inputs. | |||||
| if (ReplaceMakeTuple(func_graph, replacing_cnode)) { | |||||
| return nullptr; | |||||
| if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) { | |||||
| MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got " | |||||
| << AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString(); | |||||
| } | } | ||||
| AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); | |||||
| if (replace_node == nullptr) { | |||||
| MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); | |||||
| return nullptr; | |||||
| auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode); | |||||
| while (index < input_num) { | |||||
| auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); | |||||
| ++index; | |||||
| MS_EXCEPTION_IF_NULL(replacing_node); | |||||
| if (!replacing_node->isa<CNode>()) { | |||||
| new_depend_inputs.push_back(replacing_node); | |||||
| continue; | |||||
| } | |||||
| auto replacing_cnode = replacing_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(replacing_cnode); | |||||
| // Deal with the make_tuple with TransData or Cast inputs. | |||||
| auto make_tuple_replace_node = ReplaceMakeTuple(func_graph, replacing_cnode); | |||||
| if (make_tuple_replace_node != nullptr) { | |||||
| new_depend_inputs.push_back(make_tuple_replace_node); | |||||
| continue; | |||||
| } | |||||
| AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); | |||||
| if (replace_node == nullptr) { | |||||
| new_depend_inputs.push_back(replacing_node); | |||||
| MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " | |||||
| << node->DebugString(); | |||||
| continue; | |||||
| } | |||||
| new_depend_inputs.push_back(replace_node); | |||||
| } | } | ||||
| std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex), | |||||
| depend_cnode->input(kRealInputIndexInDepend), replace_node}; | |||||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | ||||
| CNodePtr new_depend; | CNodePtr new_depend; | ||||
| if (kernel_graph == nullptr) { | if (kernel_graph == nullptr) { | ||||
| @@ -201,18 +201,18 @@ def test_bert_percision(): | |||||
| loss_value = np.array(callback.loss_list) | loss_value = np.array(callback.loss_list) | ||||
| assert np.allclose(loss_value[0], 12.206575, 0, 0.000001) | assert np.allclose(loss_value[0], 12.206575, 0, 0.000001) | ||||
| expect_loss_value = [12.206575, 11.980493, 11.984225, 11.878742, 11.832555, 12.410444, 12.008799, | |||||
| 12.620619, 12.22254, 12.4261055] | |||||
| expect_loss_value = [12.206575, 11.865044, 11.828129, 11.826707, 11.82108, 12.407423, 12.005459, | |||||
| 12.621225, 12.222903, 12.427446] | |||||
| print("loss value: {}".format(loss_value)) | print("loss value: {}".format(loss_value)) | ||||
| assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) | assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) | ||||
| overflow = np.array(callback.overflow_list) | overflow = np.array(callback.overflow_list) | ||||
| expect_overflow = [True, True, False, False, False, True, False, False, False, True] | |||||
| expect_overflow = [False, False, False, True, False, False, False, True, False, False] | |||||
| print("overflow: {}".format(overflow)) | print("overflow: {}".format(overflow)) | ||||
| assert (overflow == expect_overflow).all() | assert (overflow == expect_overflow).all() | ||||
| loss_scale = np.array(callback.lossscale_list) | loss_scale = np.array(callback.lossscale_list) | ||||
| expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0] | |||||
| expect_loss_scale = [65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0] | |||||
| print("loss scale: {}".format(loss_scale)) | print("loss scale: {}".format(loss_scale)) | ||||
| assert np.allclose(loss_scale, expect_loss_scale, 0, 0) | assert np.allclose(loss_scale, expect_loss_scale, 0, 0) | ||||
| @@ -259,27 +259,27 @@ def test_bert_performance(): | |||||
| # assertion occurs while the loss value, overflow state or loss_scale value is wrong | # assertion occurs while the loss value, overflow state or loss_scale value is wrong | ||||
| loss_value = np.array(callback.loss_list) | loss_value = np.array(callback.loss_list) | ||||
| expect_loss_value = [10.237753, 10.213153, 10.212972] | |||||
| expect_loss_value = [10.235566, 10.207392, 10.206976] | |||||
| print("loss value: {}".format(loss_value)) | print("loss value: {}".format(loss_value)) | ||||
| assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) | assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) | ||||
| overflow = np.array(callback.overflow_list) | overflow = np.array(callback.overflow_list) | ||||
| expect_overflow = [False, False, False] | |||||
| expect_overflow = [True, True, True] | |||||
| print("overflow: {}".format(overflow)) | print("overflow: {}".format(overflow)) | ||||
| assert (overflow == expect_overflow).all() | assert (overflow == expect_overflow).all() | ||||
| loss_scale = np.array(callback.lossscale_list) | loss_scale = np.array(callback.lossscale_list) | ||||
| expect_loss_scale = [16384.0, 16384.0, 16384.0] | |||||
| expect_loss_scale = [262144.0, 262144.0, 262144.0] | |||||
| print("loss scale: {}".format(loss_scale)) | print("loss scale: {}".format(loss_scale)) | ||||
| assert np.allclose(loss_scale, expect_loss_scale, 0, 0) | assert np.allclose(loss_scale, expect_loss_scale, 0, 0) | ||||
| epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2] | epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2] | ||||
| expect_epoch_mseconds = 1726 | |||||
| expect_epoch_mseconds = 1600 | |||||
| print("epoch mseconds: {}".format(epoch_mseconds)) | print("epoch mseconds: {}".format(epoch_mseconds)) | ||||
| assert epoch_mseconds <= expect_epoch_mseconds + 5 | assert epoch_mseconds <= expect_epoch_mseconds + 5 | ||||
| per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2] | per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2] | ||||
| expect_per_step_mseconds = 17 | |||||
| expect_per_step_mseconds = 16 | |||||
| print("per step mseconds: {}".format(per_step_mseconds)) | print("per step mseconds: {}".format(per_step_mseconds)) | ||||
| assert per_step_mseconds <= expect_per_step_mseconds + 1 | assert per_step_mseconds <= expect_per_step_mseconds + 1 | ||||
| @@ -68,5 +68,47 @@ TEST_F(TestHWOptimizeDependence, test_optimize_dependence_with_make_tuple) { | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_dependence_with_make_tuple", "after"); | FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_dependence_with_make_tuple", "after"); | ||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tuple) { | |||||
| /* | |||||
| * def before(x, y, a, b): | |||||
| * z = make_tuple(TransData(a), TransData(b)) | |||||
| * depend_intput = control_depend(y, z) | |||||
| * sum = add(x, depend_intput) | |||||
| * return sum | |||||
| */ | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "before"); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::OptimizeDependence>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(g); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence) { | |||||
| /* | |||||
| * def before(x, y, a, b): | |||||
| * z = make_tuple(TransData(a), TransData(b)) | |||||
| * depend_intput = control_depend(y, z) | |||||
| * sum = add(x, depend_intput) | |||||
| * return sum | |||||
| */ | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "before"); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::OptimizeDependence>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(g); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,7 @@ from mindspore.ops import Primitive | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| depend = P.Depend() | depend = P.Depend() | ||||
| controldepend = Primitive("ControlDepend") | |||||
| TransData = Primitive('TransData') | TransData = Primitive('TransData') | ||||
| add = P.TensorAdd() | add = P.TensorAdd() | ||||
| make_tuple = Primitive('make_tuple') | make_tuple = Primitive('make_tuple') | ||||
| @@ -69,3 +70,42 @@ def test_optimize_dependence_with_make_tuple(tag): | |||||
| return sum_add | return sum_add | ||||
| return fns[tag] | return fns[tag] | ||||
| def test_optimize_control_dependence(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(x, y, z): | |||||
| new_z = TransData(z) | |||||
| depend_intput = controldepend(y, new_z) | |||||
| sum_add = add(x, depend_intput) | |||||
| return sum_add | |||||
| @fns | |||||
| def after(x, y, z): | |||||
| depend_intput = controldepend(y, z) | |||||
| sum_add = add(x, depend_intput) | |||||
| return sum_add | |||||
| return fns[tag] | |||||
| def test_optimize_control_dependence_with_make_tuple(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(x, y, a, b): | |||||
| z = make_tuple(TransData(a), TransData(b)) | |||||
| depend_intput = controldepend(y, z) | |||||
| sum_add = add(x, depend_intput) | |||||
| return sum_add | |||||
| @fns | |||||
| def after(x, y, a, b): | |||||
| z = make_tuple(a, b) | |||||
| depend_intput = controldepend(y, z) | |||||
| sum_add = add(x, depend_intput) | |||||
| return sum_add | |||||
| return fns[tag] | |||||