Merge pull request !43 from fary86/dump-typed-graph-when-analyze-failtags/v0.2.0-alpha
| @@ -34,6 +34,7 @@ | |||
| #include "utils/utils.h" | |||
| #include "debug/trace.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "operator/ops.h" | |||
| namespace mindspore { | |||
| // max number of elements in sequence | |||
| @@ -69,7 +70,7 @@ py::object load_obj(const std::string& path) { | |||
| // ============================================= MindSpore IR Exporter ============================================= | |||
| std::string GetNodeType(const AnfNodePtr& nd) { | |||
| std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { | |||
| abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape()); | |||
| TypePtr type = dyn_cast<Type>(nd->Type()); | |||
| std::ostringstream oss; | |||
| @@ -102,7 +103,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& | |||
| FuncGraphPtr fg = func_graph; | |||
| while (fg != nullptr) { | |||
| if (exported.find(fg) == exported.end()) { | |||
| if (!export_used_) { | |||
| if (!check_integrity_) { | |||
| break; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'"; | |||
| @@ -255,15 +256,15 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { | |||
| } | |||
| // output primitive attributes | |||
| auto attrs = prim->attrs(); | |||
| if (attrs.size() > 0) { | |||
| oss << "["; | |||
| int i = 0; | |||
| for (auto& attr : attrs) { | |||
| oss << (i > 0 ? ", " : "") << attr.first << "=" << attr.second->DumpText(); | |||
| i++; | |||
| oss << prim->GetAttrsText(); | |||
| if (prim->isa<prim::DoSignaturePrimitive>()) { | |||
| auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim); | |||
| auto& func = do_signature->function(); | |||
| if (func->isa<Primitive>()) { | |||
| auto sig_prim = dyn_cast<Primitive>(func); | |||
| oss << sig_prim->GetAttrsText(); | |||
| } | |||
| oss << "]"; | |||
| } | |||
| return oss.str(); | |||
| @@ -351,7 +352,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value | |||
| std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { | |||
| std::ostringstream oss; | |||
| if (export_used_) { | |||
| if (check_integrity_) { | |||
| MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText(); | |||
| } | |||
| oss << value->type_name() << "[" << value->DumpText() << "]"; | |||
| @@ -420,7 +421,7 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An | |||
| } | |||
| oss << "%" << iter->second; | |||
| } else if (node->isa<Parameter>()) { | |||
| oss << "%para" << GetParamIndex(func_graph, node, export_used_); | |||
| oss << "%para" << GetParamIndex(func_graph, node, check_integrity_); | |||
| } else if (IsValueNode<FuncGraph>(node)) { | |||
| FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node); | |||
| oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id(); | |||
| @@ -64,17 +64,18 @@ struct ParamPtrHasher { | |||
| class AnfExporter { | |||
| public: | |||
| explicit AnfExporter(const std::string& id, bool export_used = true) | |||
| : param_index(-1), id_(id), export_used_(export_used) { | |||
| explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false) | |||
| : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { | |||
| func_graph_set.clear(); | |||
| exported.clear(); | |||
| } | |||
| ~AnfExporter() {} | |||
| virtual ~AnfExporter() {} | |||
| void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); | |||
| void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs); | |||
| private: | |||
| protected: | |||
| virtual std::string GetNodeType(const AnfNodePtr& nd); | |||
| int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); | |||
| int GetParamIndexFromExported(const AnfNodePtr& param); | |||
| std::string DumpObject(const py::object& obj, const std::string& category) const; | |||
| @@ -101,8 +102,10 @@ class AnfExporter { | |||
| OrderedSet<FuncGraphPtr> func_graph_set{}; | |||
| OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported; | |||
| std::string id_; | |||
| bool export_used_ = true; // whether export function graphs used in current exporting function graph | |||
| bool export_used_ = true; // whether export function graphs used in current exporting function graph | |||
| bool check_integrity_ = false; // whether check integrity or not, when dumping ir for loading, must set it to true | |||
| TaggedNodeMap tagged_cnodes_; | |||
| abstract::AnfNodeConfigPtr node_cfg_ = nullptr; | |||
| }; | |||
| void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); | |||
| @@ -115,7 +118,6 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); | |||
| void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); | |||
| std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); | |||
| std::string GetNodeType(const AnfNodePtr& nd); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ | |||
| @@ -17,6 +17,7 @@ | |||
| #include "debug/trace.h" | |||
| #include <iostream> | |||
| #include <fstream> | |||
| #include <map> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| @@ -194,37 +195,116 @@ void TraceGraphInfer() { | |||
| MS_LOG(INFO) << "\n*************************************************************************************"; | |||
| } | |||
| void OutputAnalysisGraphInfo() { | |||
| MS_LOG(INFO) << "Output analysis graph begin"; | |||
| std::unordered_map<FuncGraphPtr, size_t> index_map; | |||
| std::vector<TaggedGraph> tagged_graphs; | |||
| class AnalyzedFuncGraphExporter : public AnfExporter { | |||
| public: | |||
| AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} | |||
| ~AnalyzedFuncGraphExporter() override = default; | |||
| void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs); | |||
| private: | |||
| std::string GetNodeType(const AnfNodePtr& nd) override; | |||
| }; | |||
| std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() { | |||
| std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs; | |||
| auto& list = GetCNodeDebugStack(); | |||
| for (size_t i = 0; i < list.size(); ++i) { | |||
| auto& node_cfg = list[i]; | |||
| auto node_cfg = list[i]; | |||
| auto fg = node_cfg->context()->func_graph(); | |||
| auto node = node_cfg->node(); | |||
| auto idx = tagged_graphs.size(); | |||
| std::pair<FuncGraphPtr, size_t> item(fg, idx); | |||
| if (index_map.insert(item).second) { | |||
| tagged_graphs.emplace_back(TaggedGraph(fg, TaggedNodeMap())); | |||
| tagged_func_graphs[fg][node] = i; | |||
| } | |||
| return tagged_func_graphs; | |||
| } | |||
| void OutputAnalyzedGraphWithType() { | |||
| AnalyzedFuncGraphExporter exporter; | |||
| exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); | |||
| } | |||
| std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { | |||
| if (node_cfg_ == nullptr) { | |||
| return AnfExporter::GetNodeType(node); | |||
| } | |||
| auto ctx = node_cfg_->context(); | |||
| auto engine = node_cfg_->engine(); | |||
| auto cfg = engine->MakeConfig(node, ctx); | |||
| auto abs = engine->cache().GetValue(cfg); | |||
| if (abs == nullptr) { | |||
| return "Undefined"; | |||
| } | |||
| auto dtype = abs->BuildType(); | |||
| auto shape = abs->BuildShape(); | |||
| std::ostringstream oss; | |||
| if (dtype != nullptr && abs->isa<abstract::AbstractTensor>() && shape != nullptr) { | |||
| oss << dtype->DumpText() << shape->DumpText(); | |||
| } else if (dtype != nullptr) { | |||
| oss << dtype->DumpText(); | |||
| } else { | |||
| oss << "Undefined"; | |||
| } | |||
| return oss.str(); | |||
| } | |||
| void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, | |||
| const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) { | |||
| if (node_cfgs.empty()) { | |||
| MS_LOG(DEBUG) << "Node configs is empty"; | |||
| return; | |||
| } | |||
| std::ofstream ofs(filename); | |||
| if (!ofs.is_open()) { | |||
| MS_LOG(ERROR) << "Open file '" << filename << "' failed!"; | |||
| return; | |||
| } | |||
| param_index = 1; | |||
| auto tagged_func_graphs = CalcTaggedFuncGraphs(); | |||
| // first output grapn on the analysis stack | |||
| for (const auto& node_cfg : node_cfgs) { | |||
| auto fg = node_cfg->context()->func_graph(); | |||
| // the graph is already output, skip it | |||
| if (exported.find(fg) != exported.end()) { | |||
| continue; | |||
| } | |||
| tagged_graphs[index_map[fg]].second[node] = i; | |||
| // set node_cfg info for getting type | |||
| node_cfg_ = node_cfg; | |||
| tagged_cnodes_ = tagged_func_graphs[fg]; | |||
| ExportOneFuncGraph(ofs, fg); | |||
| ofs << "\n\n"; | |||
| } | |||
| node_cfg_ = nullptr; | |||
| tagged_cnodes_.clear(); | |||
| // print seperator between function graphs on analyzed graph call stack and others | |||
| ofs << "#===============================================================================\n\n\n"; | |||
| // second output other graphs | |||
| while (!func_graph_set.empty()) { | |||
| FuncGraphPtr fg = *func_graph_set.begin(); | |||
| ExportOneFuncGraph(ofs, fg); | |||
| ofs << "\n\n"; | |||
| (void)func_graph_set.erase(fg); | |||
| } | |||
| ofs << "# num of total funcgraphs: " << exported.size(); | |||
| ExportIR("analyze_fail.dat", tagged_graphs); | |||
| MS_LOG(INFO) << "Output analysis graph *end*"; | |||
| ofs.close(); | |||
| } | |||
| void GetInferStackInfo(std::ostringstream& oss) { | |||
| MS_LOG(INFO) << "Get graph analysis information begin"; | |||
| auto& stack = GetCNodeDebugStack(); | |||
| auto stack = GetCNodeDebugStack(); | |||
| if (stack.empty()) { | |||
| MS_LOG(INFO) << "Length of analysis information stack is empty."; | |||
| return; | |||
| } | |||
| OutputAnalysisGraphInfo(); | |||
| OutputAnalyzedGraphWithType(); | |||
| oss << "\nThe function call stack:\n"; | |||
| int index = 0; | |||
| @@ -106,6 +106,27 @@ void Primitive::set_signatures( | |||
| } | |||
| } | |||
| std::string Primitive::GetAttrsText() const { | |||
| if (attrs_.empty()) { | |||
| return ""; | |||
| } | |||
| std::ostringstream oss; | |||
| oss << "["; | |||
| bool is_first = true; | |||
| for (auto& attr : attrs_) { | |||
| if (is_first) { | |||
| is_first = false; | |||
| } else { | |||
| oss << ", "; | |||
| } | |||
| oss << attr.first << "=" << attr.second->DumpText(); | |||
| } | |||
| oss << "]"; | |||
| return oss.str(); | |||
| } | |||
| py::function PrimitivePy::GetBpropFunction() { | |||
| static const char* const get_bprop_func_name = "get_bprop"; | |||
| if (py::hasattr(python_obj_, get_bprop_func_name)) { | |||
| @@ -102,6 +102,7 @@ class Primitive : public Named { | |||
| PrimType prim_type() const { return prim_type_; } | |||
| std::string instance_name() const { return instance_name_; } | |||
| std::string GetAttrsText() const; | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const Primitive& other) const; | |||
| ~Primitive() override = default; | |||
| @@ -22,6 +22,7 @@ | |||
| #include "operator/ops.h" | |||
| #include "pipeline/static_analysis/prim.h" | |||
| #include "pipeline/static_analysis/abstract_function.h" | |||
| #include "debug/trace.h" | |||
| namespace mindspore { | |||
| using Shape = abstract::Shape; | |||
| @@ -124,6 +125,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) { | |||
| AbstractBasePtrList args_spec_list = {tuple_tensor, start_index}; | |||
| try { | |||
| trace::ClearTraceStack(); | |||
| engine_->Run(tupleSliceGraphPtr, args_spec_list); | |||
| FAIL() << "Excepted exception :Args type is wrong"; | |||
| } catch (std::runtime_error const &err) { | |||