You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

info.h 8.4 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CORE_UTILS_INFO_H_
  17. #define MINDSPORE_CORE_UTILS_INFO_H_
  18. #include <iostream>
  19. #include <string>
  20. #include <memory>
  21. #include <stack>
  22. #include <utility>
  23. #include <vector>
  24. #include "base/base.h"
  25. #include "utils/trace_info.h"
  26. namespace mindspore {
  27. // namespace to support intermediate representation definition
  28. enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 };
  29. // Location class record the location in source code.
  30. class Location {
  31. public:
  32. Location(const std::string &file_name, int line, int column, int line_end, int column_end)
  33. : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {}
  34. Location(const Location &loc)
  35. : file_name_(loc.file_name_),
  36. line_(loc.line_),
  37. column_(loc.column_),
  38. line_end_(loc.line_end_),
  39. column_end_(loc.column_end_) {}
  40. std::string ToString(SourceLineTip tip = kSourceLineTipNextLine) const;
  41. std::string file_name() const { return file_name_; }
  42. int line() const { return line_; }
  43. void set_line(int line) { line_ = line; }
  44. int line_end() const { return line_end_; }
  45. void set_line_end(int line) { line_end_ = line; }
  46. int column() const { return column_; }
  47. void set_column(int column) { column_ = column; }
  48. int column_end() const { return column_end_; }
  49. void set_column_end(int column) { column_end_ = column; }
  50. ~Location() = default;
  51. private:
  52. std::string file_name_;
  53. int line_;
  54. int column_;
  55. int line_end_;
  56. int column_end_;
  57. };
  58. class TraceContext;
  59. using TraceContextPtr = std::shared_ptr<TraceContext>;
  60. class TraceManager {
  61. public:
  62. TraceManager() = default;
  63. ~TraceManager() = default;
  64. static TraceContextPtr CurrentContextInfo();
  65. static void DebugTrace(const std::string &func_name, const LocationPtr &location);
  66. static void DebugTrace(const LocationPtr &location);
  67. static void DebugTrace(const TraceInfoPtr &trace_info);
  68. // debug trace with a cloned trace info with debug_info
  69. static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info);
  70. static void EndTrace();
  71. static void ClearParseOrResolveDebugInfo();
  72. static DebugInfoPtr GetParseOrResolveDebugInfo();
  73. thread_local static std::stack<TraceContextPtr> trace_context_stack_;
  74. thread_local static DebugInfoPtr parse_or_resolve_debug_info_;
  75. };
  76. class TraceGuard {
  77. public:
  78. TraceGuard(const std::string func_name, const LocationPtr &location) {
  79. TraceManager::DebugTrace(func_name, location);
  80. }
  81. explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); }
  82. explicit TraceGuard(const TraceInfoPtr &trace_info) { TraceManager::DebugTrace(trace_info); }
  83. TraceGuard(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) {
  84. TraceManager::DebugTrace(debug_info, trace_info);
  85. }
  86. ~TraceGuard() { TraceManager::EndTrace(); }
  87. };
  88. class TraceContext {
  89. public:
  90. ~TraceContext() = default;
  91. explicit TraceContext(const LocationPtr &loc) {
  92. ProcessAttributeFromContext();
  93. location_ = loc;
  94. }
  95. explicit TraceContext(const std::string &func_name) {
  96. ProcessAttributeFromContext();
  97. func_name_ = func_name;
  98. }
  99. explicit TraceContext(const TraceInfoPtr &trace_info) {
  100. ProcessAttributeFromContext();
  101. trace_info_ = trace_info;
  102. }
  103. void set_location(const LocationPtr &loc) { location_ = loc; }
  104. LocationPtr location() { return location_; }
  105. void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
  106. TraceInfoPtr trace_info() const { return trace_info_; }
  107. void set_func_name(const std::string &func_name) { func_name_ = func_name; }
  108. std::string func_name() { return func_name_; }
  109. protected:
  110. void ProcessAttributeFromContext();
  111. private:
  112. LocationPtr location_;
  113. TraceInfoPtr trace_info_;
  114. std::string func_name_;
  115. };
  116. class DebugInfo : public Base {
  117. public:
  118. DebugInfo();
  119. explicit DebugInfo(const std::string &name);
  120. explicit DebugInfo(const LocationPtr &loc);
  121. ~DebugInfo() override = default;
  122. MS_DECLARE_PARENT(DebugInfo, Base);
  123. int64_t debug_id();
  124. int64_t unique_id() const { return unique_id_; }
  125. int64_t unique_id_through_copy() const;
  126. std::string get_id() { return std::to_string(debug_id()); }
  127. void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
  128. TraceInfoPtr trace_info() const { return trace_info_; }
  129. void set_location(const LocationPtr &loc) { location_ = loc; }
  130. virtual LocationPtr location() { return location_; }
  131. std::string name() { return name_; }
  132. void set_name(const std::string &name) { name_ = name; }
  133. virtual std::string debug_name();
  134. virtual std::string get_python_func_belonged() { return ""; }
  135. protected:
  136. template <typename Derived>
  137. std::shared_ptr<Derived> shared_from_base() {
  138. return std::static_pointer_cast<Derived>(shared_from_this());
  139. }
  140. private:
  141. void InitValueFromContext() {
  142. if (TraceManager::CurrentContextInfo() != nullptr) {
  143. auto context_info = TraceManager::CurrentContextInfo();
  144. trace_info_ = context_info->trace_info();
  145. location_ = context_info->location();
  146. }
  147. }
  148. static int64_t gen_unique_id() {
  149. static int64_t cur_unique_id = 0;
  150. return cur_unique_id++;
  151. }
  152. protected:
  153. int64_t unique_id_;
  154. int64_t debug_id_;
  155. TraceInfoPtr trace_info_;
  156. LocationPtr location_;
  157. std::string name_;
  158. };
  159. class NodeDebugInfo : public DebugInfo {
  160. public:
  161. NodeDebugInfo() {
  162. if (TraceManager::CurrentContextInfo() != nullptr) {
  163. auto context_info = TraceManager::CurrentContextInfo();
  164. py_func_belonged_ = context_info->func_name();
  165. }
  166. }
  167. explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) {
  168. if (TraceManager::CurrentContextInfo() != nullptr) {
  169. auto context_info = TraceManager::CurrentContextInfo();
  170. py_func_belonged_ = context_info->func_name();
  171. }
  172. }
  173. ~NodeDebugInfo() override = default;
  174. std::string debug_name() override;
  175. void set_node(const std::shared_ptr<AnfNode> &node) { node_ = AnfNodeWeakPtr(node); }
  176. std::shared_ptr<AnfNode> get_node() const { return node_.lock(); }
  177. void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; }
  178. std::string get_python_func_belonged() override { return py_func_belonged_; }
  179. private:
  180. AnfNodeWeakPtr node_;
  181. std::string py_func_belonged_;
  182. };
  183. using NodeDebugInfoPtr = std::shared_ptr<NodeDebugInfo>;
  184. class GraphDebugInfo : public DebugInfo {
  185. public:
  186. GraphDebugInfo() {
  187. if (TraceManager::CurrentContextInfo() != nullptr) {
  188. auto context_info = TraceManager::CurrentContextInfo();
  189. py_func_name_ = context_info->func_name();
  190. deco_loc_ = nullptr;
  191. }
  192. }
  193. explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) {
  194. if (TraceManager::CurrentContextInfo() != nullptr) {
  195. auto context_info = TraceManager::CurrentContextInfo();
  196. py_func_name_ = context_info->func_name();
  197. deco_loc_ = nullptr;
  198. }
  199. }
  200. ~GraphDebugInfo() override = default;
  201. std::string debug_name() override;
  202. LocationPtr location() override;
  203. LocationPtr deco_location() { return deco_loc_; }
  204. void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
  205. FuncGraphPtr get_graph() const { return func_graph_.lock(); }
  206. void set_full_name(const std::string &name) { full_name_ = name; }
  207. std::string get_full_name() { return full_name_; }
  208. void set_deco_location(const LocationPtr &deco_list_loc);
  209. std::string get_python_func_belonged() override { return py_func_name_; }
  210. private:
  211. FuncGraphWeakPtr func_graph_;
  212. LocationPtr deco_loc_;
  213. std::string py_func_name_;
  214. std::string full_name_;
  215. };
  216. using GraphDebugInfoPtr = std::shared_ptr<GraphDebugInfo>;
  217. } // namespace mindspore
  218. #endif // MINDSPORE_CORE_UTILS_INFO_H_