You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

segment_runner.cc 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF gNY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "vm/segment_runner.h"
  19. #include <algorithm>
  20. #include <functional>
  21. #include <memory>
  22. #include <set>
  23. #include <tuple>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <string>
  27. #include "utils/log_adapter.h"
  28. #include "utils/utils.h"
  29. #include "ir/manager.h"
  30. #include "ir/func_graph_cloner.h"
  31. #include "operator/ops.h"
  32. namespace mindspore {
  33. const char kMsConvert[] = "ms";
  34. const char kMsVm[] = "vm";
  35. const char kGeVm[] = "ge";
  36. namespace compile {
  37. // cached conversion
  38. ConvertCache g_ConvertCache;
  39. void ClearConvertCache() { g_ConvertCache.clear(); }
  40. // Return the list of nodes whose values are required beyond this segment.
  41. // Arguments:
  42. // lst: list of nodes (the segment)
  43. // users: dict mapping each node to its users (globally)
  44. // seen: set of nodes that are part of the segment
  45. AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, const std::vector<AnfNodePtr> &seen) {
  46. AnfNodePtrList output;
  47. if (users.size() == 0) {
  48. return output;
  49. }
  50. (void)std::transform(
  51. std::begin(lst), std::end(lst), std::back_inserter(output), [&users, &seen](AnfNodePtr n) -> AnfNodePtr {
  52. auto usersn = users.find(n);
  53. bool is_referred_out_of_segment = std::any_of(
  54. std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair<AnfNodePtr, int> &u) -> bool {
  55. return std::find(std::begin(seen), std::end(seen), u.first) == std::end(seen);
  56. });
  57. if (n->isa<CNode>() && is_referred_out_of_segment) {
  58. return n;
  59. }
  60. return nullptr;
  61. });
  62. // remove nullptr
  63. for (auto it = output.begin(); it != output.end();) {
  64. if (*it == nullptr) {
  65. it = output.erase(it);
  66. } else {
  67. ++it;
  68. }
  69. }
  70. return output;
  71. }
  72. std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
  73. auto fg = std::make_shared<FuncGraph>();
  74. AnfNodePtrList inputs;
  75. AnfNodePtrToAnfNodePtrMap eqv;
  76. if (lst.empty()) {
  77. MS_LOG(EXCEPTION) << "Input anf node list is empty";
  78. }
  79. auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr {
  80. if (a->isa<ValueNode>() && !IsValueNode<FuncGraph>(a)) {
  81. eqv[a] = a;
  82. } else if (eqv.find(a) == eqv.end()) {
  83. inputs.push_back(a);
  84. eqv[a] = fg->add_parameter();
  85. eqv[a]->set_abstract(a->abstract());
  86. eqv[a]->set_kernel_info(a->kernel_info_ptr());
  87. }
  88. return eqv[a];
  89. };
  90. // Merge CNodes into a AnfGraph that represents a linear instruction segment
  91. for (auto n : lst) {
  92. if (!n->isa<CNode>()) {
  93. MS_LOG(EXCEPTION) << "Inst is not CNode";
  94. }
  95. auto &inps = n->cast<CNodePtr>()->inputs();
  96. if (inps.empty()) {
  97. MS_LOG(EXCEPTION) << "Input is empty";
  98. }
  99. if (!IsValueNode<Primitive>(inps[0]) &&
  100. !(IsValueNode<FuncGraph>(inps[0]) &&
  101. inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
  102. MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive valuenode";
  103. }
  104. auto fn = inps[0];
  105. std::vector<AnfNodePtr> args{fn};
  106. if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && inps[kRealInputIndexInDepend]->isa<ValueNode>() &&
  107. eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
  108. args.emplace_back(inps[kRealInputIndexInDepend]);
  109. args.emplace_back(inps[kRealInputIndexInDepend]);
  110. } else {
  111. (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref);
  112. }
  113. eqv[n] = fg->NewCNode(args);
  114. eqv[n]->set_abstract(n->abstract());
  115. eqv[n]->set_kernel_info(n->kernel_info_ptr());
  116. }
  117. std::vector<AnfNodePtr> eqv_keys;
  118. (void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys),
  119. [](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });
  120. auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys);
  121. AnfNodePtr fg_output;
  122. if (outputs.size() > 1) {
  123. std::vector<AnfNodePtr> output_args;
  124. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  125. (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args),
  126. [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
  127. // Set output for AnfGraph
  128. fg_output = fg->NewCNode(output_args);
  129. } else {
  130. fg_output = eqv[outputs[0]];
  131. }
  132. fg->set_output(fg_output);
  133. return std::make_tuple(fg, inputs, outputs);
  134. }
  135. // Converts the list of nodes to a runnable form.
  136. // All the nodes in the list must represent linear flow (no calls, branches, ...)
  137. // Returns:
  138. // (fn, inputs, outputs):
  139. // - fn: A callable function
  140. // - inputs: the list of inputs nodes whose values should be
  141. // provided to the function
  142. // - outputs: the list of output nodes corresponding to the
  143. // outputs of the function
  144. // Notes:
  145. // This implementation will convert the nodes into a subgraph
  146. // that will run using the MsVM.
  147. template <typename T>
  148. LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) {
  149. auto cached = g_ConvertCache.find(lst);
  150. if (cached != g_ConvertCache.end()) {
  151. return cached->second;
  152. }
  153. LinConvertResult result;
  154. FuncGraphPtr fg = nullptr;
  155. AnfNodePtrList inputs;
  156. AnfNodePtrList outputs;
  157. std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst);
  158. // Clone in case g contains subgraphs that have a different manager
  159. fg = BasicClone(fg);
  160. std::shared_ptr<VMImpl> vm = std::make_shared<T>();
  161. result.run =
  162. std::make_shared<RunFunc>([fg, vm](const VectorRef &args) -> VectorRef { return vm->RunGraph(fg, args); });
  163. result.inputs = inputs;
  164. result.outputs = outputs;
  165. result.graph_id = UINT32_MAX;
  166. (void)g_ConvertCache.emplace(lst, result);
  167. return result;
  168. }
  169. LinkFuncType MsVmConvert = Convert<VM>;
  170. std::unordered_map<std::string, LinkFuncType> backends = {{kMsVm, MsVmConvert}};
  171. std::set<std::string> backend_list = {
  172. kMsConvert,
  173. kMsVm,
  174. };
  175. } // namespace compile
  176. } // namespace mindspore