You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_control_parser.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <utility>
  17. #include <memory>
  18. #include "session/ascend_control_parser.h"
  19. #include "session/anf_runtime_algorithm.h"
  20. namespace mindspore {
  21. namespace session {
  22. static VectorRef GetCallArgs(std::vector<AnfNodePtr>::iterator iter_begin, std::vector<AnfNodePtr>::iterator iter_end) {
  23. VectorRef call_args;
  24. for (auto iter = iter_begin; iter != iter_end; ++iter) {
  25. if (utils::isa<ValueNode>(*iter)) {
  26. call_args.push_back(GetValueNode(*iter));
  27. } else {
  28. call_args.push_back(*iter);
  29. }
  30. }
  31. return call_args;
  32. }
  33. void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
  34. std::set<KernelGraphPtr> memo;
  35. ProcessKernelGraph(kg, nullptr, nullptr, {}, NOT_NULL(&memo));
  36. }
  37. NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
  38. const CNodePtr &last_label, const VectorRef &args,
  39. NotNull<std::set<KernelGraphPtr> *> memo) {
  40. MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString();
  41. // 0. recursive condition
  42. if (memo->find(kg) != memo->end()) {
  43. MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString();
  44. return NOT_NULL(kg->get_start_label());
  45. }
  46. // 2. args replace placeholder
  47. LinkParentGraph(kg, last_node, last_label, args);
  48. // 3. topological sort
  49. std::vector<CNodePtr> nodes = GetCNodes(TopoSort(kg->get_return()));
  50. if (nodes.empty()) {
  51. MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!";
  52. }
  53. // 4. insert first_label
  54. auto start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
  55. for (auto node : nodes) {
  56. if (!IsPrimitiveCNode(node, prim::kPrimPartial)) {
  57. InsertControlDependToGraph(kg, NOT_NULL(start_label), NOT_NULL(node));
  58. break;
  59. }
  60. }
  61. kg->set_start_label(start_label);
  62. // 5. traverse
  63. for (size_t i = 0; i < nodes.size(); ++i) {
  64. auto &cnode = nodes[i];
  65. if (cnode->size() < kCNodePrim + 1) {
  66. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  67. }
  68. AnfNodePtr fn = cnode->input(kCNodePrim);
  69. if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
  70. MS_LOG(DEBUG) << "continue node " << cnode->DebugString();
  71. continue;
  72. }
  73. AnfNodePtr arg = cnode->input(kCNodeCallArg);
  74. if (IsValueNode<KernelGraph>(arg)) {
  75. RecurseCall(kg, NOT_NULL(cnode), (i + 1 < nodes.size() ? nodes[i + 1] : nullptr), memo);
  76. } else if (!arg->isa<CNode>()) {
  77. MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString();
  78. } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) {
  79. auto arg_cnode = arg->cast<CNodePtr>();
  80. cnode->set_inputs(cnode->inputs());
  81. RecurseSwitch(kg, NOT_NULL(cnode), memo);
  82. } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) {
  83. auto arg_cnode = arg->cast<CNodePtr>();
  84. cnode->set_inputs(cnode->inputs());
  85. RecurseSwitchLayer(kg, NOT_NULL(cnode), memo);
  86. }
  87. }
  88. MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString();
  89. return NOT_NULL(start_label);
  90. }
  91. std::vector<CNodePtr> AscendControlParser::GetCNodes(const std::vector<AnfNodePtr> &in) {
  92. std::vector<CNodePtr> out;
  93. for (auto &node : in) {
  94. if (node->isa<CNode>()) {
  95. out.push_back(node->cast<CNodePtr>());
  96. }
  97. }
  98. return out;
  99. }
  100. void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node) {
  101. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))};
  102. auto return_node = kg->get_return();
  103. MS_EXCEPTION_IF_NULL(return_node);
  104. inputs.push_back(return_node->input(1));
  105. inputs.push_back(attch_node.get());
  106. auto depend_node = kg->NewCNode(inputs);
  107. return_node->set_input(1, depend_node);
  108. }
  109. void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
  110. NotNull<AnfNodePtr> second_node) {
  111. MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString()
  112. << ", the second node is " << second_node->DebugString();
  113. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
  114. first_node, second_node};
  115. auto control_depend = kg->NewCNode(inputs);
  116. InsertDependToGraph(kg, NOT_NULL(control_depend));
  117. }
  118. void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
  119. const CNodePtr &last_label, const VectorRef &args) {
  120. if (from_graph_call_node != nullptr) {
  121. SetSubGraphInput(kg, NOT_NULL(from_graph_call_node), args);
  122. }
  123. auto origin_return = kg->get_return();
  124. std::vector<AnfNodePtr> origin_return_inputs = origin_return->inputs();
  125. // if entry graph, replace return with make_tuple
  126. if (from_graph_call_node == nullptr || last_label == nullptr) {
  127. MS_LOG(INFO) << kg->ToString() << " is entry graph.";
  128. std::vector<AnfNodePtr> make_tuple_inputs = {std::make_shared<ValueNode>(prim::kPrimMakeTuple)};
  129. make_tuple_inputs.insert(make_tuple_inputs.end(), origin_return_inputs.begin() + 1, origin_return_inputs.end());
  130. auto make_tuple = kg->NewCNode(make_tuple_inputs);
  131. origin_return->set_inputs({origin_return->input(kCNodePrim), make_tuple});
  132. } else {
  133. // else replace return with label_goto
  134. auto label_goto =
  135. kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName)), last_label});
  136. InsertDependToGraph(kg, NOT_NULL(label_goto));
  137. }
  138. }
  139. void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
  140. NotNull<std::set<KernelGraphPtr> *> memo) {
  141. MS_LOG(INFO) << "process call func " << cur_node->DebugString();
  142. // 1 get kernel graph
  143. auto origin_inputs = cur_node->inputs();
  144. std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))};
  145. auto call_args = GetCallArgs(origin_inputs.begin() + 1, origin_inputs.end());
  146. if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) {
  147. MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode";
  148. return;
  149. }
  150. // 2 return label
  151. auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
  152. // 3 add depend relationship
  153. InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
  154. if (next_node != nullptr && next_node != kg->get_return()) {
  155. InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
  156. }
  157. auto call_kg = GetValueNode<KernelGraphPtr>(origin_inputs[kCNodeCallArg]);
  158. // 4 modify call op to goto op
  159. cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]);
  160. // 5 recurse sub graph
  161. CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, call_args, memo);
  162. new_inputs.push_back(sub_label);
  163. new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end());
  164. cur_node->set_inputs(new_inputs);
  165. cur_node->set_abstract(nullptr);
  166. MS_LOG(INFO) << "success process call func " << cur_node->DebugString();
  167. }
  168. void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
  169. NotNull<std::set<KernelGraphPtr> *> memo) {
  170. MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
  171. if (cur_node->size() < kCNodeSwitchLength) {
  172. MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength;
  173. }
  174. // 1 return label
  175. auto back_label = kg->NewCNode({std::make_shared<ValueNode>(prim::kPrimLabelSet)});
  176. // 2 recurse sub graph
  177. auto origin_switch_inputs = cur_node->inputs();
  178. std::vector<AnfNodePtr> new_switch_inputs = {
  179. std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
  180. origin_switch_inputs[kCNodeSwitchCond]};
  181. for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
  182. // 2.1 branch kernel graph and args
  183. CNodePtr partial;
  184. KernelGraphPtr branch_fg;
  185. VectorRef call_args;
  186. std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
  187. // 2.2 add depend relationship
  188. InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
  189. // 2.3 recurse sub graph
  190. CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo);
  191. new_switch_inputs.push_back(branch_label);
  192. }
  193. std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
  194. new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end());
  195. cur_node->set_inputs(new_switch_inputs);
  196. cur_node->set_abstract(nullptr);
  197. MS_LOG(INFO) << "success process switch func " << cur_node->DebugString();
  198. }
  199. void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
  200. NotNull<std::set<KernelGraphPtr> *> memo) {
  201. MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
  202. if (cur_node->size() < kCNodeSwitchLayerLength) {
  203. MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
  204. }
  205. auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch);
  206. MS_EXCEPTION_IF_NULL(branch_tuple);
  207. if (!branch_tuple->isa<CNode>()) {
  208. MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
  209. }
  210. auto branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs();
  211. // 1 return label
  212. auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName))});
  213. // 2 recurse sub graph
  214. auto origin_switch_inputs = cur_node->inputs();
  215. std::vector<AnfNodePtr> new_switch_inputs = {std::make_shared<ValueNode>(prim::kPrimLabelSwitch),
  216. origin_switch_inputs[kCNodeSwitchCond]};
  217. for (size_t i = 0; i < branch_partial.size(); ++i) {
  218. // 2.1 branch kernel graph and args
  219. CNodePtr partial;
  220. KernelGraphPtr branch_fg;
  221. VectorRef call_args;
  222. std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
  223. // 2.2 add depend relationship
  224. InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
  225. // 2.3 recurse sub graph
  226. CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo);
  227. new_switch_inputs.push_back(branch_label);
  228. }
  229. new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
  230. cur_node->set_inputs(new_switch_inputs);
  231. cur_node->set_abstract(nullptr);
  232. MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString();
  233. }
  234. std::tuple<CNodePtr, KernelGraphPtr, VectorRef> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
  235. if (!node.get()->isa<CNode>()) {
  236. MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString();
  237. }
  238. // 2.1 branch kernel graph and args
  239. auto partial_cnode = utils::cast<CNodePtr>(node.get());
  240. if (partial_cnode->size() < kCNodePartialLength) {
  241. MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength;
  242. }
  243. auto partial_inputs = partial_cnode->inputs();
  244. auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
  245. auto call_args = GetCallArgs(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end());
  246. return {partial_cnode, branch_kg, call_args};
  247. }
  248. void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
  249. NotNull<AnfNodePtr> to) {
  250. if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
  251. AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
  252. return;
  253. }
  254. if (from.get() == to.get()) {
  255. return;
  256. }
  257. MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
  258. << to->DebugString();
  259. // config inputs of assign node
  260. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from};
  261. // generate a new cnode
  262. auto assign_node = kg->NewCNode(inputs);
  263. MS_EXCEPTION_IF_NULL(assign_node);
  264. assign_node->set_abstract(to->abstract());
  265. // append the assign at the end of from graph
  266. InsertDependToGraph(kg, NOT_NULL(assign_node));
  267. }
  268. size_t AscendControlParser::SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node,
  269. size_t input_index) {
  270. auto output_num = AnfAlgo::GetOutputTensorNum(node);
  271. if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
  272. return input_index + output_num;
  273. }
  274. auto &graph_inputs = kg->inputs();
  275. if (input_index >= graph_inputs.size()) {
  276. MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
  277. }
  278. auto backend_parameter = graph_inputs[input_index];
  279. if (node.get()->isa<Parameter>()) {
  280. MS_EXCEPTION_IF_NULL(backend_parameter);
  281. MS_LOG(INFO) << "Reuse node [" << node->DebugString() << "], old node[" << backend_parameter->DebugString()
  282. << "] will be replaced.";
  283. kg->ReplaceNode(backend_parameter, node);
  284. return input_index;
  285. }
  286. InsertAssignToGraph(kg, node, NOT_NULL(backend_parameter));
  287. return input_index + 1;
  288. }
  289. void AscendControlParser::SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node,
  290. const VectorRef &args) {}
  291. } // namespace session
  292. } // namespace mindspore