Browse Source

Support fused node corresponding code print

tags/v1.6.0
huanghui 4 years ago
parent
commit
767caad833
10 changed files with 241 additions and 60 deletions
  1. +42
    -28
      mindspore/ccsrc/debug/anf_ir_dump.cc
  2. +12
    -3
      mindspore/ccsrc/frontend/optimizer/cse.cc
  3. +1
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h
  4. +49
    -0
      mindspore/core/ir/anf.cc
  5. +25
    -13
      mindspore/core/ir/anf.h
  6. +5
    -3
      mindspore/core/ir/primal_debug_info.h
  7. +41
    -1
      mindspore/core/utils/info.cc
  8. +7
    -2
      mindspore/core/utils/info.h
  9. +57
    -10
      mindspore/core/utils/trace_base.cc
  10. +2
    -0
      mindspore/core/utils/trace_base.h

+ 42
- 28
mindspore/ccsrc/debug/anf_ir_dump.cc View File

@@ -410,6 +410,45 @@ void DumpShape(const AnfNodePtr &nd, const FuncGraphPtr &sub_graph, const std::s
gsub->buffer << std::endl;
}

void DumpPrimalDebugInfos(const CNodePtr &nd, const std::shared_ptr<SubGraphIRInfo> &gsub) {
MS_EXCEPTION_IF_NULL(nd);
auto primal_debug_infos = nd->primal_debug_infos();
if (!primal_debug_infos.empty()) {
gsub->buffer << " # Corresponding forward node candidate:\n";
for (auto &primal_debug_info : primal_debug_infos) {
gsub->buffer << trace::GetDebugInfo(primal_debug_info, " # ", kSourceLineTipDiscard) << "\n";
}
}
}

void DumpDebugInfo(const CNodePtr &nd, const std::shared_ptr<SubGraphIRInfo> &gsub, const LocDumpMode &dump_location) {
MS_EXCEPTION_IF_NULL(nd);
if (dump_location == kTopStack) {
auto fused_debug_infos = nd->fused_debug_infos();
if (!fused_debug_infos.empty()) {
gsub->buffer << " # Corresponding code candidate:\n";
for (const auto &debug_info : fused_debug_infos) {
auto debug_info_str = trace::GetDebugInfo(debug_info, " # ", kSourceLineTipDiscard);
if (!debug_info_str.empty()) {
gsub->buffer << debug_info_str << "\n";
}
}
} else {
auto debug_info_str = trace::GetDebugInfo(nd->debug_info(), " # ", kSourceLineTipDiscard);
if (!debug_info_str.empty()) {
gsub->buffer << debug_info_str << "\n";
}
}

DumpPrimalDebugInfos(nd, gsub);
} else if (dump_location == kWholeStack) {
auto traces = mindspore::trace::GetSourceLineList(nd);
for (auto &trace : traces) {
gsub->buffer << " # " << trace;
}
}
}

void DumpCNode(const CNodePtr &nd, const FuncGraphPtr &sub_graph, OrderedMap<AnfNodePtr, int32_t> *const para_map,
const std::shared_ptr<SubGraphIRInfo> &gsub, bool dump_full_name = false,
LocDumpMode dump_location = kOff) {
@@ -457,34 +496,9 @@ void DumpCNode(const CNodePtr &nd, const FuncGraphPtr &sub_graph, OrderedMap<Anf
if (dump_full_name) {
gsub->buffer << " : (" << nd->fullname_with_scope() << ")" << std::endl;
}
if (dump_location == kTopStack) {
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
gsub->buffer << trace::GetDebugInfo(nd->debug_info(), " # ", kSourceLineTipDiscard) << "#"
<< label_manage::Label(nd->debug_info()) << "\n";
auto primal_debug_infos = nd->primal_debug_infos();
if (!primal_debug_infos.empty()) {
gsub->buffer << " # Corresponding forward node candidate:\n";
for (auto &primal_debug_info : primal_debug_infos) {
gsub->buffer << trace::GetDebugInfo(primal_debug_info, " # ", kSourceLineTipDiscard) << "#"
<< label_manage::Label(primal_debug_info) << "\n";
}
}
} else {
gsub->buffer << trace::GetDebugInfo(nd->debug_info(), " # ", kSourceLineTipDiscard) << "\n";
auto primal_debug_infos = nd->primal_debug_infos();
if (!primal_debug_infos.empty()) {
gsub->buffer << " # Corresponding forward node candidate:\n";
for (auto &primal_debug_info : primal_debug_infos) {
gsub->buffer << trace::GetDebugInfo(primal_debug_info, " # ", kSourceLineTipDiscard) << "\n";
}
}
}
} else if (dump_location == kWholeStack) {
auto traces = mindspore::trace::GetSourceLineList(nd);
for (auto &trace : traces) {
gsub->buffer << " # " << trace;
}
}

// print debug info
DumpDebugInfo(nd, gsub, dump_location);
}

