You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernel_flags.h 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
  17. #define MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include <utility>
  23. #include "utils/ms_context.h"
  24. namespace mindspore {
  25. namespace context {
  26. class GraphKernelFlags {
  27. public:
  28. static const GraphKernelFlags &GetInstance() {
  29. static std::unique_ptr<GraphKernelFlags> flags(nullptr);
  30. auto contexts = GetGraphKernelContext();
  31. if (flags == nullptr || contexts.first != flags->flags_cache_ || contexts.second != flags->enable_cache_) {
  32. flags.reset(new GraphKernelFlags(contexts.first, contexts.second));
  33. flags->Refresh();
  34. }
  35. return *flags;
  36. }
  37. // Dump all flags to json-format string
  38. std::string DumpAllFlags() const;
  39. // Check whether graph_kernel is enabled
  40. bool IsEnableGraphKernel() const { return opt_level > 0; }
  41. GraphKernelFlags(const GraphKernelFlags &flags) = delete;
  42. ~GraphKernelFlags() = default;
  43. public:
  44. /**
  45. * Dump info as human-readable text.
  46. * A directory "graph_kernel_dump" will be created, and all information will be dumped in this directory.
  47. */
  48. bool dump_as_text{false};
  49. /**
  50. * Enable stitch fusion in graph kernel fusion strategy.
  51. */
  52. bool enable_stitch_fusion{false};
  53. /**
  54. * Enable parallel fusion in graph kernel fusion strategy.
  55. */
  56. bool enable_parallel_fusion{false};
  57. /**
  58. * Optimization level, value from 0 to 3.
  59. * 0: GraphKernel disabled
  60. * 1: GraphKernel enabled
  61. * 2 and 3 are not supported now.
  62. * the default value is controlled by context `enable_graph_kernel`,
  63. * but if it's also set in `graph_kernel_flags`, then the flag will prevail.
  64. */
  65. unsigned int opt_level{0};
  66. /**
  67. * auto_tune, unsupported now.
  68. */
  69. unsigned int auto_tune{0};
  70. /**
  71. * cluster_limit, unsupported now.
  72. */
  73. unsigned int cluster_limit{30};
  74. /**
  75. * Additional expanding operators (case sensitive).
  76. * The operators to be added into the default expanding operator list.
  77. */
  78. std::vector<std::string> enable_expand_ops;
  79. /**
  80. * Expanding operators to be enabled (case sensitive).
  81. * Unlike the "enable_expand_ops", the default list will be overwritten by this list.
  82. * Note that the "enable_expand_ops" and "disable_expand_ops" will be ignored if this flag is set.
  83. */
  84. std::vector<std::string> enable_expand_ops_only;
  85. /**
  86. * Expanding operators to be disabled (case sensitive).
  87. * The behavior is undefined when this list overlaps with "enable_expand_ops".
  88. */
  89. std::vector<std::string> disable_expand_ops;
  90. /**
  91. * Additional clustering operators (case sensitive).
  92. * The operators to be added into the default clustering operator list.
  93. */
  94. std::vector<std::string> enable_cluster_ops;
  95. /**
  96. * Clustering operators to be enabled (case sensitive).
  97. * Unlike the "enable_cluster_ops", the default list will be overwritten by this list.
  98. * Note that the "enable_cluster_ops" and "disable_cluster_ops" will be ignored if this flag is set.
  99. */
  100. std::vector<std::string> enable_cluster_ops_only;
  101. /**
  102. * Clustering operators to be disabled (case sensitive).
  103. * The behavior is undefined when this list overlaps with "enable_cluster_ops".
  104. */
  105. std::vector<std::string> disable_cluster_ops;
  106. /**
  107. * enable_pass_only, unsupported now.
  108. */
  109. std::vector<std::string> enable_pass_only;
  110. /**
  111. * disable_pass, unsupported now.
  112. */
  113. std::vector<std::string> disable_pass;
  114. private:
  115. GraphKernelFlags(const std::string &graph_kernel_flags, bool enable_graph_kernel)
  116. : flags_cache_(graph_kernel_flags), enable_cache_(enable_graph_kernel) {
  117. opt_level = enable_graph_kernel ? 1 : 0;
  118. }
  119. // get the `graph_kernel_flags` and `enable_graph_kernel`
  120. static std::pair<std::string, bool> GetGraphKernelContext() {
  121. auto context = MsContext::GetInstance();
  122. MS_EXCEPTION_IF_NULL(context);
  123. // Use the environment variable in priority
  124. auto env_flags = std::getenv("MS_GRAPH_KERNEL_FLAGS");
  125. std::string flags = env_flags ? std::string(env_flags) : context->get_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS);
  126. return std::make_pair(flags, context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL));
  127. }
  128. // parse and refresh the flags
  129. void Refresh();
  130. // register the flags defined above
  131. void RegisterFlags(std::map<std::string, std::string> *flag_map);
  132. // cache the flag string to check whether the flags is changed.
  133. std::string flags_cache_;
  134. // cache the enable_graph_kernel value to check whether the context is changed.
  135. bool enable_cache_;
  136. };
  137. } // namespace context
  138. } // namespace mindspore
  139. #endif // MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H