You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_graph.cc 36 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/kernel_graph.h"
  17. #include <algorithm>
  18. #include <queue>
  19. #include <unordered_set>
  20. #include <set>
  21. #include "operator/ops.h"
  22. #include "ir/param_value_py.h"
  23. #include "session/anf_runtime_algorithm.h"
  24. #include "device/kernel_info.h"
  25. #include "kernel/kernel_build_info.h"
  26. #include "device/kernel_runtime_manager.h"
  27. #include "kernel/common_utils.h"
  28. namespace mindspore {
  29. namespace session {
  30. namespace {
  31. constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
  32. constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
  33. void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  34. std::unordered_set<AnfNodePtr> *visited_nodes) {
  35. MS_EXCEPTION_IF_NULL(node);
  36. MS_EXCEPTION_IF_NULL(que);
  37. MS_EXCEPTION_IF_NULL(visited_nodes);
  38. if (visited_nodes->find(node) == visited_nodes->end()) {
  39. que->push(node);
  40. (void)visited_nodes->insert(node);
  41. MS_LOG(DEBUG) << "Push que:" << node->DebugString();
  42. }
  43. }
  44. std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
  45. auto item_with_index =
  46. AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
  47. AnfNodePtr node = item_with_index.first;
  48. MS_EXCEPTION_IF_NULL(node);
  49. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
  50. auto outputs = AnfAlgo::GetAllOutput(node);
  51. std::set<AnfNodePtr> memo;
  52. std::vector<AnfNodePtr> new_output;
  53. for (auto &output : outputs) {
  54. if (memo.find(output) != memo.end()) {
  55. continue;
  56. }
  57. memo.insert(output);
  58. new_output.push_back(output);
  59. }
  60. if (new_output.size() == 1 && AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) {
  61. node = new_output[0];
  62. }
  63. }
  64. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
  65. return {node};
  66. }
  67. std::vector<AnfNodePtr> real_inputs;
  68. auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast<CNodePtr>());
  69. for (const auto &child_graph : child_graphs) {
  70. if (child_graph->get_output_null()) {
  71. continue;
  72. }
  73. auto real_input = child_graph->output();
  74. auto child_real_inputs = GetCallRealOutputs(real_input);
  75. std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs));
  76. }
  77. return real_inputs;
  78. }
  79. AnfNodePtr MakeValueNode(const AnfNodePtr &node) {
  80. auto value_node = node->cast<ValueNodePtr>();
  81. if (value_node == nullptr) {
  82. return nullptr;
  83. }
  84. ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
  85. new_value_node->set_abstract(value_node->abstract());
  86. // create kernel_info fo new value node
  87. auto kernel_info = std::make_shared<device::KernelInfo>();
  88. new_value_node->set_kernel_info(kernel_info);
  89. // create kernel_build_info for new value node
  90. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  91. // set the format of value_node to DEFAULT_FORMAT
  92. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  93. // set value node initial device data type = infer data type
  94. std::vector<TypeId> types;
  95. for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
  96. types.push_back(kTypeUnknown);
  97. }
  98. kernel_build_info_builder->SetOutputsDeviceType(types);
  99. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  100. return new_value_node;
  101. }
  102. } // namespace
  103. std::vector<AnfNodePtr> KernelGraph::outputs() const {
  104. auto graph_output = output();
  105. if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) {
  106. auto make_tuple = output()->cast<CNodePtr>();
  107. MS_EXCEPTION_IF_NULL(make_tuple);
  108. auto &inputs = make_tuple->inputs();
  109. return std::vector<AnfNodePtr>(inputs.begin() + 1, inputs.end());
  110. }
  111. return std::vector<AnfNodePtr>(1, graph_output);
  112. }
  113. void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
  114. std::unordered_set<AnfNodePtr> *visited_nodes) {
  115. MS_EXCEPTION_IF_NULL(visit_queue);
  116. MS_EXCEPTION_IF_NULL(visited_nodes);
  117. auto it = node_output_edges_.find(node);
  118. if (it == node_output_edges_.end()) {
  119. // value node and parameter has no input,no need to print log
  120. if (node->isa<CNode>()) {
  121. MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
  122. }
  123. return;
  124. }
  125. // visit all reduce node first, then other nodes
  126. std::vector<AnfNodePtr> active_nodes;
  127. for (const auto &output_edge : it->second) {
  128. auto next_node = output_edge.first;
  129. MS_EXCEPTION_IF_NULL(next_node);
  130. if (node_input_num_.find(next_node) == node_input_num_.end()) {
  131. MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
  132. }
  133. MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
  134. << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
  135. if (node_input_num_[next_node] < output_edge.second) {
  136. MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node]
  137. << ",depend edge:" << output_edge.second;
  138. }
  139. node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
  140. // allreduce first
  141. if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
  142. (void)visited_nodes->insert(next_node);
  143. if (AnfAlgo::IsCommunicationOp(next_node)) {
  144. MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString();
  145. visit_queue->push(next_node);
  146. } else {
  147. active_nodes.emplace_back(next_node);
  148. }
  149. }
  150. }
  151. for (auto &node : active_nodes) {
  152. MS_EXCEPTION_IF_NULL(node);
  153. MS_LOG(DEBUG) << "Visit node:" << node->DebugString();
  154. visit_queue->push(node);
  155. }
  156. }
  157. void KernelGraph::SetExecOrderByDefault() {
  158. std::queue<AnfNodePtr> seed_nodes;
  159. UpdateNodeEdgeList(&seed_nodes);
  160. execution_order_.clear();
  161. std::unordered_set<AnfNodePtr> visited_nodes;
  162. std::queue<AnfNodePtr> zero_input_nodes;
  163. AnfNodePtr last_communication_node = nullptr;
  164. std::queue<AnfNodePtr> communication_descendants;
  165. while (!seed_nodes.empty() || last_communication_node != nullptr) {
  166. // seed nodes first, then visit last all reduce node descendant
  167. if (seed_nodes.empty()) {
  168. VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
  169. last_communication_node = nullptr;
  170. } else {
  171. zero_input_nodes.push(seed_nodes.front());
  172. seed_nodes.pop();
  173. }
  174. // all reduce node descendant first, then common queue
  175. while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
  176. AnfNodePtr node = nullptr;
  177. bool is_communication_descendant = false;
  178. if (communication_descendants.empty()) {
  179. node = zero_input_nodes.front();
  180. zero_input_nodes.pop();
  181. } else {
  182. node = communication_descendants.front();
  183. communication_descendants.pop();
  184. is_communication_descendant = true;
  185. }
  186. // add execute node
  187. MS_EXCEPTION_IF_NULL(node);
  188. if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
  189. execution_order_.push_back(node->cast<CNodePtr>());
  190. }
  191. // for all reduce node, visit last all reduce node descendant
  192. if (AnfAlgo::IsCommunicationOp(node)) {
  193. if (last_communication_node != nullptr) {
  194. VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
  195. }
  196. last_communication_node = node;
  197. } else if (is_communication_descendant) {
  198. VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
  199. } else {
  200. VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
  201. }
  202. }
  203. }
  204. CheckLoop();
  205. // resort start label / end goto
  206. std::vector<CNodePtr> re_order;
  207. if (start_label_ != nullptr) {
  208. re_order.push_back(start_label_);
  209. }
  210. for (auto &node : execution_order_) {
  211. if (node == start_label_ || node == end_goto_) {
  212. continue;
  213. }
  214. re_order.push_back(node);
  215. }
  216. if (end_goto_ != nullptr) {
  217. re_order.push_back(end_goto_);
  218. }
  219. execution_order_ = re_order;
  220. }
  221. void KernelGraph::CheckLoop() {
  222. std::map<AnfNodePtr, size_t> none_zero_nodes;
  223. if (node_input_edges_.size() != node_input_num_.size()) {
  224. MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size()
  225. << "not equal to node_input_num_ size:" << node_input_num_.size();
  226. }
  227. for (auto &it : node_input_num_) {
  228. MS_EXCEPTION_IF_NULL(it.first);
  229. string str;
  230. auto node_input_it = node_input_edges_.find(it.first);
  231. if (node_input_it == node_input_edges_.end()) {
  232. MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]";
  233. }
  234. for (const auto &input_edge : node_input_edges_[it.first]) {
  235. MS_EXCEPTION_IF_NULL(input_edge.first);
  236. str = str.append(input_edge.first->DebugString()).append("|");
  237. }
  238. if (it.second != 0) {
  239. MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second;
  240. none_zero_nodes[it.first] = it.second;
  241. }
  242. }
  243. // if don't consider control depend and loop exit,a exception will be throw
  244. if (!none_zero_nodes.empty()) {
  245. MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
  246. }
  247. }
  248. CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
  249. auto cnode = FuncGraph::NewCNode(inputs);
  250. MS_EXCEPTION_IF_NULL(cnode);
  251. cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
  252. CreateKernelInfoFromNewParameter(cnode);
  253. auto kernel_info = std::make_shared<device::KernelInfo>();
  254. std::vector<size_t> feature_map_input_indexs;
  255. // if the node only has the primitive(such as getNext) or the node's input has a feature map input
  256. // then the node's output is a feature map output
  257. for (size_t index = 1; index < inputs.size(); ++index) {
  258. auto node = inputs[index];
  259. if (AnfAlgo::IsFeatureMapOutput(node)) {
  260. feature_map_input_indexs.push_back(index);
  261. }
  262. }
  263. if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
  264. AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
  265. }
  266. if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
  267. kernel_info->SetFeatureMapFlag(true);
  268. }
  269. if (AnfAlgo::IsRealCNodeKernel(cnode)) {
  270. AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
  271. AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
  272. }
  273. cnode->set_kernel_info(kernel_info);
  274. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  275. return cnode;
  276. }
  277. void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
  278. if (!AnfAlgo::IsGraphKernel(cnode)) {
  279. return;
  280. }
  281. auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  282. MS_EXCEPTION_IF_NULL(func_graph);
  283. std::vector<AnfNodePtr> node_list;
  284. std::vector<AnfNodePtr> input_list;
  285. std::vector<AnfNodePtr> output_list;
  286. kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
  287. for (auto &anf_node : node_list) {
  288. MS_EXCEPTION_IF_NULL(anf_node);
  289. auto kernel_info = std::make_shared<device::KernelInfo>();
  290. anf_node->set_kernel_info(kernel_info);
  291. auto anf_cnode = anf_node->cast<CNodePtr>();
  292. MS_EXCEPTION_IF_NULL(anf_cnode);
  293. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_cnode); ++i) {
  294. auto input_node = anf_cnode->input(i + 1);
  295. MS_EXCEPTION_IF_NULL(input_node);
  296. if (IsValueNode<tensor::Tensor>(input_node)) {
  297. auto new_input_node = MakeValueNode(input_node);
  298. if (new_input_node != nullptr) {
  299. anf_cnode->set_input(i + 1, new_input_node);
  300. }
  301. }
  302. }
  303. }
  304. for (auto &anf_node : input_list) {
  305. MS_EXCEPTION_IF_NULL(anf_node);
  306. auto kernel_info = std::make_shared<device::KernelInfo>();
  307. anf_node->set_kernel_info(kernel_info);
  308. }
  309. }
  310. CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
  311. MS_EXCEPTION_IF_NULL(cnode);
  312. auto new_cnode = std::make_shared<CNode>(*cnode);
  313. // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map
  314. if (BackendNodeExistInFrontBackendMap(cnode)) {
  315. FrontBackendlMapUpdate(cnode, new_cnode);
  316. }
  317. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  318. return new_cnode;
  319. }
  320. ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
  321. ParameterPtr new_parameter = add_parameter();
  322. MS_EXCEPTION_IF_NULL(new_parameter);
  323. // create kernel_info form new parameter
  324. auto kernel_info = std::make_shared<device::KernelInfo>();
  325. size_t output_tensor_num = 1;
  326. // if use default parameter = nullptr,it remarks create a new parameter from no parameter
  327. if (parameter == nullptr) {
  328. new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>());
  329. kernel_info->SetFeatureMapFlag(true);
  330. } else {
  331. // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
  332. new_parameter->set_abstract(parameter->abstract());
  333. new_parameter->set_name(parameter->name());
  334. if (AnfAlgo::IsParameterWeight(parameter)) {
  335. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
  336. auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
  337. new_parameter->set_default_param(param_value_new);
  338. kernel_info->SetFeatureMapFlag(false);
  339. } else {
  340. kernel_info->SetFeatureMapFlag(true);
  341. }
  342. // if output is a tuple tensor,now can use for loop to handle tuple tensor
  343. output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter);
  344. }
  345. new_parameter->set_kernel_info(kernel_info);
  346. // create kernel_build_info for new parameter
  347. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  348. // create init data type,
  349. std::vector<TypeId> init_data_type = {};
  350. for (size_t i = 0; i < output_tensor_num; i++) {
  351. TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, i);
  352. init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type);
  353. }
  354. // set the format of parameter to DEFAULT_FORMAT
  355. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
  356. // set parameter initaial device data type
  357. kernel_build_info_builder->SetOutputsDeviceType(init_data_type);
  358. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get());
  359. AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
  360. return new_parameter;
  361. }
  362. std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) {
  363. MS_EXCEPTION_IF_NULL(value_node);
  364. auto node_value = value_node->value();
  365. auto output_size = AnfAlgo::GetOutputTensorNum(value_node);
  366. std::vector<AnfNodePtr> convert_inputs;
  367. if (!node_value->isa<ValueTuple>()) {
  368. MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
  369. }
  370. auto value_tuple = node_value->cast<ValueTuplePtr>();
  371. MS_EXCEPTION_IF_NULL(value_tuple);
  372. if (value_tuple->size() != output_size) {
  373. MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size()
  374. << " is not mathced with the value node's output size" << output_size;
  375. }
  376. for (size_t index = 0; index < value_tuple->value().size(); ++index) {
  377. auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]);
  378. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)},
  379. {AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get());
  380. AddValueNodeToGraph(new_value_node);
  381. auto kernel_info = std::make_shared<device::KernelInfo>();
  382. new_value_node->set_kernel_info(kernel_info);
  383. kernel_info->SetFeatureMapFlag(false);
  384. // create kernel_build_info for new value node
  385. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  386. // set the format of value_node to DEFAULT_FORMAT
  387. kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
  388. // set value node initial device data type = infer data type
  389. kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown});
  390. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  391. AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
  392. AddValueNodeToGraph(new_value_node);
  393. convert_inputs.emplace_back(new_value_node);
  394. }
  395. if (!RemoveValueNodeFromGraph(value_node)) {
  396. MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
  397. }
  398. return convert_inputs;
  399. }
  400. ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
  401. MS_EXCEPTION_IF_NULL(value_node);
  402. auto new_value_node = MakeValueNode(value_node)->cast<ValueNodePtr>();
  403. AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
  404. return new_value_node;
  405. }
  406. const std::vector<AnfNodePtr> &KernelGraph::inputs() const {
  407. MS_EXCEPTION_IF_NULL(inputs_);
  408. return *inputs_;
  409. }
  410. void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) {
  411. MS_EXCEPTION_IF_NULL(front_anf);
  412. MS_EXCEPTION_IF_NULL(backend_anf);
  413. if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) {
  414. MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
  415. }
  416. if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
  417. MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
  418. }
  419. front_backend_anf_map_[front_anf] = backend_anf;
  420. backend_front_anf_map_[backend_anf] = front_anf;
  421. }
  422. void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) {
  423. MS_EXCEPTION_IF_NULL(old_backend_anf);
  424. MS_EXCEPTION_IF_NULL(new_backend_anf);
  425. if (old_backend_anf == new_backend_anf) {
  426. MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString();
  427. return;
  428. }
  429. if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
  430. MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
  431. return;
  432. }
  433. if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
  434. MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString();
  435. }
  436. front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
  437. backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
  438. // delete old kernel
  439. (void)backend_front_anf_map_.erase(old_backend_anf);
  440. }
  441. // get kernel by anf
  442. AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
  443. if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) {
  444. return nullptr;
  445. }
  446. return front_backend_anf_map_[front_anf];
  447. }
  448. bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) {
  449. return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end();
  450. }
  451. ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) {
  452. if (tensor_to_value_node_map_.find(tensor) == tensor_to_value_node_map_.end()) {
  453. return nullptr;
  454. }
  455. return tensor_to_value_node_map_[tensor];
  456. }
  457. void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) {
  458. MS_EXCEPTION_IF_NULL(tensor);
  459. MS_EXCEPTION_IF_NULL(value_node);
  460. tensor_to_value_node_map_[tensor] = value_node;
  461. }
  462. void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) {
  463. MS_EXCEPTION_IF_NULL(node);
  464. MS_EXCEPTION_IF_NULL(input);
  465. MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num;
  466. auto output_depend_edge = std::pair<AnfNodePtr, size_t>(node, depend_edge_num);
  467. // add output depend edge of input
  468. auto output_it = node_output_edges_.find(input);
  469. if (output_it == node_output_edges_.end()) {
  470. node_output_edges_[input] = std::vector<std::pair<AnfNodePtr, size_t>>{output_depend_edge};
  471. } else {
  472. output_it->second.push_back(output_depend_edge);
  473. }
  474. // add input depend edge of output
  475. auto input_depend_edge = std::pair<AnfNodePtr, size_t>(input, depend_edge_num);
  476. auto input_it = node_input_edges_.find(node);
  477. if (input_it == node_input_edges_.end()) {
  478. node_input_edges_[node] = std::vector<std::pair<AnfNodePtr, size_t>>{input_depend_edge};
  479. } else {
  480. input_it->second.push_back(input_depend_edge);
  481. }
  482. // add node input depend num
  483. auto depend_it = node_input_num_.find(node);
  484. if (depend_it == node_input_num_.end()) {
  485. node_input_num_[node] = depend_edge_num;
  486. } else {
  487. depend_it->second += depend_edge_num;
  488. }
  489. }
  490. std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
  491. MS_EXCEPTION_IF_NULL(node);
  492. auto it = node_output_edges_.find(node);
  493. if (it == node_output_edges_.end()) {
  494. MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]";
  495. }
  496. std::vector<AnfNodePtr> output_nodes;
  497. auto trans = [](const std::pair<AnfNodePtr, size_t> &pair) -> AnfNodePtr { return pair.first; };
  498. (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes), trans);
  499. return output_nodes;
  500. }
  501. // Find control_depend real input nodes.
  502. void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) {
  503. MS_EXCEPTION_IF_NULL(anf_node);
  504. MS_EXCEPTION_IF_NULL(result);
  505. MS_EXCEPTION_IF_NULL(visited);
  506. if (visited->find(anf_node) != visited->end()) {
  507. MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
  508. return;
  509. }
  510. visited->insert(anf_node);
  511. if (AnfAlgo::IsRealKernel(anf_node)) {
  512. result->emplace_back(anf_node);
  513. return;
  514. }
  515. if (!anf_node->isa<CNode>()) {
  516. return;
  517. }
  518. auto cnode = anf_node->cast<CNodePtr>();
  519. MS_EXCEPTION_IF_NULL(cnode);
  520. if (cnode->inputs().empty()) {
  521. MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
  522. }
  523. auto input0 = cnode->input(0);
  524. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  525. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  526. GetAllFatherRealNode(cnode->input(i), result, visited);
  527. }
  528. } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
  529. if (cnode->inputs().size() != kTupleGetItemInputSize) {
  530. MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
  531. }
  532. GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
  533. } else if (IsPrimitive(input0, prim::kPrimDepend)) {
  534. if (cnode->inputs().size() != kDependInputSize) {
  535. MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
  536. }
  537. GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
  538. GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
  539. }
  540. }
  541. // update the depend relations of control depend
  542. void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
  543. for (const auto &node : depends) {
  544. MS_EXCEPTION_IF_NULL(node);
  545. if (!node->isa<CNode>()) {
  546. return;
  547. }
  548. auto cnode = node->cast<CNodePtr>();
  549. MS_EXCEPTION_IF_NULL(cnode);
  550. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
  551. MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend";
  552. }
  553. auto prior_node = cnode->input(kControlDependPriorIndex);
  554. auto depend_node = cnode->input(kControlDependBehindIndex);
  555. MS_EXCEPTION_IF_NULL(prior_node);
  556. MS_EXCEPTION_IF_NULL(depend_node);
  557. std::vector<AnfNodePtr> prior_nodes = {prior_node};
  558. std::vector<AnfNodePtr> depend_nodes = {depend_node};
  559. int depend_mode = 0;
  560. if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
  561. depend_mode = AnfAlgo::GetNodeAttr<int>(cnode, kControlDependMode);
  562. }
  563. MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString()
  564. << "], depend_mode :" << depend_mode << ".";
  565. if (prior_node->isa<Parameter>() && depend_mode == 1) {
  566. prior_nodes = GetOutputNodes(prior_node);
  567. }
  568. if (depend_node->isa<Parameter>()) {
  569. depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{};
  570. }
  571. std::vector<AnfNodePtr> real_prior_nodes;
  572. std::set<AnfNodePtr> prior_visited;
  573. for (const auto &tmp : prior_nodes) {
  574. GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
  575. }
  576. std::vector<AnfNodePtr> real_depend_nodes;
  577. std::set<AnfNodePtr> depend_visited;
  578. for (const auto &tmp : depend_nodes) {
  579. GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
  580. }
  581. for (auto &first_node : real_prior_nodes) {
  582. if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
  583. continue;
  584. }
  585. for (auto &second_node : real_depend_nodes) {
  586. if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
  587. continue;
  588. }
  589. MS_EXCEPTION_IF_NULL(first_node);
  590. MS_EXCEPTION_IF_NULL(second_node);
  591. MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
  592. AddDependEdge(second_node, first_node, 1);
  593. }
  594. }
  595. }
  596. }
  597. bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  598. std::unordered_set<AnfNodePtr> *visited_nodes) {
  599. MS_EXCEPTION_IF_NULL(node);
  600. MS_EXCEPTION_IF_NULL(que);
  601. MS_EXCEPTION_IF_NULL(visited_nodes);
  602. if (!node->isa<CNode>()) {
  603. return false;
  604. }
  605. auto cnode = node->cast<CNodePtr>();
  606. MS_EXCEPTION_IF_NULL(cnode);
  607. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
  608. return false;
  609. }
  610. // set the control depend visited but don't push it into the que
  611. if (visited_nodes->find(node) != visited_nodes->end()) {
  612. return true;
  613. }
  614. (void)visited_nodes->insert(cnode);
  615. // add a 0 depend num to keep the link relations to prepare for finding zero output nodes
  616. auto prior_node = cnode->input(kControlDependPriorIndex);
  617. auto depend_node = cnode->input(kControlDependBehindIndex);
  618. for (const auto &input : cnode->inputs()) {
  619. AddDependEdge(node, input, 0);
  620. }
  621. PushNoVisitedNode(depend_node, que, visited_nodes);
  622. PushNoVisitedNode(prior_node, que, visited_nodes);
  623. return true;
  624. }
  625. void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
  626. MS_EXCEPTION_IF_NULL(seed_nodes);
  627. node_output_edges_.clear();
  628. node_input_num_.clear();
  629. node_input_edges_.clear();
  630. std::vector<AnfNodePtr> control_depends;
  631. std::unordered_set<AnfNodePtr> visited_nodes;
  632. std::queue<AnfNodePtr> que;
  633. que.push(get_return());
  634. while (!que.empty()) {
  635. auto node = que.front();
  636. que.pop();
  637. MS_EXCEPTION_IF_NULL(node);
  638. if (node->isa<Parameter>() || node->isa<ValueNode>()) {
  639. seed_nodes->push(node);
  640. continue;
  641. }
  642. if (!node->isa<CNode>()) {
  643. continue;
  644. }
  645. auto cnode = node->cast<CNodePtr>();
  646. MS_EXCEPTION_IF_NULL(cnode);
  647. // handle data links
  648. for (const auto &input : cnode->inputs()) {
  649. size_t depend_edge_num = 1;
  650. // handle control depend,all inputs of control depend has no depend edge
  651. if (HandleControlDependNode(input, &que, &visited_nodes)) {
  652. control_depends.push_back(input);
  653. depend_edge_num = 0;
  654. }
  655. PushNoVisitedNode(input, &que, &visited_nodes);
  656. AddDependEdge(node, input, depend_edge_num);
  657. }
  658. }
  659. UpdateControlDependRelations(control_depends);
  660. }
  661. void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); }
  662. bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; }
  663. AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
  664. if (!IsInRefOutputMap(out_pair)) {
  665. MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap";
  666. }
  667. return ref_out_in_map_.at(out_pair);
  668. }
  669. void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) {
  670. if (IsInRefOutputMap(final_pair)) {
  671. MS_LOG(EXCEPTION) << "Out_pair is already in RefOutputMap";
  672. }
  673. (void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair));
  674. }
  675. bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
  676. if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) {
  677. (void)graph_value_nodes_.erase(value_node);
  678. return true;
  679. }
  680. return false;
  681. }
  682. void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node) {
  683. MS_EXCEPTION_IF_NULL(inputs_);
  684. auto it = node_output_edges_.find(old_anf_node);
  685. if (it != node_output_edges_.end()) {
  686. const auto &outputs = it->second;
  687. for (auto &output_node : outputs) {
  688. MS_EXCEPTION_IF_NULL(output_node.first);
  689. auto output_cnode = output_node.first->cast<CNodePtr>();
  690. MS_EXCEPTION_IF_NULL(output_cnode);
  691. auto &output_node_inputs = output_cnode->inputs();
  692. // don't replace node if it is a control edge => output_node.second == 0
  693. if (output_node.second == 0) {
  694. continue;
  695. }
  696. for (size_t i = 1; i < output_node_inputs.size(); i++) {
  697. if (output_node_inputs[i] == old_anf_node.get()) {
  698. output_cnode->set_input(i, new_anf_node);
  699. }
  700. }
  701. // update graph inputs
  702. for (size_t i = 0; i < inputs_->size(); i++) {
  703. if ((*inputs_)[i] == old_anf_node.get()) {
  704. MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString()
  705. << ",new graph input:" << new_anf_node->DebugString();
  706. (*inputs_)[i] = new_anf_node.get();
  707. break;
  708. }
  709. }
  710. }
  711. // update front to backend map
  712. FrontBackendlMapUpdate(old_anf_node, new_anf_node);
  713. // update output depend relations
  714. node_output_edges_[new_anf_node.get()] = it->second;
  715. (void)node_output_edges_.erase(old_anf_node);
  716. }
  717. // update graph inputs in child graph
  718. auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(),
  719. [&old_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
  720. return n.first == old_anf_node.get();
  721. });
  722. if (it_real_inputs != real_inputs_.end()) {
  723. // erase old parameter in map
  724. auto old_args = it_real_inputs->second;
  725. real_inputs_.erase(it_real_inputs);
  726. // insert new parameter to map
  727. auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(),
  728. [&new_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
  729. return n.first == new_anf_node.get();
  730. });
  731. if (iter != real_inputs_.end()) {
  732. MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited.";
  733. iter->second = old_args;
  734. } else {
  735. real_inputs_.emplace_back(new_anf_node, old_args);
  736. }
  737. }
  738. }
  739. void KernelGraph::UpdateExecuteKernelStreamLabel() {
  740. for (auto &kernel : execution_order_) {
  741. AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get());
  742. }
  743. }
  744. std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
  745. std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
  746. if (IsLeafGraph()) {
  747. leaf_graph_order.push_back(shared_from_this()->cast<KernelGraphPtr>());
  748. } else {
  749. for (const auto &child_graph : child_graph_order_) {
  750. MS_EXCEPTION_IF_NULL(child_graph);
  751. auto child_leaf_graph_order = child_graph->GetLeafGraphOrder();
  752. std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order));
  753. }
  754. }
  755. return leaf_graph_order;
  756. }
  757. bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
  758. std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const {
  759. std::vector<CNodePtr> result;
  760. for (const auto &anf : execution_order_) {
  761. if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
  762. result.push_back(anf->cast<CNodePtr>());
  763. }
  764. }
  765. return result;
  766. }
  767. void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg) {
  768. MS_EXCEPTION_IF_NULL(parameter);
  769. MS_EXCEPTION_IF_NULL(arg);
  770. MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
  771. MS_EXCEPTION_IF_NULL(parameter);
  772. MS_EXCEPTION_IF_NULL(arg);
  773. auto iter = std::find_if(
  774. real_inputs_.begin(), real_inputs_.end(),
  775. [&parameter](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool { return n.first == parameter; });
  776. if (iter != real_inputs_.end()) {
  777. auto &args = iter->second;
  778. args.push_back(arg);
  779. } else {
  780. real_inputs_.emplace_back(parameter, std::vector<AnfNodePtr>(1, arg));
  781. }
  782. }
  783. void KernelGraph::UpdateCallRealInput() {
  784. MS_LOG(INFO) << "Update graph id: " << graph_id_;
  785. std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_map;
  786. for (auto &it : real_inputs_) {
  787. auto parameter = it.first;
  788. MS_EXCEPTION_IF_NULL(parameter);
  789. auto real_inputs = it.second;
  790. std::vector<AnfNodePtr> new_real_inputs;
  791. for (auto &real_input : real_inputs) {
  792. // if real input is a call node ,find the child graph output act as the new real input
  793. auto tmp_real_input = GetCallRealOutputs(real_input);
  794. std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs));
  795. }
  796. real_inputs_map.emplace_back(parameter, new_real_inputs);
  797. }
  798. real_inputs_ = real_inputs_map;
  799. }
  800. void KernelGraph::PrintGraphExecuteOrder() const {
  801. MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order";
  802. for (size_t i = 0; i < execution_order_.size(); i++) {
  803. CNodePtr cur_cnode_ptr = execution_order_[i];
  804. MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
  805. std::string event_str;
  806. std::string label_str;
  807. if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) {
  808. event_str = ", event_id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId)) + "]";
  809. }
  810. if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) {
  811. label_str = ", label_id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrLabelIndex)) + "]";
  812. }
  813. if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) {
  814. auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrLabelSwitchList);
  815. label_str = ", label_id[";
  816. for (size_t j = 0; j < label_list.size(); ++j) {
  817. label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]");
  818. }
  819. }
  820. MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
  821. << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
  822. << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"
  823. << event_str << label_str;
  824. }
  825. }
  826. std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
  827. KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); }
  828. } // namespace session
  829. } // namespace mindspore