You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

segment_runner.cc 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF gNY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "vm/segment_runner.h"
  19. #include <algorithm>
  20. #include <functional>
  21. #include <memory>
  22. #include <set>
  23. #include <tuple>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <string>
  27. #include "utils/log_adapter.h"
  28. #include "ir/manager.h"
  29. #include "ir/func_graph_cloner.h"
  30. #include "operator/ops.h"
  31. namespace mindspore {
  32. const char kMsConvert[] = "ms";
  33. const char kMsVm[] = "vm";
  34. const char kGeVm[] = "ge";
  35. namespace compile {
  36. // cached conversion
  37. ConvertCache g_ConvertCache;
  38. void ClearConvertCache() { g_ConvertCache.clear(); }
  39. // Return the list of nodes whose values are required beyond this segment.
  40. // Arguments:
  41. // lst: list of nodes (the segment)
  42. // users: dict mapping each node to its users (globally)
  43. // seen: set of nodes that are part of the segment
  44. AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, const std::vector<AnfNodePtr> &seen) {
  45. AnfNodePtrList output;
  46. if (users.size() == 0) {
  47. return output;
  48. }
  49. (void)std::transform(
  50. std::begin(lst), std::end(lst), std::back_inserter(output), [&users, &seen](AnfNodePtr n) -> AnfNodePtr {
  51. auto usersn = users.find(n);
  52. bool is_referred_out_of_segment = std::any_of(
  53. std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair<AnfNodePtr, int> &u) -> bool {
  54. return std::find(std::begin(seen), std::end(seen), u.first) == std::end(seen);
  55. });
  56. if (n->isa<CNode>() && is_referred_out_of_segment) {
  57. return n;
  58. }
  59. return nullptr;
  60. });
  61. // remove nullptr
  62. for (auto it = output.begin(); it != output.end();) {
  63. if (*it == nullptr) {
  64. it = output.erase(it);
  65. } else {
  66. ++it;
  67. }
  68. }
  69. return output;
  70. }
  71. std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
  72. auto fg = std::make_shared<FuncGraph>();
  73. AnfNodePtrList inputs;
  74. AnfNodePtrToAnfNodePtrMap eqv;
  75. if (lst.empty()) {
  76. MS_LOG(EXCEPTION) << "Input anf node list is empty";
  77. }
  78. auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr {
  79. if (a->isa<ValueNode>() && !IsValueNode<FuncGraph>(a)) {
  80. eqv[a] = a;
  81. } else if (eqv.find(a) == eqv.end()) {
  82. inputs.push_back(a);
  83. eqv[a] = fg->add_parameter();
  84. }
  85. return eqv[a];
  86. };
  87. // Merge CNodes into a AnfGraph that represents a linear instruction segment
  88. for (auto n : lst) {
  89. if (!n->isa<CNode>()) {
  90. MS_LOG(EXCEPTION) << "Inst is not CNode";
  91. }
  92. auto &inps = n->cast<CNodePtr>()->inputs();
  93. if (inps.empty()) {
  94. MS_LOG(EXCEPTION) << "Input is empty";
  95. }
  96. if (!IsValueNode<Primitive>(inps[0])) {
  97. MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive valuenode";
  98. }
  99. auto fn = inps[0];
  100. std::vector<AnfNodePtr> args{fn};
  101. (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref);
  102. eqv[n] = fg->NewCNode(args);
  103. }
  104. std::vector<AnfNodePtr> eqv_keys;
  105. (void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys),
  106. [](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });
  107. auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys);
  108. std::vector<AnfNodePtr> output_args;
  109. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  110. (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args),
  111. [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
  112. // Set output for AnfGraph
  113. auto fg_output = fg->NewCNode(output_args);
  114. fg->set_output(fg_output);
  115. return std::make_tuple(fg, inputs, outputs);
  116. }
  117. // Converts the list of nodes to a runnable form.
  118. // All the nodes in the list must represent linear flow (no calls, branches, ...)
  119. // Returns:
  120. // (fn, inputs, outputs):
  121. // - fn: A callable function
  122. // - inputs: the list of inputs nodes whose values should be
  123. // provided to the function
  124. // - outputs: the list of output nodes corresponding to the
  125. // outputs of the function
  126. // Notes:
  127. // This implementation will convert the nodes into a subgraph
  128. // that will run using the MsVM.
  129. template <typename T>
  130. LinConvertResult Convert(const AnfNodePtrList &lst) {
  131. auto cached = g_ConvertCache.find(lst);
  132. if (cached != g_ConvertCache.end()) {
  133. return cached->second;
  134. }
  135. LinConvertResult result;
  136. FuncGraphPtr fg = nullptr;
  137. AnfNodePtrList inputs;
  138. AnfNodePtrList outputs;
  139. std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst);
  140. // Clone in case g contains subgraphs that have a different manager
  141. fg = BasicClone(fg);
  142. std::shared_ptr<VMImpl> vm = std::make_shared<T>();
  143. result.run =
  144. std::make_shared<RunFunc>([fg, vm](const VectorRef &args) -> VectorRef { return vm->RunGraph(fg, args); });
  145. result.inputs = inputs;
  146. result.outputs = outputs;
  147. result.graph_id = UINT32_MAX;
  148. (void)g_ConvertCache.emplace(lst, result);
  149. return result;
  150. }
  151. LinkFuncType MsVmConvert = Convert<VM>;
  152. std::unordered_map<std::string, LinkFuncType> backends = {{kMsVm, MsVmConvert}};
  153. std::set<std::string> backend_list = {
  154. kMsConvert,
  155. kMsVm,
  156. };
  157. } // namespace compile
  158. } // namespace mindspore