You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

segment_runner.cc 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "vm/segment_runner.h"
  19. #include <algorithm>
  20. #include <functional>
  21. #include <memory>
  22. #include <set>
  23. #include <tuple>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <string>
  27. #include "utils/log_adapter.h"
  28. #include "utils/utils.h"
  29. #include "ir/manager.h"
  30. #include "ir/func_graph_cloner.h"
  31. #include "frontend/operator/ops.h"
  32. namespace mindspore {
  33. const char kMsConvert[] = "ms";
  34. const char kMsVm[] = "vm";
  35. const char kGeVm[] = "ge";
  36. namespace compile {
  37. // cached conversion
  38. ConvertCache g_ConvertCache;
  39. void ClearConvertCache() { g_ConvertCache.clear(); }
  40. // Return the list of nodes whose values are required beyond this segment.
  41. // Arguments:
  42. // lst: list of nodes (the segment)
  43. // users: dict mapping each node to its users (globally)
  44. // seen: set of nodes that are part of the segment
  45. AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, const std::vector<AnfNodePtr> &seen) {
  46. AnfNodePtrList output;
  47. if (users.size() == 0) {
  48. return output;
  49. }
  50. (void)std::transform(
  51. std::begin(lst), std::end(lst), std::back_inserter(output), [&users, &seen](AnfNodePtr n) -> AnfNodePtr {
  52. auto usersn = users.find(n);
  53. bool is_referred_out_of_segment = std::any_of(
  54. std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair<AnfNodePtr, int> &u) -> bool {
  55. return std::find(std::begin(seen), std::end(seen), u.first) == std::end(seen);
  56. });
  57. if (n->isa<CNode>() && is_referred_out_of_segment) {
  58. return n;
  59. }
  60. return nullptr;
  61. });
  62. // remove nullptr
  63. for (auto it = output.begin(); it != output.end();) {
  64. if (*it == nullptr) {
  65. it = output.erase(it);
  66. } else {
  67. ++it;
  68. }
  69. }
  70. return output;
  71. }
  72. namespace {
  73. AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr,
  74. AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
  75. MS_EXCEPTION_IF_NULL(fg);
  76. MS_EXCEPTION_IF_NULL(inputs_ptr);
  77. MS_EXCEPTION_IF_NULL(eqv_ptr);
  78. MS_EXCEPTION_IF_NULL(node);
  79. auto &inputs = *inputs_ptr;
  80. auto &eqv = *eqv_ptr;
  81. if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
  82. eqv[node] = node;
  83. } else if (eqv.find(node) == eqv.end()) {
  84. bool ignore_make_tuple = false;
  85. if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
  86. ignore_make_tuple = true;
  87. auto cnode = node->cast<CNodePtr>();
  88. MS_EXCEPTION_IF_NULL(cnode);
  89. const auto &node_inputs = cnode->inputs();
  90. for (size_t i = 1; i < node_inputs.size(); ++i) {
  91. if (!IsPrimitiveCNode(node_inputs[i], prim::kPrimControlDepend)) {
  92. ignore_make_tuple = false;
  93. break;
  94. }
  95. }
  96. }
  97. if (!ignore_make_tuple) {
  98. inputs.push_back(node);
  99. }
  100. eqv[node] = fg->add_parameter();
  101. eqv[node]->set_abstract(node->abstract());
  102. eqv[node]->set_kernel_info(node->kernel_info_ptr());
  103. }
  104. return eqv[node];
  105. }
  106. } // namespace
  107. std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
  108. if (lst.empty()) {
  109. MS_LOG(EXCEPTION) << "Input anf node list is empty";
  110. }
  111. TraceManager::DebugTrace(
  112. std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info()));
  113. auto fg = std::make_shared<FuncGraph>();
  114. TraceManager::EndTrace();
  115. AnfNodePtrList inputs;
  116. AnfNodePtrToAnfNodePtrMap eqv;
  117. // Merge CNodes into a AnfGraph that represents a linear instruction segment
  118. for (auto n : lst) {
  119. if (!n->isa<CNode>()) {
  120. MS_LOG(EXCEPTION) << "Inst is not CNode";
  121. }
  122. auto &inps = n->cast<CNodePtr>()->inputs();
  123. if (inps.empty()) {
  124. MS_LOG(EXCEPTION) << "Input is empty";
  125. }
  126. if (!IsValueNode<Primitive>(inps[0]) &&
  127. !(IsValueNode<FuncGraph>(inps[0]) &&
  128. inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
  129. MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive valuenode";
  130. }
  131. auto fn = inps[0];
  132. std::vector<AnfNodePtr> args{fn};
  133. if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
  134. args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv));
  135. args.emplace_back(NewValueNode(MakeValue(0)));
  136. } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) {
  137. for (size_t i = 1; i < inps.size(); ++i) {
  138. if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) {
  139. args.emplace_back(NewValueNode(MakeValue(i)));
  140. } else {
  141. args.emplace_back(RefSubGraphNode(fg, inps[i], &inputs, &eqv));
  142. }
  143. }
  144. } else {
  145. (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
  146. [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
  147. }
  148. TraceManager::DebugTrace(std::make_shared<TraceGetEnv>(n->debug_info()));
  149. eqv[n] = fg->NewCNode(args);
  150. TraceManager::EndTrace();
  151. eqv[n]->set_abstract(n->abstract());
  152. eqv[n]->set_kernel_info(n->kernel_info_ptr());
  153. }
  154. std::vector<AnfNodePtr> eqv_keys;
  155. (void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys),
  156. [](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });
  157. auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys);
  158. AnfNodePtr fg_output;
  159. if (outputs.size() > 1) {
  160. std::vector<AnfNodePtr> output_args;
  161. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  162. (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args),
  163. [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
  164. // Set output for AnfGraph
  165. fg_output = fg->NewCNode(output_args);
  166. } else {
  167. fg_output = eqv[outputs[0]];
  168. }
  169. fg->set_output(fg_output);
  170. return std::make_tuple(fg, inputs, outputs);
  171. }
  172. // Converts the list of nodes to a runnable form.
  173. // All the nodes in the list must represent linear flow (no calls, branches, ...)
  174. // Returns:
  175. // (fn, inputs, outputs):
  176. // - fn: A callable function
  177. // - inputs: the list of inputs nodes whose values should be
  178. // provided to the function
  179. // - outputs: the list of output nodes corresponding to the
  180. // outputs of the function
  181. // Notes:
  182. // This implementation will convert the nodes into a subgraph
  183. // that will run using the MsVM.
  184. template <typename T>
  185. LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) {
  186. auto cached = g_ConvertCache.find(lst);
  187. if (cached != g_ConvertCache.end()) {
  188. return cached->second;
  189. }
  190. LinConvertResult result;
  191. FuncGraphPtr fg = nullptr;
  192. AnfNodePtrList inputs;
  193. AnfNodePtrList outputs;
  194. std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst);
  195. // Clone in case g contains subgraphs that have a different manager
  196. fg = BasicClone(fg);
  197. std::shared_ptr<VMImpl> vm = std::make_shared<T>();
  198. result.run =
  199. std::make_shared<RunFunc>([fg, vm](const VectorRef &args) -> VectorRef { return vm->RunGraph(fg, args); });
  200. result.inputs = inputs;
  201. result.outputs = outputs;
  202. result.graph_id = UINT32_MAX;
  203. (void)g_ConvertCache.emplace(lst, result);
  204. return result;
  205. }
  206. LinkFuncType MsVmConvert = Convert<VM>;
  207. std::unordered_map<std::string, LinkFuncType> backends = {{kMsVm, MsVmConvert}};
  208. std::set<std::string> backend_list = {
  209. kMsConvert,
  210. kMsVm,
  211. };
  212. } // namespace compile
  213. } // namespace mindspore