void DumpIRInSubgraph(const std::vector<AnfNodePtr> &nodes, OrderedMap<AnfNodePtr, int32_t> *para_map,


+ 12
- 3
mindspore/ccsrc/frontend/optimizer/cse.cc View File

@@ -48,6 +48,17 @@ bool IsSetRecomputed(const CNodePtr &a, const CNodePtr &b) {
(WithRecomputedScope(b) && !b->HasAttr(kAttrNeedCseAfterRecompute));
}

void UpdateDebugInfoAndDumpFlag(const AnfNodePtr &main, const AnfNodePtr &node) {
if (main == nullptr || !main->isa<CNode>()) {
return;
}
if (AnfUtils::GetDumpFlag(node) && !AnfUtils::GetDumpFlag(main)) {
AnfUtils::SetDumpFlag(main);
}
auto main_cnode = main->cast<CNodePtr>();
main_cnode->AddFusedDebugInfo(node);
}

BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) {
MS_EXCEPTION_IF_NULL(node);
auto node_abs = node->abstract();
@@ -247,9 +258,7 @@ bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::si
}
if (CheckReplace(node, main)) {
changes = true;
if (AnfUtils::GetDumpFlag(node) && !AnfUtils::GetDumpFlag(main)) {
AnfUtils::SetDumpFlag(main);
}
UpdateDebugInfoAndDumpFlag(main, node);
(void)manager->Replace(node, main);
(void)clear_set.insert(i);
}


+ 1
- 0
mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h View File

@@ -58,6 +58,7 @@ class MergeAddN : public AnfVisitor {

auto new_node = fg->NewCNode({addn, make_node});
UpdateDumpFlag(new_node);
new_node->AddFusedDebugInfoList(addn_nodes_);
return new_node;
}



+ 49
- 0
mindspore/core/ir/anf.cc View File

@@ -125,6 +125,55 @@ std::string CNode::DebugString(int recursive_level) const {
return buffer.str();
}

void CNode::AddFusedDebugInfo(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
return;
}
if (shared_from_this() == node) {
this->AddFusedDebugInfo(node->debug_info());
return;
}
auto cnode = node->cast<CNodePtr>();
auto node_fused_debug_infos = cnode->fused_debug_infos();
if (!node_fused_debug_infos.empty()) {
std::for_each(node_fused_debug_infos.begin(), node_fused_debug_infos.end(),
[this](const NodeDebugInfoPtr &debug_info) { this->AddFusedDebugInfo(debug_info); });
} else {
this->AddFusedDebugInfo(cnode->debug_info());
}

auto primal_debug_infos = cnode->primal_debug_infos();
if (!primal_debug_infos.empty()) {
std::for_each(primal_debug_infos.begin(), primal_debug_infos.end(),
[this](const NodeDebugInfoPtr &debug_info) { this->AddPrimalDebugInfo(debug_info); });
}
}

void CNode::AddFusedDebugInfoList(const std::vector<AnfNodePtr> &nodes) {
std::for_each(nodes.begin(), nodes.end(), [this](const AnfNodePtr &node) { this->AddFusedDebugInfo(node); });
}

void CNode::AddFusedDebugInfo(const NodeDebugInfoPtr &debug_info) {
if (debug_info == nullptr) {
return;
}
(void)fused_debug_infos_.emplace(debug_info);
}

void CNode::AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> &debug_infos) {
std::for_each(debug_infos.begin(), debug_infos.end(),
[this](const NodeDebugInfoPtr &debug_info) { this->AddFusedDebugInfo(debug_info); });
}

NodeDebugInfoSet CNode::primal_debug_infos() const { return primal_debug_infos_; }

void CNode::set_primal_debug_infos(const NodeDebugInfoSet &debug_infos) {
std::for_each(debug_infos.begin(), debug_infos.end(),
[this](const NodeDebugInfoPtr &debug_info) { this->AddPrimalDebugInfo(debug_info); });
}

