| @@ -25,6 +25,7 @@ | |||
| #include "abstract/abstract_function.h" | |||
| #include "utils/flags.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/anf_utils.h" | |||
| namespace mindspore { | |||
| /* namespace to support opt */ | |||
| @@ -246,6 +247,9 @@ bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::si | |||
| } | |||
| if (CheckReplace(node, main)) { | |||
| changes = true; | |||
| if (AnfUtils::GetDumpFlag(node) && !AnfUtils::GetDumpFlag(main)) { | |||
| AnfUtils::SetDumpFlag(main); | |||
| } | |||
| (void)manager->Replace(node, main); | |||
| (void)clear_set.insert(i); | |||
| } | |||
| @@ -25,6 +25,7 @@ | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "utils/anf_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -45,6 +46,7 @@ class MergeAddN : public AnfVisitor { | |||
| if (!is_match_ || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| addn_nodes_.push_back(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto addn = NewValueNode(GetValueNode(cnode->input(0))); | |||
| @@ -54,7 +56,9 @@ class MergeAddN : public AnfVisitor { | |||
| auto fg = node->func_graph(); | |||
| auto make_node = fg->NewCNode(args_); | |||
| return fg->NewCNode({addn, make_node}); | |||
| auto new_node = fg->NewCNode({addn, make_node}); | |||
| UpdateDumpFlag(new_node); | |||
| return new_node; | |||
| } | |||
| void Visit(const CNodePtr &cnode) override { | |||
| @@ -84,6 +88,7 @@ class MergeAddN : public AnfVisitor { | |||
| return; | |||
| } | |||
| addn_nodes_.push_back(first_input); | |||
| (void)Ys_.erase(Ys_.begin()); | |||
| (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); | |||
| (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); | |||
| @@ -104,6 +109,7 @@ class MergeAddN : public AnfVisitor { | |||
| return; | |||
| } | |||
| addn_nodes_.push_back(last_input); | |||
| Ys_.pop_back(); | |||
| (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); | |||
| (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); | |||
| @@ -133,14 +139,27 @@ class MergeAddN : public AnfVisitor { | |||
| Xs_.clear(); | |||
| Ys_.clear(); | |||
| args_.clear(); | |||
| addn_nodes_.clear(); | |||
| is_inner_ = false; | |||
| is_outer_ = false; | |||
| is_match_ = false; | |||
| } | |||
| void UpdateDumpFlag(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| for (const auto &addn : addn_nodes_) { | |||
| if (AnfUtils::GetDumpFlag(addn)) { | |||
| AnfUtils::SetDumpFlag(node); | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| private: | |||
| FuncGraphManagerPtr mng_{nullptr}; | |||
| std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{}; | |||
| std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{}, addn_nodes_{}; | |||
| bool is_inner_{false}, is_outer_{false}, is_match_{false}; | |||
| }; | |||
| @@ -26,6 +26,7 @@ | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "abstract/dshape.h" | |||
| #include "utils/anf_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -85,7 +86,11 @@ class ReduceOneEliminater : public AnfVisitor { | |||
| new_node->set_abstract(node_abstract); | |||
| return new_node; | |||
| } | |||
| return node->func_graph()->NewCNode({NewValueNode(reshape_op), x_, NewValueNode(new_shape)}); | |||
| auto new_node = node->func_graph()->NewCNode({NewValueNode(reshape_op), x_, NewValueNode(new_shape)}); | |||
| if (AnfUtils::GetDumpFlag(node)) { | |||
| AnfUtils::SetDumpFlag(new_node); | |||
| } | |||
| return new_node; | |||
| } | |||
| return nullptr; | |||
| @@ -29,6 +29,7 @@ | |||
| #include "utils/check_convert_utils.h" | |||
| #include "debug/dump_proto.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| using FloatPtr = std::shared_ptr<Float>; | |||
| @@ -71,6 +72,8 @@ static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_float_ | |||
| {64, mind_ir::TensorProto_DataType_FLOAT64}, | |||
| }; | |||
| static std::set<std::string> g_export_attr_blacklist = {kAttrDump}; | |||
| // Can build different builder according to format | |||
| class IrExportBuilder; | |||
| using IrExportBuilderPtr = std::shared_ptr<IrExportBuilder>; | |||
| @@ -625,6 +628,10 @@ bool IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons | |||
| auto prim = GetValueNode<PrimitivePtr>(op); | |||
| for (auto attr : prim->attrs()) { | |||
| MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); | |||
| auto iter = g_export_attr_blacklist.find(attr.first); | |||
| if (iter != g_export_attr_blacklist.end()) { | |||
| continue; | |||
| } | |||
| mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| attr_proto->set_name(attr.first); | |||
| auto attr_value = attr.second; | |||
| @@ -505,6 +505,7 @@ constexpr auto kPrimalAttrForwardNodeName = "forward_node_name"; | |||
| // attr value | |||
| constexpr auto kValueTargetSwitch = "target_switch"; | |||
| constexpr auto kValueTargetOther = "target_other"; | |||
| constexpr auto kValueTrue = "true"; | |||
| // env key | |||
| constexpr auto kGraphOpRun = "GRAPH_OP_RUN"; | |||
| @@ -250,4 +250,28 @@ bool AnfUtils::IsNodeInGraphKernel(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| return node->func_graph() != nullptr && node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| } | |||
| void AnfUtils::SetDumpFlag(const AnfNodePtr &node) { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| return; | |||
| } | |||
| auto prim = GetCNodePrimitive(node); | |||
| if (prim != nullptr) { | |||
| prim->set_attr(kAttrDump, MakeValue(kValueTrue)); | |||
| } | |||
| } | |||
| bool AnfUtils::GetDumpFlag(const AnfNodePtr &node) { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto prim = GetCNodePrimitive(node); | |||
| if (prim != nullptr) { | |||
| auto attr = prim->GetAttr(kAttrDump); | |||
| if (attr != nullptr && attr->isa<StringImm>() && attr->cast<StringImmPtr>()->value() == kValueTrue) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -48,6 +48,10 @@ class AnfUtils { | |||
| static bool IsGraphKernel(const AnfNodePtr &node); | |||
| // check whether the node is a node in GraphKernel's subgraph. | |||
| static bool IsNodeInGraphKernel(const AnfNodePtr &node); | |||
| // Set dump flag to CNode's primitive. | |||
| static void SetDumpFlag(const AnfNodePtr &node); | |||
| // Get dump flag from CNode's primitive. | |||
| static bool GetDumpFlag(const AnfNodePtr &node); | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_UTILS_ANF_UTILS_H_ | |||