Merge pull request !43 from fary86/dump-typed-graph-when-analyze-failtags/v0.2.0-alpha
| @@ -34,6 +34,7 @@ | |||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "operator/ops.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // max number of elements in sequence | // max number of elements in sequence | ||||
| @@ -69,7 +70,7 @@ py::object load_obj(const std::string& path) { | |||||
| // ============================================= MindSpore IR Exporter ============================================= | // ============================================= MindSpore IR Exporter ============================================= | ||||
| std::string GetNodeType(const AnfNodePtr& nd) { | |||||
| std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { | |||||
| abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape()); | abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape()); | ||||
| TypePtr type = dyn_cast<Type>(nd->Type()); | TypePtr type = dyn_cast<Type>(nd->Type()); | ||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| @@ -102,7 +103,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& | |||||
| FuncGraphPtr fg = func_graph; | FuncGraphPtr fg = func_graph; | ||||
| while (fg != nullptr) { | while (fg != nullptr) { | ||||
| if (exported.find(fg) == exported.end()) { | if (exported.find(fg) == exported.end()) { | ||||
| if (!export_used_) { | |||||
| if (!check_integrity_) { | |||||
| break; | break; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'"; | MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'"; | ||||
| @@ -255,15 +256,15 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { | |||||
| } | } | ||||
| // output primitive attributes | // output primitive attributes | ||||
| auto attrs = prim->attrs(); | |||||
| if (attrs.size() > 0) { | |||||
| oss << "["; | |||||
| int i = 0; | |||||
| for (auto& attr : attrs) { | |||||
| oss << (i > 0 ? ", " : "") << attr.first << "=" << attr.second->DumpText(); | |||||
| i++; | |||||
| oss << prim->GetAttrsText(); | |||||
| if (prim->isa<prim::DoSignaturePrimitive>()) { | |||||
| auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim); | |||||
| auto& func = do_signature->function(); | |||||
| if (func->isa<Primitive>()) { | |||||
| auto sig_prim = dyn_cast<Primitive>(func); | |||||
| oss << sig_prim->GetAttrsText(); | |||||
| } | } | ||||
| oss << "]"; | |||||
| } | } | ||||
| return oss.str(); | return oss.str(); | ||||
| @@ -351,7 +352,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value | |||||
| std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { | std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { | ||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| if (export_used_) { | |||||
| if (check_integrity_) { | |||||
| MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText(); | MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText(); | ||||
| } | } | ||||
| oss << value->type_name() << "[" << value->DumpText() << "]"; | oss << value->type_name() << "[" << value->DumpText() << "]"; | ||||
| @@ -420,7 +421,7 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An | |||||
| } | } | ||||
| oss << "%" << iter->second; | oss << "%" << iter->second; | ||||
| } else if (node->isa<Parameter>()) { | } else if (node->isa<Parameter>()) { | ||||
| oss << "%para" << GetParamIndex(func_graph, node, export_used_); | |||||
| oss << "%para" << GetParamIndex(func_graph, node, check_integrity_); | |||||
| } else if (IsValueNode<FuncGraph>(node)) { | } else if (IsValueNode<FuncGraph>(node)) { | ||||
| FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node); | FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node); | ||||
| oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id(); | oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id(); | ||||
| @@ -64,17 +64,18 @@ struct ParamPtrHasher { | |||||
| class AnfExporter { | class AnfExporter { | ||||
| public: | public: | ||||
| explicit AnfExporter(const std::string& id, bool export_used = true) | |||||
| : param_index(-1), id_(id), export_used_(export_used) { | |||||
| explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false) | |||||
| : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { | |||||
| func_graph_set.clear(); | func_graph_set.clear(); | ||||
| exported.clear(); | exported.clear(); | ||||
| } | } | ||||
| ~AnfExporter() {} | |||||
| virtual ~AnfExporter() {} | |||||
| void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); | void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); | ||||
| void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs); | void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs); | ||||
| private: | |||||
| protected: | |||||
| virtual std::string GetNodeType(const AnfNodePtr& nd); | |||||
| int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); | int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); | ||||
| int GetParamIndexFromExported(const AnfNodePtr& param); | int GetParamIndexFromExported(const AnfNodePtr& param); | ||||
| std::string DumpObject(const py::object& obj, const std::string& category) const; | std::string DumpObject(const py::object& obj, const std::string& category) const; | ||||
| @@ -101,8 +102,10 @@ class AnfExporter { | |||||
| OrderedSet<FuncGraphPtr> func_graph_set{}; | OrderedSet<FuncGraphPtr> func_graph_set{}; | ||||
| OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported; | OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported; | ||||
| std::string id_; | std::string id_; | ||||
| bool export_used_ = true; // whether export function graphs used in current exporting function graph | |||||
| bool export_used_ = true; // whether export function graphs used in current exporting function graph | |||||
| bool check_integrity_ = false; // whether check integrity or not, when dumping ir for loading, must set it to true | |||||
| TaggedNodeMap tagged_cnodes_; | TaggedNodeMap tagged_cnodes_; | ||||
| abstract::AnfNodeConfigPtr node_cfg_ = nullptr; | |||||
| }; | }; | ||||
| void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); | void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); | ||||
| @@ -115,7 +118,6 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); | |||||
| void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); | void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); | ||||
| std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); | std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); | ||||
| std::string GetNodeType(const AnfNodePtr& nd); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ | #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <fstream> | |||||
| #include <map> | #include <map> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -194,37 +195,116 @@ void TraceGraphInfer() { | |||||
| MS_LOG(INFO) << "\n*************************************************************************************"; | MS_LOG(INFO) << "\n*************************************************************************************"; | ||||
| } | } | ||||
| void OutputAnalysisGraphInfo() { | |||||
| MS_LOG(INFO) << "Output analysis graph begin"; | |||||
| std::unordered_map<FuncGraphPtr, size_t> index_map; | |||||
| std::vector<TaggedGraph> tagged_graphs; | |||||
| class AnalyzedFuncGraphExporter : public AnfExporter { | |||||
| public: | |||||
| AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} | |||||
| ~AnalyzedFuncGraphExporter() override = default; | |||||
| void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs); | |||||
| private: | |||||
| std::string GetNodeType(const AnfNodePtr& nd) override; | |||||
| }; | |||||
| std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() { | |||||
| std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs; | |||||
| auto& list = GetCNodeDebugStack(); | auto& list = GetCNodeDebugStack(); | ||||
| for (size_t i = 0; i < list.size(); ++i) { | for (size_t i = 0; i < list.size(); ++i) { | ||||
| auto& node_cfg = list[i]; | |||||
| auto node_cfg = list[i]; | |||||
| auto fg = node_cfg->context()->func_graph(); | auto fg = node_cfg->context()->func_graph(); | ||||
| auto node = node_cfg->node(); | auto node = node_cfg->node(); | ||||
| auto idx = tagged_graphs.size(); | |||||
| std::pair<FuncGraphPtr, size_t> item(fg, idx); | |||||
| if (index_map.insert(item).second) { | |||||
| tagged_graphs.emplace_back(TaggedGraph(fg, TaggedNodeMap())); | |||||
| tagged_func_graphs[fg][node] = i; | |||||
| } | |||||
| return tagged_func_graphs; | |||||
| } | |||||
| void OutputAnalyzedGraphWithType() { | |||||
| AnalyzedFuncGraphExporter exporter; | |||||
| exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); | |||||
| } | |||||
| std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { | |||||
| if (node_cfg_ == nullptr) { | |||||
| return AnfExporter::GetNodeType(node); | |||||
| } | |||||
| auto ctx = node_cfg_->context(); | |||||
| auto engine = node_cfg_->engine(); | |||||
| auto cfg = engine->MakeConfig(node, ctx); | |||||
| auto abs = engine->cache().GetValue(cfg); | |||||
| if (abs == nullptr) { | |||||
| return "Undefined"; | |||||
| } | |||||
| auto dtype = abs->BuildType(); | |||||
| auto shape = abs->BuildShape(); | |||||
| std::ostringstream oss; | |||||
| if (dtype != nullptr && abs->isa<abstract::AbstractTensor>() && shape != nullptr) { | |||||
| oss << dtype->DumpText() << shape->DumpText(); | |||||
| } else if (dtype != nullptr) { | |||||
| oss << dtype->DumpText(); | |||||
| } else { | |||||
| oss << "Undefined"; | |||||
| } | |||||
| return oss.str(); | |||||
| } | |||||
| void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, | |||||
| const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) { | |||||
| if (node_cfgs.empty()) { | |||||
| MS_LOG(DEBUG) << "Node configs is empty"; | |||||
| return; | |||||
| } | |||||
| std::ofstream ofs(filename); | |||||
| if (!ofs.is_open()) { | |||||
| MS_LOG(ERROR) << "Open file '" << filename << "' failed!"; | |||||
| return; | |||||
| } | |||||
| param_index = 1; | |||||
| auto tagged_func_graphs = CalcTaggedFuncGraphs(); | |||||
| // first output grapn on the analysis stack | |||||
| for (const auto& node_cfg : node_cfgs) { | |||||
| auto fg = node_cfg->context()->func_graph(); | |||||
| // the graph is already output, skip it | |||||
| if (exported.find(fg) != exported.end()) { | |||||
| continue; | |||||
| } | } | ||||
| tagged_graphs[index_map[fg]].second[node] = i; | |||||
| // set node_cfg info for getting type | |||||
| node_cfg_ = node_cfg; | |||||
| tagged_cnodes_ = tagged_func_graphs[fg]; | |||||
| ExportOneFuncGraph(ofs, fg); | |||||
| ofs << "\n\n"; | |||||
| } | |||||
| node_cfg_ = nullptr; | |||||
| tagged_cnodes_.clear(); | |||||
| // print seperator between function graphs on analyzed graph call stack and others | |||||
| ofs << "#===============================================================================\n\n\n"; | |||||
| // second output other graphs | |||||
| while (!func_graph_set.empty()) { | |||||
| FuncGraphPtr fg = *func_graph_set.begin(); | |||||
| ExportOneFuncGraph(ofs, fg); | |||||
| ofs << "\n\n"; | |||||
| (void)func_graph_set.erase(fg); | |||||
| } | } | ||||
| ofs << "# num of total funcgraphs: " << exported.size(); | |||||
| ExportIR("analyze_fail.dat", tagged_graphs); | |||||
| MS_LOG(INFO) << "Output analysis graph *end*"; | |||||
| ofs.close(); | |||||
| } | } | ||||
| void GetInferStackInfo(std::ostringstream& oss) { | void GetInferStackInfo(std::ostringstream& oss) { | ||||
| MS_LOG(INFO) << "Get graph analysis information begin"; | MS_LOG(INFO) << "Get graph analysis information begin"; | ||||
| auto& stack = GetCNodeDebugStack(); | |||||
| auto stack = GetCNodeDebugStack(); | |||||
| if (stack.empty()) { | if (stack.empty()) { | ||||
| MS_LOG(INFO) << "Length of analysis information stack is empty."; | MS_LOG(INFO) << "Length of analysis information stack is empty."; | ||||
| return; | return; | ||||
| } | } | ||||
| OutputAnalysisGraphInfo(); | |||||
| OutputAnalyzedGraphWithType(); | |||||
| oss << "\nThe function call stack:\n"; | oss << "\nThe function call stack:\n"; | ||||
| int index = 0; | int index = 0; | ||||
| @@ -106,6 +106,27 @@ void Primitive::set_signatures( | |||||
| } | } | ||||
| } | } | ||||
| std::string Primitive::GetAttrsText() const { | |||||
| if (attrs_.empty()) { | |||||
| return ""; | |||||
| } | |||||
| std::ostringstream oss; | |||||
| oss << "["; | |||||
| bool is_first = true; | |||||
| for (auto& attr : attrs_) { | |||||
| if (is_first) { | |||||
| is_first = false; | |||||
| } else { | |||||
| oss << ", "; | |||||
| } | |||||
| oss << attr.first << "=" << attr.second->DumpText(); | |||||
| } | |||||
| oss << "]"; | |||||
| return oss.str(); | |||||
| } | |||||
| py::function PrimitivePy::GetBpropFunction() { | py::function PrimitivePy::GetBpropFunction() { | ||||
| static const char* const get_bprop_func_name = "get_bprop"; | static const char* const get_bprop_func_name = "get_bprop"; | ||||
| if (py::hasattr(python_obj_, get_bprop_func_name)) { | if (py::hasattr(python_obj_, get_bprop_func_name)) { | ||||
| @@ -102,6 +102,7 @@ class Primitive : public Named { | |||||
| PrimType prim_type() const { return prim_type_; } | PrimType prim_type() const { return prim_type_; } | ||||
| std::string instance_name() const { return instance_name_; } | std::string instance_name() const { return instance_name_; } | ||||
| std::string GetAttrsText() const; | |||||
| bool operator==(const Value& other) const override; | bool operator==(const Value& other) const override; | ||||
| bool operator==(const Primitive& other) const; | bool operator==(const Primitive& other) const; | ||||
| ~Primitive() override = default; | ~Primitive() override = default; | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "pipeline/static_analysis/prim.h" | #include "pipeline/static_analysis/prim.h" | ||||
| #include "pipeline/static_analysis/abstract_function.h" | #include "pipeline/static_analysis/abstract_function.h" | ||||
| #include "debug/trace.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| using Shape = abstract::Shape; | using Shape = abstract::Shape; | ||||
| @@ -124,6 +125,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) { | |||||
| AbstractBasePtrList args_spec_list = {tuple_tensor, start_index}; | AbstractBasePtrList args_spec_list = {tuple_tensor, start_index}; | ||||
| try { | try { | ||||
| trace::ClearTraceStack(); | |||||
| engine_->Run(tupleSliceGraphPtr, args_spec_list); | engine_->Run(tupleSliceGraphPtr, args_spec_list); | ||||
| FAIL() << "Excepted exception :Args type is wrong"; | FAIL() << "Excepted exception :Args type is wrong"; | ||||
| } catch (std::runtime_error const &err) { | } catch (std::runtime_error const &err) { | ||||