You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

segment_runner.cc 7.4 kB

4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "vm/segment_runner.h"
  19. #include <algorithm>
  20. #include <functional>
  21. #include <memory>
  22. #include <set>
  23. #include <tuple>
  24. #include <utility>
  25. #include <string>
  26. #include "utils/hash_map.h"
  27. #include "utils/hash_set.h"
  28. #include "utils/log_adapter.h"
  29. #include "utils/utils.h"
  30. #include "ir/manager.h"
  31. #include "ir/func_graph_cloner.h"
  32. #include "frontend/operator/ops.h"
  33. namespace mindspore {
  34. namespace compile {
  35. namespace {
  36. // Return the list of nodes whose values are required beyond this segment.
  37. // Arguments:
  38. // nodes: list of nodes in the segment
  39. // users: dict mapping each node to its users (globally)
  40. // seen: set of nodes that are part of the segment
  41. AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
  42. const mindspore::HashSet<AnfNodePtr> &seen) {
  43. AnfNodePtrList output;
  44. if (users.size() == 0) {
  45. return output;
  46. }
  47. for (auto &node : nodes) {
  48. MS_EXCEPTION_IF_NULL(node);
  49. if (!node->isa<CNode>()) {
  50. continue;
  51. }
  52. auto iter = users.find(node);
  53. if (iter == users.end()) {
  54. continue;
  55. }
  56. auto &node_users = iter->second;
  57. const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users),
  58. [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
  59. const bool is_outer_user = (seen.find(u.first) == seen.end());
  60. return is_outer_user;
  61. });
  62. if (has_outer_user) {
  63. output.emplace_back(node);
  64. }
  65. }
  66. return output;
  67. }
  68. AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr,
  69. AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
  70. MS_EXCEPTION_IF_NULL(fg);
  71. MS_EXCEPTION_IF_NULL(inputs_ptr);
  72. MS_EXCEPTION_IF_NULL(eqv_ptr);
  73. MS_EXCEPTION_IF_NULL(node);
  74. auto &inputs = *inputs_ptr;
  75. auto &eqv = *eqv_ptr;
  76. if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
  77. eqv[node] = node;
  78. } else if (eqv.find(node) == eqv.end()) {
  79. inputs.push_back(node);
  80. eqv[node] = fg->add_parameter();
  81. eqv[node]->set_abstract(node->abstract());
  82. eqv[node]->set_kernel_info(node->kernel_info_ptr());
  83. }
  84. return eqv[node];
  85. }
  86. } // namespace
  87. std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
  88. if (lst.empty()) {
  89. MS_LOG(EXCEPTION) << "Input anf node list is empty";
  90. }
  91. FuncGraphPtr fg = nullptr;
  92. {
  93. // limit the lifetime of guard.
  94. MS_EXCEPTION_IF_NULL(lst[0]->cast<CNodePtr>());
  95. MS_EXCEPTION_IF_NULL(lst[0]->cast<CNodePtr>()->func_graph());
  96. TraceGuard guard(std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info()));
  97. fg = std::make_shared<FuncGraph>();
  98. }
  99. AnfNodePtrList inputs;
  100. AnfNodePtrToAnfNodePtrMap eqv;
  101. // Merge CNodes into a AnfGraph that represents a linear instruction segment
  102. for (auto n : lst) {
  103. MS_EXCEPTION_IF_NULL(n);
  104. if (!n->isa<CNode>()) {
  105. MS_LOG(EXCEPTION) << "Inst is not CNode";
  106. }
  107. auto &inps = n->cast<CNodePtr>()->inputs();
  108. if (inps.empty()) {
  109. MS_LOG(EXCEPTION) << "Input is empty";
  110. }
  111. if (!IsValueNode<Primitive>(inps[0]) &&
  112. !(IsValueNode<FuncGraph>(inps[0]) &&
  113. inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
  114. MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive ValueNode";
  115. }
  116. auto fn = inps[0];
  117. std::vector<AnfNodePtr> args{fn};
  118. if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() >= kDependInputSize &&
  119. eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
  120. args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv));
  121. const size_t value_start_index = 2;
  122. for (size_t i = value_start_index; i < inps.size(); ++i) {
  123. args.emplace_back(NewValueNode(MakeValue(0)));
  124. }
  125. } else {
  126. (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
  127. [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
  128. }
  129. TraceGuard tg(std::make_shared<TraceSegmentTransform>(n->debug_info()));
  130. MS_EXCEPTION_IF_NULL(fg);
  131. eqv[n] = fg->NewCNode(args);
  132. eqv[n]->set_abstract(n->abstract());
  133. eqv[n]->set_kernel_info(n->kernel_info_ptr());
  134. }
  135. mindspore::HashSet<AnfNodePtr> eqv_keys;
  136. for (auto &e : eqv) {
  137. eqv_keys.emplace(e.first);
  138. }
  139. auto mgr = lst[0]->func_graph()->manager();
  140. MS_EXCEPTION_IF_NULL(mgr);
  141. auto outputs = GetOutput(lst, mgr->node_users(), eqv_keys);
  142. AnfNodePtr fg_output;
  143. if (outputs.size() > 1) {
  144. std::vector<AnfNodePtr> output_args;
  145. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  146. (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args),
  147. [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
  148. // Set output for AnfGraph
  149. fg_output = fg->NewCNode(output_args);
  150. } else {
  151. if (outputs.empty()) {
  152. MS_LOG(EXCEPTION) << "Output is empty.";
  153. }
  154. fg_output = eqv[outputs[0]];
  155. }
  156. fg->set_output(fg_output);
  157. return std::make_tuple(fg, inputs, outputs);
  158. }
  159. // Converts the list of nodes to a runnable form.
  160. // All the nodes in the list must represent linear flow (no calls, branches, ...)
  161. // Returns:
  162. // (fn, inputs, outputs):
  163. // - fn: A callable function
  164. // - inputs: the list of inputs nodes whose values should be
  165. // provided to the function
  166. // - outputs: the list of output nodes corresponding to the
  167. // outputs of the function
  168. // Notes:
  169. // This implementation will convert the nodes into a subgraph
  170. // that will run using the MsVM.
  171. template <typename T>
  172. LinConvertResult Convert(const GraphSegmentPtr &segment, const std::string &) {
  173. MS_EXCEPTION_IF_NULL(segment);
  174. LinConvertResult result;
  175. FuncGraphPtr fg = nullptr;
  176. AnfNodePtrList inputs;
  177. AnfNodePtrList outputs;
  178. std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
  179. // Clone in case g contains subgraphs that have a different manager
  180. fg = BasicClone(fg);
  181. std::shared_ptr<VMImpl> vm = std::make_shared<T>();
  182. result.run =
  183. std::make_shared<RunFunc>([fg, vm](const VectorRef &args) -> VectorRef { return vm->RunGraph(fg, args); });
  184. result.inputs = inputs;
  185. result.outputs = outputs;
  186. result.graph_id = UINT32_MAX;
  187. return result;
  188. }
  189. LinkFuncType MsVmConvert = Convert<VM>;
  190. std::set<std::string> backend_list = {
  191. kMsConvert,
  192. kMsVm,
  193. };
  194. } // namespace compile
  195. } // namespace mindspore