You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_graph.cc 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/kernel_graph.h"
  17. #include <algorithm>
  18. #include <queue>
  19. #include <stack>
  20. #include <unordered_set>
  21. #include "common/utils.h"
  22. #include "operator/ops.h"
  23. #include "session/anf_runtime_algorithm.h"
  24. #include "device/kernel_info.h"
  25. #include "kernel/kernel_build_info.h"
  26. namespace mindspore {
  27. namespace session {
  28. namespace {
  29. void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  30. std::unordered_set<AnfNodePtr> *visited_nodes) {
  31. MS_EXCEPTION_IF_NULL(que);
  32. MS_EXCEPTION_IF_NULL(visited_nodes);
  33. if (visited_nodes->find(node) == visited_nodes->end()) {
  34. que->push(node);
  35. (void)visited_nodes->insert(node);
  36. MS_LOG(DEBUG) << "Push que:" << node->DebugString();
  37. }
  38. }
  39. } // namespace
  40. std::vector<AnfNodePtr> KernelGraph::outputs() const {
  41. MS_EXCEPTION_IF_NULL(output());
  42. if (IsPrimitiveCNode(output(), prim::kPrimMakeTuple)) {
  43. auto make_tuple = output()->cast<CNodePtr>();
  44. MS_EXCEPTION_IF_NULL(make_tuple);
  45. auto &inputs = make_tuple->inputs();
  46. return std::vector<AnfNodePtr>(inputs.begin() + 1, inputs.end());
  47. }
  48. return std::vector<AnfNodePtr>();
  49. }
  50. void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
  51. std::unordered_set<AnfNodePtr> *visited_nodes) {
  52. MS_EXCEPTION_IF_NULL(visit_queue);
  53. MS_EXCEPTION_IF_NULL(visited_nodes);
  54. auto it = node_output_edges_.find(node);
  55. if (it == node_output_edges_.end()) {
  56. // value node and parameter has no input,no need to print log
  57. if (node->isa<CNode>()) {
  58. MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
  59. }
  60. return;
  61. }
  62. // visit all reduce node first, then other nodes
  63. std::vector<AnfNodePtr> active_nodes;
  64. for (const auto &output_edge : it->second) {
  65. auto next_node = output_edge.first;
  66. if (node_input_num_.find(next_node) == node_input_num_.end()) {
  67. MS_EXCEPTION_IF_NULL(next_node);
  68. MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
  69. }
  70. MS_EXCEPTION_IF_NULL(next_node);
  71. MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
  72. << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
  73. if (node_input_num_[next_node] < output_edge.second) {
  74. MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node]
  75. << ",depend edge:" << output_edge.second;
  76. }
  77. node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
  78. // allreduce first
  79. if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
  80. (void)visited_nodes->insert(next_node);
  81. if (AnfAlgo::IsCommunicationOp(next_node)) {
  82. MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
  83. visit_queue->push(next_node);
  84. } else {
  85. active_nodes.emplace_back(next_node);
  86. }
  87. }
  88. }
  89. for (auto &node : active_nodes) {
  90. MS_LOG(DEBUG) << "visit node:" << node->DebugString();
  91. visit_queue->push(node);
  92. }
  93. }
  94. void KernelGraph::SetExecOrderByDefault() {
  95. std::queue<AnfNodePtr> seed_nodes;
  96. UpdateNodeEdgeList(&seed_nodes);
  97. execution_order_.clear();
  98. std::unordered_set<AnfNodePtr> visited_nodes;
  99. std::queue<AnfNodePtr> zero_input_nodes;
  100. AnfNodePtr last_communication_node = nullptr;
  101. std::queue<AnfNodePtr> communication_descendants;
  102. while (!seed_nodes.empty() || last_communication_node != nullptr) {
  103. // seed nodes first, then visit last all reduce node descendant
  104. if (seed_nodes.empty()) {
  105. VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
  106. last_communication_node = nullptr;
  107. } else {
  108. zero_input_nodes.push(seed_nodes.front());
  109. seed_nodes.pop();
  110. }
  111. // all reduce node descendant first, then common queue
  112. while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
  113. AnfNodePtr node = nullptr;
  114. bool is_communication_descendant = false;
  115. if (communication_descendants.empty()) {
  116. node = zero_input_nodes.front();
  117. zero_input_nodes.pop();
  118. } else {
  119. node = communication_descendants.front();
  120. communication_descendants.pop();
  121. is_communication_descendant = true;
  122. }
  123. // add execute node
  124. MS_EXCEPTION_IF_NULL(node);
  125. if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
  126. execution_order_.push_back(node->cast<CNodePtr>());
  127. }
  128. // for all reduce node, visit last all reduce node descendant
  129. if (AnfAlgo::IsCommunicationOp(node)) {
  130. if (last_communication_node != nullptr) {
  131. VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
  132. }
  133. last_communication_node = node;
  134. } else if (is_communication_descendant) {
  135. VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
  136. } else {
  137. VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
  138. }
  139. }
  140. }
  141. CheckLoop();
  142. }
  143. void KernelGraph::CheckLoop() {
  144. std::map<AnfNodePtr, size_t> none_zero_nodes;
  145. if (node_input_edges_.size() != node_input_num_.size()) {
  146. MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size()
  147. << "not equal to node_input_num_ size:" << node_input_num_.size();
  148. }
  149. for (auto &it : node_input_num_) {
  150. MS_EXCEPTION_IF_NULL(it.first);
  151. string str;
  152. auto node_input_it = node_input_edges_.find(it.first);
  153. if (node_input_it == node_input_edges_.end()) {
  154. MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]";
  155. }
  156. for (const auto &input_edge : node_input_edges_[it.first]) {
  157. MS_EXCEPTION_IF_NULL(input_edge.first);
  158. str = str.append(input_edge.first->DebugString()).append("|");
  159. }
  160. if (it.second != 0) {
  161. MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second;
  162. none_zero_nodes[it.first] = it.second;
  163. }
  164. }
  165. // if don't consider control depend and loop exit,a exception will be throw
  166. if (!none_zero_nodes.empty()) {
  167. MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
  168. }
  169. }
  170. CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
  171. auto cnode = FuncGraph::NewCNode(inputs);
  172. MS_EXCEPTION_IF_NULL(cnode);
  173. cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
  174. // create kernel_info from new parameter
  175. auto kernel_info = std::make_shared<device::KernelInfo>();
  176. // if the node only has the primitive(such as getNext) or the node's input has a feature map input
  177. // then the node's output is a feature map output
  178. if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(),
  179. [&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) {
  180. kernel_info->SetFeatureMapFlag(true);
  181. }
  182. cnode->set_kernel_info(kernel_info);
  183. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  184. return cnode;
  185. }
  186. CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
  187. MS_EXCEPTION_IF_NULL(cnode);
  188. auto new_cnode = std::make_shared<CNode>(*cnode);
  189. // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map
  190. if (BackendNodeExistInFrontBackendMap(cnode)) {
  191. FrontBackendlMapUpdate(cnode, new_cnode);
  192. }
  193. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  194. return new_cnode;
  195. }
  196. ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
  197. ParameterPtr new_parameter = add_parameter();
  198. MS_EXCEPTION_IF_NULL(new_parameter);
  199. // create kernel_info form new parameter
  200. auto kernel_info = std::make_shared<device::KernelInfo>();
  201. size_t output_tensor_num = 1;
  202. // if use default parameter = nullptr,it remarks create a new parameter from no parameter
  203. if (parameter == nullptr) {
  204. new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>());
  205. kernel_info->SetFeatureMapFlag(true);
  206. } else {
  207. // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
  208. new_parameter->set_abstract(parameter->abstract());
  209. new_parameter->set_name(parameter->name());
  210. if (AnfAlgo::IsParameterWeight(parameter)) {
  211. new_parameter->set_default_param(parameter->default_param());
  212. kernel_info->SetFeatureMapFlag(false);
  213. } else {
  214. kernel_info->SetFeatureMapFlag(true);
  215. }
  216. // if output is a tuple tensor,now can use for loop to handle tuple tensor
  217. output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter);
  218. }
  219. new_parameter->set_kernel_info(kernel_info);
  220. // create kernel_build_info for new parameter
  221. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  222. // create init data type,
  223. std::vector<TypeId> init_data_type = {};
  224. for (size_t i = 0; i < output_tensor_num; i++) {
  225. TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, i);
  226. init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type);
  227. }
  228. // set the format of parameter to DEFAULT_FORMAT
  229. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
  230. // set parameter initaial device data type
  231. kernel_build_info_builder->SetOutputsDeviceType(init_data_type);
  232. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get());
  233. AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
  234. return new_parameter;
  235. }
  236. std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) {
  237. MS_EXCEPTION_IF_NULL(value_node);
  238. auto node_value = value_node->value();
  239. auto output_size = AnfAlgo::GetOutputTensorNum(value_node);
  240. std::vector<AnfNodePtr> convert_inputs;
  241. if (!node_value->isa<ValueTuple>()) {
  242. MS_LOG(EXCEPTION) << "multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
  243. }
  244. auto value_tuple = node_value->cast<ValueTuplePtr>();
  245. MS_EXCEPTION_IF_NULL(value_tuple);
  246. if (value_tuple->size() != output_size) {
  247. MS_LOG(EXCEPTION) << "value tuple size" << value_tuple->size()
  248. << " is not mathced with the value node's output size" << output_size;
  249. }
  250. for (size_t index = 0; index < value_tuple->value().size(); ++index) {
  251. auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]);
  252. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)},
  253. {AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get());
  254. AddValueNodeToGraph(new_value_node);
  255. auto kernel_info = std::make_shared<device::KernelInfo>();
  256. new_value_node->set_kernel_info(kernel_info);
  257. kernel_info->SetFeatureMapFlag(false);
  258. // create kernel_build_info for new value node
  259. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  260. // set the format of value_node to DEFAULT_FORMAT
  261. kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
  262. // set value node initial device data type = infer data type
  263. kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown});
  264. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  265. AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
  266. AddValueNodeToGraph(new_value_node);
  267. convert_inputs.emplace_back(new_value_node);
  268. }
  269. if (!RemoveValueNodeFromGraph(value_node)) {
  270. MS_LOG(WARNING) << "failed to remove the value_node " << value_node->DebugString();
  271. }
  272. return convert_inputs;
  273. }
  274. ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
  275. MS_EXCEPTION_IF_NULL(value_node);
  276. ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
  277. new_value_node->set_abstract(value_node->abstract());
  278. // create kernel_info fo new value node
  279. auto kernel_info = std::make_shared<device::KernelInfo>();
  280. kernel_info->SetFeatureMapFlag(false);
  281. new_value_node->set_kernel_info(kernel_info);
  282. // create kernel_build_info for new value node
  283. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  284. // set the format of value_node to DEFAULT_FORMAT
  285. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  286. // set value node initial device data type = infer data type
  287. std::vector<TypeId> types;
  288. for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
  289. types.push_back(kTypeUnknown);
  290. }
  291. kernel_build_info_builder->SetOutputsDeviceType(types);
  292. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  293. AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
  294. return new_value_node;
  295. }
  296. const std::vector<AnfNodePtr> &KernelGraph::inputs() const {
  297. MS_EXCEPTION_IF_NULL(inputs_);
  298. return *inputs_;
  299. }
  300. void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) {
  301. MS_EXCEPTION_IF_NULL(front_anf);
  302. MS_EXCEPTION_IF_NULL(backend_anf);
  303. if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) {
  304. MS_LOG(EXCEPTION) << "anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
  305. }
  306. if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
  307. MS_LOG(EXCEPTION) << "kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
  308. }
  309. front_backend_anf_map_[front_anf] = backend_anf;
  310. backend_front_anf_map_[backend_anf] = front_anf;
  311. }
  312. void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) {
  313. MS_EXCEPTION_IF_NULL(old_backend_anf);
  314. MS_EXCEPTION_IF_NULL(new_backend_anf);
  315. if (old_backend_anf.get() == new_backend_anf.get()) {
  316. MS_LOG(EXCEPTION) << "old can't be same with new";
  317. }
  318. if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
  319. MS_LOG(EXCEPTION) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
  320. }
  321. if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
  322. MS_LOG(EXCEPTION) << "anf is not exist in the mape ,old " << old_backend_anf->DebugString();
  323. }
  324. front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
  325. backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
  326. // delete old kernel
  327. (void)backend_front_anf_map_.erase(old_backend_anf);
  328. }
  329. // get kernel by anf
  330. AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
  331. if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) {
  332. return nullptr;
  333. }
  334. return front_backend_anf_map_[front_anf];
  335. }
  336. bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) {
  337. return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end();
  338. }
  339. ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) {
  340. if (tensor_to_value_node_map_.find(tensor) == tensor_to_value_node_map_.end()) {
  341. return nullptr;
  342. }
  343. return tensor_to_value_node_map_[tensor];
  344. }
  345. void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) {
  346. MS_EXCEPTION_IF_NULL(tensor);
  347. MS_EXCEPTION_IF_NULL(value_node);
  348. tensor_to_value_node_map_[tensor] = value_node;
  349. }
  350. void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) {
  351. MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num;
  352. auto output_depend_edge = std::pair<AnfNodePtr, size_t>(node, depend_edge_num);
  353. // add output depend edge of input
  354. auto output_it = node_output_edges_.find(input);
  355. if (output_it == node_output_edges_.end()) {
  356. node_output_edges_[input] = std::vector<std::pair<AnfNodePtr, size_t>>{output_depend_edge};
  357. } else {
  358. output_it->second.push_back(output_depend_edge);
  359. }
  360. // add input depend edge of output
  361. auto input_depend_edge = std::pair<AnfNodePtr, size_t>(input, depend_edge_num);
  362. auto input_it = node_input_edges_.find(node);
  363. if (input_it == node_input_edges_.end()) {
  364. node_input_edges_[node] = std::vector<std::pair<AnfNodePtr, size_t>>{input_depend_edge};
  365. } else {
  366. input_it->second.push_back(input_depend_edge);
  367. }
  368. // add node input depend num
  369. auto depend_it = node_input_num_.find(node);
  370. if (depend_it == node_input_num_.end()) {
  371. node_input_num_[node] = depend_edge_num;
  372. } else {
  373. depend_it->second += depend_edge_num;
  374. }
  375. }
  376. std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
  377. MS_EXCEPTION_IF_NULL(node);
  378. auto it = node_output_edges_.find(node);
  379. if (it == node_output_edges_.end()) {
  380. MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]";
  381. }
  382. std::vector<AnfNodePtr> output_nodes;
  383. auto trans = [](const std::pair<AnfNodePtr, size_t> &pair) -> AnfNodePtr { return pair.first; };
  384. (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes), trans);
  385. return output_nodes;
  386. }
  387. // update the depend relations of control depend
  388. void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
  389. for (const auto &node : depends) {
  390. MS_EXCEPTION_IF_NULL(node);
  391. if (!node->isa<CNode>()) {
  392. return;
  393. }
  394. auto cnode = node->cast<CNodePtr>();
  395. MS_EXCEPTION_IF_NULL(cnode);
  396. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
  397. MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend";
  398. }
  399. auto prior_node = cnode->input(kControlDependPriorIndex);
  400. auto depend_node = cnode->input(kControlDependBehindIndex);
  401. MS_EXCEPTION_IF_NULL(prior_node);
  402. MS_EXCEPTION_IF_NULL(depend_node);
  403. std::vector<AnfNodePtr> prior_nodes = {prior_node};
  404. std::vector<AnfNodePtr> depend_nodes = {depend_node};
  405. MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString();
  406. if (prior_node->isa<Parameter>()) {
  407. prior_nodes = GetOutputNodes(prior_node);
  408. }
  409. if (depend_node->isa<Parameter>()) {
  410. depend_nodes = GetOutputNodes(depend_node);
  411. }
  412. for (auto &first_node : prior_nodes) {
  413. for (auto &second_node : depend_nodes) {
  414. MS_EXCEPTION_IF_NULL(first_node);
  415. MS_EXCEPTION_IF_NULL(second_node);
  416. MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
  417. AddDependEdge(second_node, first_node, 1);
  418. }
  419. }
  420. }
  421. }
  422. bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  423. std::unordered_set<AnfNodePtr> *visited_nodes) {
  424. MS_EXCEPTION_IF_NULL(node);
  425. if (!node->isa<CNode>()) {
  426. return false;
  427. }
  428. auto cnode = node->cast<CNodePtr>();
  429. MS_EXCEPTION_IF_NULL(cnode);
  430. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
  431. return false;
  432. }
  433. // set the control depend visited but don't push it into the que
  434. if (visited_nodes->find(node) != visited_nodes->end()) {
  435. MS_LOG(EXCEPTION) << "control depend[" << node->DebugString() << "] has been handled before";
  436. }
  437. (void)visited_nodes->insert(cnode);
  438. // add a 0 depend num to keep the link relations to prepare for finding zero output nodes
  439. auto prior_node = cnode->input(kControlDependPriorIndex);
  440. auto depend_node = cnode->input(kControlDependBehindIndex);
  441. for (const auto &input : cnode->inputs()) {
  442. AddDependEdge(node, input, 0);
  443. }
  444. PushNoVisitedNode(depend_node, que, visited_nodes);
  445. PushNoVisitedNode(prior_node, que, visited_nodes);
  446. return true;
  447. }
  448. void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
  449. node_output_edges_.clear();
  450. node_input_num_.clear();
  451. node_input_edges_.clear();
  452. std::vector<AnfNodePtr> control_depends;
  453. std::unordered_set<AnfNodePtr> visited_nodes;
  454. std::queue<AnfNodePtr> que;
  455. que.push(get_return());
  456. while (!que.empty()) {
  457. auto node = que.front();
  458. que.pop();
  459. MS_EXCEPTION_IF_NULL(node);
  460. if (node->isa<Parameter>() || node->isa<ValueNode>()) {
  461. seed_nodes->push(node);
  462. continue;
  463. }
  464. if (!node->isa<CNode>()) {
  465. continue;
  466. }
  467. auto cnode = node->cast<CNodePtr>();
  468. MS_EXCEPTION_IF_NULL(cnode);
  469. // handle data links
  470. for (const auto &input : cnode->inputs()) {
  471. size_t depend_edge_num = 1;
  472. // handle control depend,all inputs of control depend has no depend edge
  473. if (HandleControlDependNode(input, &que, &visited_nodes)) {
  474. control_depends.push_back(input);
  475. depend_edge_num = 0;
  476. }
  477. PushNoVisitedNode(input, &que, &visited_nodes);
  478. AddDependEdge(node, input, depend_edge_num);
  479. }
  480. }
  481. UpdateControlDependRelations(control_depends);
  482. }
  483. void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); }
  484. bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; }
  485. AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
  486. if (!IsInRefOutputMap(out_pair)) {
  487. MS_LOG(EXCEPTION) << "out_pair is not in RefOutputMap";
  488. }
  489. return ref_out_in_map_.at(out_pair);
  490. }
  491. void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) {
  492. if (IsInRefOutputMap(final_pair)) {
  493. MS_LOG(EXCEPTION) << "out_pair is already in RefOutputMap";
  494. }
  495. (void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair));
  496. }
  497. bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
  498. if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) {
  499. (void)graph_value_nodes_.erase(value_node);
  500. return true;
  501. }
  502. return false;
  503. }
  504. } // namespace session
  505. } // namespace mindspore