You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_graph.cc 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/kernel_graph.h"
  17. #include <algorithm>
  18. #include <queue>
  19. #include <stack>
  20. #include <unordered_set>
  21. #include "common/utils.h"
  22. #include "operator/ops.h"
  23. #include "session/anf_runtime_algorithm.h"
  24. #include "device/kernel_info.h"
  25. #include "kernel/kernel_build_info.h"
  26. namespace mindspore {
  27. namespace session {
  28. namespace {
  29. void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  30. std::unordered_set<AnfNodePtr> *visited_nodes) {
  31. MS_EXCEPTION_IF_NULL(que);
  32. MS_EXCEPTION_IF_NULL(visited_nodes);
  33. if (visited_nodes->find(node) == visited_nodes->end()) {
  34. que->push(node);
  35. (void)visited_nodes->insert(node);
  36. MS_LOG(DEBUG) << "Push que:" << node->DebugString();
  37. }
  38. }
  39. } // namespace
  40. std::vector<AnfNodePtr> KernelGraph::outputs() const {
  41. MS_EXCEPTION_IF_NULL(output());
  42. if (IsPrimitiveCNode(output(), prim::kPrimMakeTuple)) {
  43. auto make_tuple = output()->cast<CNodePtr>();
  44. MS_EXCEPTION_IF_NULL(make_tuple);
  45. auto &inputs = make_tuple->inputs();
  46. return std::vector<AnfNodePtr>(inputs.begin() + 1, inputs.end());
  47. }
  48. return std::vector<AnfNodePtr>();
  49. }
  50. void KernelGraph::SetExecOrderByDefault() {
  51. std::stack<AnfNodePtr> seed_nodes;
  52. UpdateNodeEdgeList(&seed_nodes);
  53. execution_order_.clear();
  54. std::unordered_set<AnfNodePtr> visited_nodes;
  55. std::queue<AnfNodePtr> zero_input_nodes;
  56. auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue) {
  57. auto it = node_output_edges_.find(node);
  58. if (it == node_output_edges_.end()) {
  59. // value node and parameter has no input,no need to print log
  60. if (node->isa<CNode>()) {
  61. MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
  62. }
  63. return;
  64. }
  65. // visit all reduce node first, then other nodes
  66. std::vector<AnfNodePtr> active_nodes;
  67. for (const auto &output_edge : it->second) {
  68. auto next_node = output_edge.first;
  69. if (node_input_num_.find(next_node) == node_input_num_.end()) {
  70. MS_EXCEPTION_IF_NULL(next_node);
  71. MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
  72. }
  73. MS_EXCEPTION_IF_NULL(next_node);
  74. MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
  75. << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
  76. if (node_input_num_[next_node] < output_edge.second) {
  77. MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num"
  78. << node_input_num_[next_node] << ",depend edge:" << output_edge.second;
  79. }
  80. node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
  81. // allreduce first
  82. if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) {
  83. (void)visited_nodes.insert(next_node);
  84. if (AnfAlgo::IsAllReduceOp(next_node)) {
  85. MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
  86. visit_queue->push(next_node);
  87. } else {
  88. active_nodes.emplace_back(next_node);
  89. }
  90. }
  91. }
  92. for (auto &node : active_nodes) {
  93. MS_LOG(DEBUG) << "visit node:" << node->DebugString();
  94. visit_queue->push(node);
  95. }
  96. };
  97. AnfNodePtr last_allreduce_node = nullptr;
  98. std::queue<AnfNodePtr> allreduce_descendants;
  99. while (!seed_nodes.empty() || last_allreduce_node != nullptr) {
  100. // seed nodes first, then visit last all reduce node descendant
  101. if (seed_nodes.empty()) {
  102. visit_node_descendant(last_allreduce_node, &allreduce_descendants);
  103. last_allreduce_node = nullptr;
  104. } else {
  105. zero_input_nodes.push(seed_nodes.top());
  106. seed_nodes.pop();
  107. }
  108. // all reduce node descendant first, then common queue
  109. while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) {
  110. AnfNodePtr node = nullptr;
  111. bool is_allreduce_descendant = false;
  112. if (allreduce_descendants.empty()) {
  113. node = zero_input_nodes.front();
  114. zero_input_nodes.pop();
  115. } else {
  116. node = allreduce_descendants.front();
  117. allreduce_descendants.pop();
  118. is_allreduce_descendant = true;
  119. }
  120. // add execute node
  121. MS_EXCEPTION_IF_NULL(node);
  122. if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
  123. execution_order_.push_back(node->cast<CNodePtr>());
  124. }
  125. // for all reduce node, visit last all reduce node descendant
  126. if (AnfAlgo::IsAllReduceOp(node)) {
  127. if (last_allreduce_node != nullptr) {
  128. visit_node_descendant(last_allreduce_node, &allreduce_descendants);
  129. }
  130. last_allreduce_node = node;
  131. } else if (is_allreduce_descendant) {
  132. visit_node_descendant(node, &allreduce_descendants);
  133. } else {
  134. visit_node_descendant(node, &zero_input_nodes);
  135. }
  136. }
  137. }
  138. CheckLoop();
  139. }
  140. void KernelGraph::CheckLoop() {
  141. std::map<AnfNodePtr, size_t> none_zero_nodes;
  142. if (node_input_edges_.size() != node_input_num_.size()) {
  143. MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size()
  144. << "not equal to node_input_num_ size:" << node_input_num_.size();
  145. }
  146. for (auto &it : node_input_num_) {
  147. MS_EXCEPTION_IF_NULL(it.first);
  148. string str;
  149. auto node_input_it = node_input_edges_.find(it.first);
  150. if (node_input_it == node_input_edges_.end()) {
  151. MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]";
  152. }
  153. for (const auto &input_edge : node_input_edges_[it.first]) {
  154. MS_EXCEPTION_IF_NULL(input_edge.first);
  155. str = str.append(input_edge.first->DebugString()).append("|");
  156. }
  157. if (it.second != 0) {
  158. MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second;
  159. none_zero_nodes[it.first] = it.second;
  160. }
  161. }
  162. // if don't consider control depend and loop exit,a exception will be throw
  163. if (!none_zero_nodes.empty()) {
  164. MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
  165. }
  166. }
  167. CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
  168. auto cnode = FuncGraph::NewCNode(inputs);
  169. MS_EXCEPTION_IF_NULL(cnode);
  170. cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
  171. // create kernel_info from new parameter
  172. auto kernel_info = std::make_shared<device::KernelInfo>();
  173. // if the node only has the primitive(such as getNext) or the node's input has a feature map input
  174. // then the node's output is a feature map output
  175. if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(),
  176. [&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) {
  177. kernel_info->SetFeatureMapFlag(true);
  178. }
  179. cnode->set_kernel_info(kernel_info);
  180. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  181. return cnode;
  182. }
  183. CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
  184. MS_EXCEPTION_IF_NULL(cnode);
  185. auto new_cnode = std::make_shared<CNode>(*cnode);
  186. // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map
  187. if (BackendNodeExistInFrontBackendMap(cnode)) {
  188. FrontBackendlMapUpdate(cnode, new_cnode);
  189. }
  190. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  191. return new_cnode;
  192. }
  193. ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
  194. ParameterPtr new_parameter = add_parameter();
  195. MS_EXCEPTION_IF_NULL(new_parameter);
  196. // create kernel_info form new parameter
  197. auto kernel_info = std::make_shared<device::KernelInfo>();
  198. size_t output_tensor_num = 1;
  199. // if use default parameter = nullptr,it remarks create a new parameter from no parameter
  200. if (parameter == nullptr) {
  201. new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>());
  202. kernel_info->SetFeatureMapFlag(true);
  203. } else {
  204. // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
  205. new_parameter->set_abstract(parameter->abstract());
  206. new_parameter->set_name(parameter->name());
  207. if (AnfAlgo::IsParameterWeight(parameter)) {
  208. new_parameter->set_default_param(parameter->default_param());
  209. kernel_info->SetFeatureMapFlag(false);
  210. } else {
  211. kernel_info->SetFeatureMapFlag(true);
  212. }
  213. // if output is a tuple tensor,now can use for loop to handle tuple tensor
  214. output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter);
  215. }
  216. new_parameter->set_kernel_info(kernel_info);
  217. // create kernel_build_info for new parameter
  218. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  219. // create init data type,
  220. std::vector<TypeId> init_data_type = {};
  221. for (size_t i = 0; i < output_tensor_num; i++) {
  222. TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, i);
  223. init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type);
  224. }
  225. // set the format of parameter to DEFAULT_FORMAT
  226. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
  227. // set parameter initaial device data type
  228. kernel_build_info_builder->SetOutputsDeviceType(init_data_type);
  229. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get());
  230. AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
  231. return new_parameter;
  232. }
  233. std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) {
  234. MS_EXCEPTION_IF_NULL(value_node);
  235. auto node_value = value_node->value();
  236. auto output_size = AnfAlgo::GetOutputTensorNum(value_node);
  237. std::vector<AnfNodePtr> convert_inputs;
  238. if (!node_value->isa<ValueTuple>()) {
  239. MS_LOG(EXCEPTION) << "multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
  240. }
  241. auto value_tuple = node_value->cast<ValueTuplePtr>();
  242. MS_EXCEPTION_IF_NULL(value_tuple);
  243. if (value_tuple->size() != output_size) {
  244. MS_LOG(EXCEPTION) << "value tuple size" << value_tuple->size()
  245. << " is not mathced with the value node's output size" << output_size;
  246. }
  247. for (size_t index = 0; index < value_tuple->value().size(); ++index) {
  248. auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]);
  249. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)},
  250. {AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get());
  251. AddValueNodeToGraph(new_value_node);
  252. auto kernel_info = std::make_shared<device::KernelInfo>();
  253. new_value_node->set_kernel_info(kernel_info);
  254. kernel_info->SetFeatureMapFlag(false);
  255. // create kernel_build_info for new value node
  256. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  257. // set the format of value_node to DEFAULT_FORMAT
  258. kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
  259. // set value node initial device data type = infer data type
  260. kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown});
  261. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  262. AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
  263. AddValueNodeToGraph(new_value_node);
  264. convert_inputs.emplace_back(new_value_node);
  265. }
  266. if (!RemoveValueNodeFromGraph(value_node)) {
  267. MS_LOG(WARNING) << "failed to remove the value_node " << value_node->DebugString();
  268. }
  269. return convert_inputs;
  270. }
  271. ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
  272. MS_EXCEPTION_IF_NULL(value_node);
  273. ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
  274. new_value_node->set_abstract(value_node->abstract());
  275. // create kernel_info fo new value node
  276. auto kernel_info = std::make_shared<device::KernelInfo>();
  277. kernel_info->SetFeatureMapFlag(false);
  278. new_value_node->set_kernel_info(kernel_info);
  279. // create kernel_build_info for new value node
  280. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  281. // set the format of value_node to DEFAULT_FORMAT
  282. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  283. // set value node initial device data type = infer data type
  284. std::vector<TypeId> types;
  285. for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
  286. types.push_back(kTypeUnknown);
  287. }
  288. kernel_build_info_builder->SetOutputsDeviceType(types);
  289. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  290. AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
  291. return new_value_node;
  292. }
  293. const std::vector<AnfNodePtr> &KernelGraph::inputs() const {
  294. MS_EXCEPTION_IF_NULL(inputs_);
  295. return *inputs_;
  296. }
  297. void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) {
  298. MS_EXCEPTION_IF_NULL(front_anf);
  299. MS_EXCEPTION_IF_NULL(backend_anf);
  300. if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) {
  301. MS_LOG(EXCEPTION) << "anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
  302. }
  303. if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
  304. MS_LOG(EXCEPTION) << "kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
  305. }
  306. front_backend_anf_map_[front_anf] = backend_anf;
  307. backend_front_anf_map_[backend_anf] = front_anf;
  308. }
  309. void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) {
  310. MS_EXCEPTION_IF_NULL(old_backend_anf);
  311. MS_EXCEPTION_IF_NULL(new_backend_anf);
  312. if (old_backend_anf.get() == new_backend_anf.get()) {
  313. MS_LOG(EXCEPTION) << "old can't be same with new";
  314. }
  315. if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
  316. MS_LOG(EXCEPTION) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
  317. }
  318. if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
  319. MS_LOG(EXCEPTION) << "anf is not exist in the mape ,old " << old_backend_anf->DebugString();
  320. }
  321. front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
  322. backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
  323. // delete old kernel
  324. (void)backend_front_anf_map_.erase(old_backend_anf);
  325. }
  326. // get kernel by anf
  327. AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
  328. if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) {
  329. return nullptr;
  330. }
  331. return front_backend_anf_map_[front_anf];
  332. }
  333. bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) {
  334. return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end();
  335. }
  336. ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) {
  337. if (tensor_to_value_node_map_.find(tensor) == tensor_to_value_node_map_.end()) {
  338. return nullptr;
  339. }
  340. return tensor_to_value_node_map_[tensor];
  341. }
  342. void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) {
  343. MS_EXCEPTION_IF_NULL(tensor);
  344. MS_EXCEPTION_IF_NULL(value_node);
  345. tensor_to_value_node_map_[tensor] = value_node;
  346. }
  347. void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) {
  348. MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num;
  349. auto output_depend_edge = std::pair<AnfNodePtr, size_t>(node, depend_edge_num);
  350. // add output depend edge of input
  351. auto output_it = node_output_edges_.find(input);
  352. if (output_it == node_output_edges_.end()) {
  353. node_output_edges_[input] = std::vector<std::pair<AnfNodePtr, size_t>>{output_depend_edge};
  354. } else {
  355. output_it->second.push_back(output_depend_edge);
  356. }
  357. // add input depend edge of output
  358. auto input_depend_edge = std::pair<AnfNodePtr, size_t>(input, depend_edge_num);
  359. auto input_it = node_input_edges_.find(node);
  360. if (input_it == node_input_edges_.end()) {
  361. node_input_edges_[node] = std::vector<std::pair<AnfNodePtr, size_t>>{input_depend_edge};
  362. } else {
  363. input_it->second.push_back(input_depend_edge);
  364. }
  365. // add node input depend num
  366. auto depend_it = node_input_num_.find(node);
  367. if (depend_it == node_input_num_.end()) {
  368. node_input_num_[node] = depend_edge_num;
  369. } else {
  370. depend_it->second += depend_edge_num;
  371. }
  372. }
  373. std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
  374. MS_EXCEPTION_IF_NULL(node);
  375. auto it = node_output_edges_.find(node);
  376. if (it == node_output_edges_.end()) {
  377. MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]";
  378. }
  379. std::vector<AnfNodePtr> output_nodes;
  380. auto trans = [](const std::pair<AnfNodePtr, size_t> &pair) -> AnfNodePtr { return pair.first; };
  381. (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes), trans);
  382. return output_nodes;
  383. }
  384. // update the depend relations of control depend
  385. void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
  386. for (const auto &node : depends) {
  387. MS_EXCEPTION_IF_NULL(node);
  388. if (!node->isa<CNode>()) {
  389. return;
  390. }
  391. auto cnode = node->cast<CNodePtr>();
  392. MS_EXCEPTION_IF_NULL(cnode);
  393. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
  394. MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend";
  395. }
  396. auto prior_node = cnode->input(kControlDependPriorIndex);
  397. auto depend_node = cnode->input(kControlDependBehindIndex);
  398. MS_EXCEPTION_IF_NULL(prior_node);
  399. MS_EXCEPTION_IF_NULL(depend_node);
  400. std::vector<AnfNodePtr> prior_nodes = {prior_node};
  401. std::vector<AnfNodePtr> depend_nodes = {depend_node};
  402. MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString();
  403. if (prior_node->isa<Parameter>()) {
  404. prior_nodes = GetOutputNodes(prior_node);
  405. }
  406. if (depend_node->isa<Parameter>()) {
  407. depend_nodes = GetOutputNodes(depend_node);
  408. }
  409. for (auto &first_node : prior_nodes) {
  410. for (auto &second_node : depend_nodes) {
  411. MS_EXCEPTION_IF_NULL(first_node);
  412. MS_EXCEPTION_IF_NULL(second_node);
  413. MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
  414. AddDependEdge(second_node, first_node, 1);
  415. }
  416. }
  417. }
  418. }
  419. bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  420. std::unordered_set<AnfNodePtr> *visited_nodes) {
  421. MS_EXCEPTION_IF_NULL(node);
  422. if (!node->isa<CNode>()) {
  423. return false;
  424. }
  425. auto cnode = node->cast<CNodePtr>();
  426. MS_EXCEPTION_IF_NULL(cnode);
  427. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
  428. return false;
  429. }
  430. // set the control depend visited but don't push it into the que
  431. if (visited_nodes->find(node) != visited_nodes->end()) {
  432. MS_LOG(EXCEPTION) << "control depend[" << node->DebugString() << "] has been handled before";
  433. }
  434. (void)visited_nodes->insert(cnode);
  435. // add a 0 depend num to keep the link relations to prepare for finding zero output nodes
  436. auto prior_node = cnode->input(kControlDependPriorIndex);
  437. auto depend_node = cnode->input(kControlDependBehindIndex);
  438. for (const auto &input : cnode->inputs()) {
  439. AddDependEdge(node, input, 0);
  440. }
  441. PushNoVisitedNode(depend_node, que, visited_nodes);
  442. PushNoVisitedNode(prior_node, que, visited_nodes);
  443. return true;
  444. }
  445. void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
  446. node_output_edges_.clear();
  447. node_input_num_.clear();
  448. node_input_edges_.clear();
  449. std::vector<AnfNodePtr> control_depends;
  450. std::unordered_set<AnfNodePtr> visited_nodes;
  451. std::queue<AnfNodePtr> que;
  452. que.push(get_return());
  453. while (!que.empty()) {
  454. auto node = que.front();
  455. que.pop();
  456. MS_EXCEPTION_IF_NULL(node);
  457. if (node->isa<Parameter>() || node->isa<ValueNode>()) {
  458. seed_nodes->push(node);
  459. continue;
  460. }
  461. if (!node->isa<CNode>()) {
  462. continue;
  463. }
  464. auto cnode = node->cast<CNodePtr>();
  465. MS_EXCEPTION_IF_NULL(cnode);
  466. // handle data links
  467. for (const auto &input : cnode->inputs()) {
  468. size_t depend_edge_num = 1;
  469. // handle control depend,all inputs of control depend has no depend edge
  470. if (HandleControlDependNode(input, &que, &visited_nodes)) {
  471. control_depends.push_back(input);
  472. depend_edge_num = 0;
  473. }
  474. PushNoVisitedNode(input, &que, &visited_nodes);
  475. AddDependEdge(node, input, depend_edge_num);
  476. }
  477. }
  478. UpdateControlDependRelations(control_depends);
  479. }
  480. void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); }
  481. bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; }
  482. AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
  483. if (!IsInRefOutputMap(out_pair)) {
  484. MS_LOG(EXCEPTION) << "out_pair is not in RefOutputMap";
  485. }
  486. return ref_out_in_map_.at(out_pair);
  487. }
  488. void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) {
  489. if (IsInRefOutputMap(final_pair)) {
  490. MS_LOG(EXCEPTION) << "out_pair is already in RefOutputMap";
  491. }
  492. (void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair));
  493. }
  494. bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
  495. if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) {
  496. (void)graph_value_nodes_.erase(value_node);
  497. return true;
  498. }
  499. return false;
  500. }
  501. } // namespace session
  502. } // namespace mindspore