void CNode::AddPrimalDebugInfo(const NodeDebugInfoPtr &debug_info) { (void)primal_debug_infos_.emplace(debug_info); }

std::string Parameter::DebugString(int recursive_level) const {
std::ostringstream buffer;
if (recursive_level > 0) {


+ 25
- 13
mindspore/core/ir/anf.h View File

@@ -56,6 +56,7 @@ class AbstractBase;
using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
using AbstractBasePtrList = std::vector<AbstractBasePtr>;
using NodeDebugInfoSet = std::set<NodeDebugInfoPtr, DebugInfoCompare>;

class Value;
using ValuePtr = std::shared_ptr<Value>;
@@ -616,14 +617,17 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
/// \brief Get primal debug information.
///
/// \return The primal debug information.
std::vector<NodeDebugInfoPtr> primal_debug_infos() { return primal_debug_infos_; }
NodeDebugInfoSet primal_debug_infos() const;

/// \brief Set primal debug information.
///
/// \param[in] debug_infos Debug information of this CNode.
void set_primal_debug_infos(const std::vector<NodeDebugInfoPtr> &debug_infos) {
primal_debug_infos_.insert(primal_debug_infos_.end(), debug_infos.begin(), debug_infos.end());
}
void set_primal_debug_infos(const NodeDebugInfoSet &debug_infos);

/// \brief Add a primal debug information.
///
/// \param[in] debug_info A debug information.
void AddPrimalDebugInfo(const NodeDebugInfoPtr &debug_info);

void CloneCNodeInfo(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
@@ -657,24 +661,32 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
/// \brief Get the debug infos of fused nodes.
///
/// \return A vector of debug infos.
std::vector<NodeDebugInfoPtr> fused_debug_infos() const { return fused_debug_infos_; }
NodeDebugInfoSet fused_debug_infos() const { return fused_debug_infos_; }

/// \brief Set the debug infos for CNode.
///
/// \param fused_debug_infos The debuf infos to be set.
void set_fused_debug_infos(const std::vector<NodeDebugInfoPtr> &fused_debug_infos) {
fused_debug_infos_ = fused_debug_infos;
}
/// \param fused_debug_infos The debug infos to be set.
void set_fused_debug_infos(const NodeDebugInfoSet &fused_debug_infos) { fused_debug_infos_ = fused_debug_infos; }

/// \brief Add a node's debug info or fused debug info.
///
/// \param node An anf node.
void AddFusedDebugInfo(const AnfNodePtr &node) {}
void AddFusedDebugInfo(const AnfNodePtr &node);

/// \brief Add a vector of nodes' debug info or fused debug info.
///
/// \param nodes A vector of anf nodes.
void AddFusedDebugInfoList(const std::vector<AnfNodePtr> &nodes) {}
void AddFusedDebugInfoList(const std::vector<AnfNodePtr> &nodes);

/// \brief Add a node debug info.
///
/// \param node A node debug info of an anf node.
void AddFusedDebugInfo(const NodeDebugInfoPtr &debug_info);

/// \brief Add a list of node debug infos.
///
/// \param node A node debug info of an anf node.
void AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> &debug_infos);

private:
std::vector<AnfNodePtr> inputs_;
@@ -689,8 +701,8 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
std::pair<ValueNodePtr, std::string> output_value_;
mindspore::HashMap<std::string, ValuePtr> attrs_;
mindspore::HashMap<std::string, ValuePtr> primal_attrs_;
std::vector<NodeDebugInfoPtr> primal_debug_infos_;
std::vector<NodeDebugInfoPtr> fused_debug_infos_;
NodeDebugInfoSet primal_debug_infos_;
NodeDebugInfoSet fused_debug_infos_;
ssize_t input_tensor_num_ = -1;
};



+ 5
- 3
mindspore/core/ir/primal_debug_info.h View File

@@ -20,6 +20,7 @@
#include <memory>
#include <stack>
#include <vector>
#include <set>
#include "utils/hash_map.h"
#include "utils/info.h"

@@ -34,14 +35,15 @@ class PrimalDebugInfoManager {
PrimalDebugInfoManager &operator=(const PrimalDebugInfoManager &) = delete;
~PrimalDebugInfoManager() = default;
void SetPrimalDebugInfo(const std::vector<NodeDebugInfoPtr> &primal_debug_infos) {
primal_debug_infos_ = primal_debug_infos;
std::for_each(primal_debug_infos.begin(), primal_debug_infos.end(),
[this](const NodeDebugInfoPtr &debug_info) { primal_debug_infos_.emplace(debug_info); });
}
void ClearPrimalDebugInfo() { primal_debug_infos_.clear(); }
std::vector<NodeDebugInfoPtr> GetCurrentPrimalDebugInfo() { return primal_debug_infos_; }
std::set<NodeDebugInfoPtr, DebugInfoCompare> GetCurrentPrimalDebugInfo() { return primal_debug_infos_; }

private:
PrimalDebugInfoManager() = default;
std::vector<NodeDebugInfoPtr> primal_debug_infos_;
std::set<NodeDebugInfoPtr, DebugInfoCompare> primal_debug_infos_;
};

// PrimalDebugInfoGuard is a class that help generate the back propagation cnode


+ 41
- 1
mindspore/core/utils/info.cc View File

@@ -81,6 +81,14 @@ std::string Location::ToString(SourceLineTip tip) const {
return debug_info_ss.str();
}

bool Location::operator<(const Location &other) const {
auto ret = file_name_.compare(other.file_name());
if (ret != 0) {
return ret < 0;
}
return line_ < other.line();
}

int64_t DebugInfo::get_id() const {
// cppcheck-suppress variableScope
static int64_t current_id = 1;
@@ -121,7 +129,7 @@ std::string GraphDebugInfo::debug_name() {
return name_;
}

LocationPtr GraphDebugInfo::location() {
LocationPtr GraphDebugInfo::location() const {
// Function may have decorator which is included in its location.
auto loc = DebugInfo::location();
if (deco_loc_ != nullptr && loc != nullptr) {
@@ -167,4 +175,36 @@ void TraceManager::ClearParseOrResolveDebugInfo() { TraceManager::parse_or_resol
thread_local std::vector<TraceContext> TraceManager::trace_context_stack_;

thread_local DebugInfoPtr TraceManager::parse_or_resolve_debug_info_ = nullptr;

LocationPtr GetFirstLocation(const DebugInfoPtr &debug_info) {
auto tmp = debug_info;
while (tmp != nullptr) {
if (tmp->location() != nullptr) {
return tmp->location();
}
if (tmp->trace_info() != nullptr) {
tmp = tmp->trace_info()->debug_info();
} else {
break;
}
}
return nullptr;
}

bool DebugInfoCompare::operator()(const DebugInfoPtr &left, const DebugInfoPtr &right) {
MS_EXCEPTION_IF_NULL(left);
MS_EXCEPTION_IF_NULL(right);
if (left == right) {
return false;
}
auto left_loc = GetFirstLocation(left);
auto right_loc = GetFirstLocation(right);
if (left_loc == nullptr || right_loc == nullptr) {
return left < right;
}
if (left_loc == right_loc) {
return false;
}
return *left_loc < *right_loc;
}
} // namespace mindspore

+ 7
- 2
mindspore/core/utils/info.h View File

@@ -42,6 +42,8 @@ class Location {
int column() const { return column_; }
int column_end() const { return column_end_; }

bool operator<(const Location &other) const;

private:
std::string file_name_;
int line_;
@@ -210,7 +212,7 @@ class MS_CORE_API DebugInfo {
/// \brief Get the location.
///
/// \return The location.
virtual LocationPtr location() { return location_; }
virtual LocationPtr location() const { return location_; }

/// \brief Get the name.
///
@@ -321,7 +323,7 @@ class GraphDebugInfo : public DebugInfo {
~GraphDebugInfo() override = default;

std::string debug_name() override;
LocationPtr location() override;
LocationPtr location() const override;
LocationPtr deco_location() { return deco_loc_; }
void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
FuncGraphPtr get_graph() const { return func_graph_.lock(); }
@@ -371,6 +373,9 @@ inline TraceContext::TraceContext(const LocationPtr &loc, const std::string &fun
}
}

struct DebugInfoCompare {
bool operator()(const DebugInfoPtr &left, const DebugInfoPtr &right);
};
} // namespace mindspore

#endif // MINDSPORE_CORE_UTILS_INFO_H_

+ 57
- 10
mindspore/core/utils/trace_base.cc View File

@@ -123,16 +123,19 @@ std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
}

std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) {
std::ostringstream oss;
if (info == nullptr) {
return "";
}

auto debug_info = GetTracedDebugInfo(info, tip);
if (debug_info.empty()) {
return "";
}
if (tip == kSourceLineTipDiscard) {
std::replace(debug_info.begin(), debug_info.end(), '\r', '/');
std::replace(debug_info.begin(), debug_info.end(), '\n', '/');
}
std::ostringstream oss;
oss << prefix << debug_info;
return oss.str();
}
@@ -159,8 +162,12 @@ std::string DumpSourceLines(AnfNode *node) {
return DumpSourceLines(ptr);
}

void GetSourceLineFromDebugInfo(const DebugInfoPtr &debug_info, std::vector<std::string> *result) {
void GetSourceLineFromDebugInfo(const DebugInfoPtr &debug_info, std::vector<std::string> *result,
const std::string &prefix = "") {
MS_EXCEPTION_IF_NULL(result);
auto info_vec = GetSourceCodeDebugInfoVec(debug_info);
const std::string spaces(prefix.size(), ' ');
bool first_line = true;
for (const auto &info : info_vec) {
MS_EXCEPTION_IF_NULL(info);
auto loc = info->location();
@@ -169,7 +176,47 @@ void GetSourceLineFromDebugInfo(const DebugInfoPtr &debug_info, std::vector<std:
}
auto loc_str = loc->ToString(kSourceLineTipDiscard);
ReplaceLinefeed(&loc_str);
result->push_back(loc_str + "\n");
if (first_line) {
result->push_back(prefix + loc_str + "\n");
first_line = false;
} else {
result->push_back(spaces + loc_str + "\n");
}
}
}

void GetFusedDebugInfos(const NodeDebugInfoSet &fused_debug_infos, std::vector<std::string> *result) {
MS_EXCEPTION_IF_NULL(result);
result->push_back("Corresponding code candidate:\n");
// Flag to mark whether fused_debug_infos has valid print.
bool is_empty = true;
for (const auto &debug_info : fused_debug_infos) {
std::vector<std::string> debug_info_vec_str;
GetSourceLineFromDebugInfo(debug_info, &debug_info_vec_str, kSectionPrefix);
if (!debug_info_vec_str.empty()) {
result->insert(result->end(), debug_info_vec_str.begin(), debug_info_vec_str.end());
is_empty = false;
}
}

if (is_empty) {
result->pop_back();
}
}

void GetPrimalDebugInfos(const CNodePtr &cnode, std::vector<std::string> *result) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(result);
auto primal_debug_infos = cnode->primal_debug_infos();
if (!primal_debug_infos.empty()) {
result->emplace_back("Corresponding forward node candidate:\n");
for (const auto &primal_debug_info : primal_debug_infos) {
std::vector<std::string> debug_info_vec_str;
GetSourceLineFromDebugInfo(primal_debug_info, &debug_info_vec_str, kSectionPrefix);
if (!debug_info_vec_str.empty()) {
result->insert(result->end(), debug_info_vec_str.begin(), debug_info_vec_str.end());
}
}
}
}

@@ -179,18 +226,18 @@ std::vector<std::string> GetSourceLineList(const AnfNodePtr &node) {
MS_LOG(WARNING) << "Node is null";
return result;
}
GetSourceLineFromDebugInfo(node->debug_info(), &result);
if (!node->isa<CNode>()) {
GetSourceLineFromDebugInfo(node->debug_info(), &result);
return result;
}
auto cnode = node->cast<CNodePtr>();
auto primal_debug_infos = cnode->primal_debug_infos();
if (!primal_debug_infos.empty()) {
result.emplace_back("Corresponding forward node candidate:\n");
for (auto &primal_debug_info : primal_debug_infos) {
GetSourceLineFromDebugInfo(primal_debug_info, &result);
}
auto fused_debug_infos = cnode->fused_debug_infos();
if (fused_debug_infos.empty()) {
GetSourceLineFromDebugInfo(node->debug_info(), &result);
} else {
GetFusedDebugInfos(fused_debug_infos, &result);
}
GetPrimalDebugInfos(cnode, &result);
return result;
}



+ 2
- 0
mindspore/core/utils/trace_base.h View File

@@ -30,6 +30,8 @@

namespace mindspore {
namespace trace {
constexpr auto kSectionPrefix = " - ";

std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine);
std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix,
SourceLineTip tip = kSourceLineTipNextLine);


Loading…
Cancel
Save