Merge pull request !1995 from lianliguang/r0.3tags/v0.3.1-alpha
| @@ -62,12 +62,12 @@ void ValidateOperation(const AnfNodePtr &node) { | |||
| void ValidateAbstract(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| MS_LOG(WARNING) << "Node to validate is invalid"; | |||
| MS_LOG(DEBUG) << "Node to validate is invalid"; | |||
| return; | |||
| } | |||
| AbstractBasePtr ptrBase = node->abstract(); | |||
| if (ptrBase == nullptr) { | |||
| MS_LOG(WARNING) << "Abstract is null in node: " << node->DebugString(); | |||
| MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString(); | |||
| return; | |||
| } | |||
| if (ptrBase->isa<AbstractClass>() || ptrBase->isa<AbstractJTagged>()) { | |||
| @@ -61,16 +61,14 @@ bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, | |||
| bool GetNextNodeAndCastIndex(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodePtr *next_node, | |||
| size_t *cast_index) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| // Check whether the cast node is used for input by only one another node. | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| if (manager->node_users().find(node) == manager->node_users().end() || manager->node_users()[node].size() != 1) { | |||
| auto output_node_list = GetRealNodeUsedList(graph, node); | |||
| MS_EXCEPTION_IF_NULL(output_node_list); | |||
| if (output_node_list->size() != 1) { | |||
| return false; | |||
| } | |||
| *next_node = manager->node_users()[node].begin()->first; | |||
| *cast_index = IntToSize(manager->node_users()[node].begin()->second - 1); | |||
| auto node_pair = output_node_list->at(0); | |||
| *next_node = node_pair.first; | |||
| *cast_index = node_pair.second - 1; | |||
| return true; | |||
| } | |||
| @@ -148,7 +146,10 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co | |||
| if (alternative_kernel_info == kernel_info_list.end()) { | |||
| return nullptr; | |||
| } | |||
| MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_op_name; | |||
| auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(next_node); | |||
| MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_cnode->DebugString() | |||
| << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" | |||
| << (*alternative_kernel_info)->ToString(); | |||
| AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); | |||
| if (node->inputs().size() < kCastInputNum) { | |||
| auto op_name = AnfAlgo::GetCNodeName(node); | |||
| @@ -217,8 +218,11 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod | |||
| if (kernel_info_it == kernel_info_list.end()) { | |||
| return nullptr; | |||
| } | |||
| auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(prior_op); | |||
| MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << prior_op->DebugString() | |||
| << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" | |||
| << (*kernel_info_it)->ToString(); | |||
| AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get()); | |||
| auto prior_name = AnfAlgo::GetCNodeName(prior_op); | |||
| if (prior_name == kFive2FourOpName) { | |||
| AnfAlgo::CopyNodeAttr("dst_type", "dstType", cur_node, prior_op); | |||
| @@ -20,6 +20,7 @@ | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <set> | |||
| #include <utility> | |||
| #include <deque> | |||
| #include "utils/utils.h" | |||
| #include "utils/base_ref.h" | |||
| @@ -472,15 +473,38 @@ void RemoveNopNode(session::KernelGraph *const graph) { | |||
| } | |||
| } | |||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, | |||
| const AnfNodePtr &node) { | |||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> output_node_list = | |||
| std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>(); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| if (manager->node_users().find(node) == manager->node_users().end()) { | |||
| auto iter = manager->node_users().find(node); | |||
| if (iter == manager->node_users().end()) { | |||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||
| } | |||
| return manager->node_users()[node].size() > 1; | |||
| auto output_info_list = iter->second; | |||
| for (const auto &output_info : output_info_list) { | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | |||
| output_info.second == kDependAttachNodeIndex) { | |||
| continue; | |||
| } | |||
| output_node_list->push_back(output_info); | |||
| } | |||
| return output_node_list; | |||
| } | |||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto output_node_list = GetRealNodeUsedList(graph, node); | |||
| MS_EXCEPTION_IF_NULL(output_node_list); | |||
| return output_node_list->size() > 1; | |||
| } | |||
| AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { | |||
| @@ -19,6 +19,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <unordered_set> | |||
| #include "ir/func_graph.h" | |||
| #include "session/kernel_graph.h" | |||
| @@ -160,6 +161,9 @@ void RemoveNopNode(session::KernelGraph *const graph); | |||
| AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); | |||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, | |||
| const AnfNodePtr &node); | |||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | |||
| @@ -44,11 +44,11 @@ AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { | |||
| return cnode->input(kSingleInputIndex); | |||
| } | |||
| bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { | |||
| return false; | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> new_make_tuple_inputs; | |||
| bool need_update = false; | |||
| @@ -75,17 +75,16 @@ bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->Replace(cnode, new_make_tuple); | |||
| return new_make_tuple; | |||
| } | |||
| return true; | |||
| return nullptr; | |||
| } | |||
| } // namespace | |||
| const BaseRef OptimizeDependence::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>("X"); | |||
| MS_EXCEPTION_IF_NULL(X); | |||
| VarPtr Y = std::make_shared<Var>("Y"); | |||
| MS_EXCEPTION_IF_NULL(Y); | |||
| return VectorRef({prim::kPrimDepend, X, Y}); | |||
| VarPtr X = std::make_shared<Var>(); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({X, Xs}); | |||
| } | |||
| const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| @@ -95,29 +94,50 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con | |||
| if (!node->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| if (AnfAlgo::GetCNodeName(node) != prim::kPrimControlDepend->name() && | |||
| AnfAlgo::GetCNodeName(node) != prim::kPrimDepend->name()) { | |||
| return nullptr; | |||
| } | |||
| size_t index = 0; | |||
| auto depend_cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(depend_cnode); | |||
| CheckCNodeInputSize(depend_cnode, kDependInputNum); | |||
| auto replacing_node = depend_cnode->input(kDependInputNum - 1); | |||
| MS_EXCEPTION_IF_NULL(replacing_node); | |||
| if (!replacing_node->isa<CNode>()) { | |||
| return nullptr; | |||
| std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)}; | |||
| if (AnfAlgo::GetCNodeName(node) == prim::kPrimDepend->name()) { | |||
| index = 1; | |||
| new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend)); | |||
| } | |||
| auto replacing_cnode = replacing_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(replacing_cnode); | |||
| // Deal with the make_tuple with TransData or Cast inputs. | |||
| if (ReplaceMakeTuple(func_graph, replacing_cnode)) { | |||
| return nullptr; | |||
| if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) { | |||
| MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got " | |||
| << AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString(); | |||
| } | |||
| AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); | |||
| if (replace_node == nullptr) { | |||
| MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); | |||
| return nullptr; | |||
| while (index < AnfAlgo::GetInputTensorNum(depend_cnode)) { | |||
| auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); | |||
| ++index; | |||
| MS_EXCEPTION_IF_NULL(replacing_node); | |||
| if (!replacing_node->isa<CNode>()) { | |||
| new_depend_inputs.push_back(replacing_node); | |||
| continue; | |||
| } | |||
| auto replacing_cnode = replacing_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(replacing_cnode); | |||
| // Deal with the make_tuple with TransData or Cast inputs. | |||
| auto make_tuple_replace_node = ReplaceMakeTuple(func_graph, replacing_cnode); | |||
| if (make_tuple_replace_node != nullptr) { | |||
| new_depend_inputs.push_back(make_tuple_replace_node); | |||
| continue; | |||
| } | |||
| AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); | |||
| if (replace_node == nullptr) { | |||
| new_depend_inputs.push_back(replacing_node); | |||
| MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " | |||
| << node->DebugString(); | |||
| continue; | |||
| } | |||
| new_depend_inputs.push_back(replace_node); | |||
| } | |||
| std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex), | |||
| depend_cnode->input(kRealInputIndexInDepend), replace_node}; | |||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||
| CNodePtr new_depend; | |||
| CNodePtr new_depend = nullptr; | |||
| if (kernel_graph == nullptr) { | |||
| new_depend = func_graph->NewCNode(new_depend_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_depend); | |||
| @@ -171,18 +171,18 @@ def test_bert_tdt(): | |||
| # assertion occurs while the loss value, overflow state or loss_scale value is wrong | |||
| loss_value = np.array(callback.loss_list) | |||
| expect_loss_value = [12.207198, 11.980881, 11.984844, 11.879381, 11.832978, 12.411333, 12.009284, | |||
| 12.621277, 12.223178, 12.427385] | |||
| expect_loss_value = [12.207198, 11.865665, 11.828972, 11.827378, 11.821808, 12.408042, 12.00606, | |||
| 12.621794, 12.223485, 12.427612] | |||
| print("loss value: {}".format(loss_value)) | |||
| assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) | |||
| overflow = np.array(callback.overflow_list) | |||
| expect_overflow = [True, True, False, False, False, True, False, False, False, True] | |||
| expect_overflow = [False, False, False, True, False, False, False, True, False, False] | |||
| print("overflow: {}".format(overflow)) | |||
| assert (overflow == expect_overflow).all() | |||
| loss_scale = np.array(callback.lossscale_list) | |||
| expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0] | |||
| expect_loss_scale = [65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0] | |||
| print("loss scale: {}".format(loss_scale)) | |||
| assert np.allclose(loss_scale, expect_loss_scale, 0, 0) | |||