You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernel_flags.h 7.9 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
  17. #define MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include <utility>
  23. namespace mindspore::graphkernel {
  24. constexpr unsigned int OptLevel_0 = 0; // Disabled
  25. constexpr unsigned int OptLevel_1 = 1; // Basic functions
  26. constexpr unsigned int OptLevel_2 = 2; // Default functions
  27. constexpr unsigned int OptLevel_3 = 3; // Experimental functions
  28. constexpr unsigned int OptLevel_MAX = 4;
  29. constexpr unsigned int OpLevel_0 = 0;
  30. constexpr unsigned int OpLevel_1 = 1;
  31. constexpr unsigned int OpLevel_MAX = 2;
  32. class GraphKernelFlags {
  33. public:
  34. static const GraphKernelFlags &GetInstance() {
  35. static std::unique_ptr<GraphKernelFlags> flags(nullptr);
  36. auto contexts = GetGraphKernelContext();
  37. if (flags == nullptr || contexts.first != flags->flags_cache_ || contexts.second != flags->enable_graph_kernel_) {
  38. flags.reset(new GraphKernelFlags(contexts.first, contexts.second));
  39. flags->Refresh();
  40. }
  41. return *flags;
  42. }
  43. // Dump all flags to json-format string
  44. std::string DumpAllFlags() const;
  45. // Check whether graph_kernel is enabled
  46. bool IsEnableGraphKernel() const { return opt_level > OptLevel_0; }
  47. // Check whether GraphKernel supports current situation.
  48. void CheckSupport() const;
  49. GraphKernelFlags(const GraphKernelFlags &flags) = delete;
  50. GraphKernelFlags(GraphKernelFlags &&flags) = delete;
  51. void operator=(const GraphKernelFlags &flags) = delete;
  52. ~GraphKernelFlags() = default;
  53. public:
  54. /**
  55. * Dump info as human-readable text.
  56. * A directory "graph_kernel_dump" will be created, and all information will be dumped in this directory.
  57. */
  58. bool dump_as_text{false};
  59. /**
  60. * Enable stitch fusion in graph kernel fusion strategy.
  61. *
  62. * Experimental feature, enabled by default when opt_level=3
  63. */
  64. bool enable_stitch_fusion{false};
  65. /**
  66. * Enable recompute fusion in graph kernel fusion strategy, enabled when op_level>=2.
  67. */
  68. bool enable_recompute_fusion{false};
  69. /**
  70. * Enable parallel fusion in graph kernel fusion strategy.
  71. *
  72. * Experimental feature, enabled by default when opt_level=3
  73. */
  74. bool enable_parallel_fusion{false};
  75. /**
  76. * Parallel AKG's operators by level.
  77. * 0: Parallel operators by local data relation analyzation with less memory influence.
  78. * 1: Parallel operators with global analyzation with more memory influence.
  79. */
  80. unsigned int parallel_ops_level{OpLevel_0};
  81. /**
  82. * Enable low precision in data transferring between graph kernel and computing in graph kernel
  83. * in graph kernel.
  84. * Experimental feature, enabled by the enable_low_precision flag
  85. */
  86. bool enable_low_precision{false};
  87. /**
  88. * Expand and cluster AKG's operators by level.
  89. */
  90. unsigned int fusion_ops_level{OpLevel_0};
  91. /**
  92. * Enable optimization for transform operators (Transpose/TransData)
  93. *
  94. * Experimental feature, enabled by default when opt_level=3.
  95. */
  96. bool enable_trans_op_optimize{false};
  97. /**
  98. * Optimization level, value from 0 to 3.
  99. * 0: Disable GraphKernel
  100. * 1: Enable GraphKernel with basic features only.
  101. * 2: Enable GraphKernel with all stable features.
  102. * 3: Enable GraphKernel with all experimental features.
  103. * The default value is OptLevel_2 when the context "enable_graph_kernel" is set,
  104. * but if it's also changed in "graph_kernel_flags", then the "graph_kernel_flags" will prevail.
  105. */
  106. unsigned int opt_level{0}; // defaults 0 or 2
  107. /**
  108. * Online tuning level, value from 0 to 3.
  109. * 0: Disable online tuning
  110. * 1-3: The higher level, the larger tuning space, and the more time it takes.
  111. */
  112. unsigned int online_tuning{0};
  113. /**
  114. * Threshold for detection of recopute's memory increment case, unit is byte.
  115. */
  116. int64_t recompute_increment_threshold{0};
  117. /**
  118. * Threshold for detection of recopute's memory peak case, unit is byte.
  119. */
  120. int64_t recompute_peak_threshold{0};
  121. /**
  122. * AKG's operator repository file path.
  123. */
  124. std::string repository_path;
  125. /**
  126. * Additional expanding operators (case sensitive).
  127. * The operators to be added into the default expanding operator list.
  128. */
  129. std::vector<std::string> enable_expand_ops;
  130. /**
  131. * Expanding operators to be enabled (case sensitive).
  132. * Unlike the "enable_expand_ops", the default list will be overwritten by this list.
  133. * Note that the "enable_expand_ops" and "disable_expand_ops" will be ignored if this flag is set.
  134. */
  135. std::vector<std::string> enable_expand_ops_only;
  136. /**
  137. * Expanding operators to be disabled (case sensitive).
  138. * The behavior is undefined when this list overlaps with "enable_expand_ops".
  139. */
  140. std::vector<std::string> disable_expand_ops;
  141. /**
  142. * Additional clustering operators (case sensitive).
  143. * The operators to be added into the default clustering operator list.
  144. */
  145. std::vector<std::string> enable_cluster_ops;
  146. /**
  147. * Clustering operators to be enabled (case sensitive).
  148. * Unlike the "enable_cluster_ops", the default list will be overwritten by this list.
  149. * Note that the "enable_cluster_ops" and "disable_cluster_ops" will be ignored if this flag is set.
  150. */
  151. std::vector<std::string> enable_cluster_ops_only;
  152. /**
  153. * Clustering operators to be disabled (case sensitive).
  154. * The behavior is undefined when this list overlaps with "enable_cluster_ops".
  155. */
  156. std::vector<std::string> disable_cluster_ops;
  157. /**
  158. * Arithmetic simplify expressions to be enabled (case sensitive).
  159. * The default list will be overwritten by this list.
  160. * Note that "disable_simplify_exprs" will be ignored if this flag is set.
  161. */
  162. std::vector<std::string> enable_simplify_exprs_only;
  163. /**
  164. * Arithmetic simplify expressions to be disabled (case sensitive).
  165. */
  166. std::vector<std::string> disable_simplify_exprs;
  167. /**
  168. * Passes to be enabled.
  169. * By default, the passes is controlled by "opt_level" and target device,
  170. * user can manually enable some passes by setting this flag.
  171. * The format is "stage_id.pass_id" or "stage_name.pass_name", which corresponds to the ir filename.
  172. */
  173. std::vector<std::string> enable_pass;
  174. /**
  175. * Passes to be disabled.
  176. * By default, the passes is controlled by "opt_level" and target device,
  177. * user can manually disable some passes by setting this flag.
  178. * The format is "stage_id.pass_id" or "stage_name.pass_name", which corresponds to the ir filename.
  179. */
  180. std::vector<std::string> disable_pass;
  181. private:
  182. GraphKernelFlags(const std::string &graph_kernel_flags, bool enable_graph_kernel)
  183. : flags_cache_(graph_kernel_flags), enable_graph_kernel_(enable_graph_kernel) {}
  184. // get the `graph_kernel_flags` and `enable_graph_kernel`
  185. static std::pair<std::string, bool> GetGraphKernelContext();
  186. // parse and refresh the flags
  187. void Refresh();
  188. // register the flags defined above
  189. void RegisterFlags(std::map<std::string, std::string> *flag_map);
  190. // cache the flag string to check whether the flags is changed.
  191. std::string flags_cache_;
  192. // cache the enable_graph_kernel value to check whether the context is changed.
  193. bool enable_graph_kernel_;
  194. };
  195. } // namespace mindspore::graphkernel
  196. #endif // MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H