You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_graph.cc 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/kernel_graph.h"
  17. #include <algorithm>
  18. #include <queue>
  19. #include <unordered_set>
  20. #include <set>
  21. #include "operator/ops.h"
  22. #include "ir/param_value_py.h"
  23. #include "session/anf_runtime_algorithm.h"
  24. #include "device/kernel_info.h"
  25. #include "kernel/kernel_build_info.h"
  26. #include "device/kernel_runtime_manager.h"
  27. namespace mindspore {
  28. namespace session {
  29. namespace {
  30. constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
  31. constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
  32. void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  33. std::unordered_set<AnfNodePtr> *visited_nodes) {
  34. MS_EXCEPTION_IF_NULL(que);
  35. MS_EXCEPTION_IF_NULL(visited_nodes);
  36. if (visited_nodes->find(node) == visited_nodes->end()) {
  37. que->push(node);
  38. (void)visited_nodes->insert(node);
  39. MS_LOG(DEBUG) << "Push que:" << node->DebugString();
  40. }
  41. }
  42. std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
  43. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(call_node, 0);
  44. MS_EXCEPTION_IF_NULL(item_with_index.first);
  45. if (!AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) {
  46. return {item_with_index.first};
  47. }
  48. std::vector<AnfNodePtr> real_inputs;
  49. auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast<CNodePtr>());
  50. for (const auto &child_graph : child_graphs) {
  51. if (AnfAlgo::IsWhileTrueGraph(child_graph)) {
  52. continue;
  53. }
  54. auto real_input = child_graph->output();
  55. auto child_real_inputs = GetCallRealOutputs(real_input);
  56. std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs));
  57. }
  58. return real_inputs;
  59. }
  60. } // namespace
  61. std::vector<AnfNodePtr> KernelGraph::outputs() const {
  62. auto graph_output = output();
  63. if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) {
  64. auto make_tuple = output()->cast<CNodePtr>();
  65. MS_EXCEPTION_IF_NULL(make_tuple);
  66. auto &inputs = make_tuple->inputs();
  67. return std::vector<AnfNodePtr>(inputs.begin() + 1, inputs.end());
  68. }
  69. return std::vector<AnfNodePtr>(1, graph_output);
  70. }
  71. void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
  72. std::unordered_set<AnfNodePtr> *visited_nodes) {
  73. MS_EXCEPTION_IF_NULL(visit_queue);
  74. MS_EXCEPTION_IF_NULL(visited_nodes);
  75. auto it = node_output_edges_.find(node);
  76. if (it == node_output_edges_.end()) {
  77. // value node and parameter has no input,no need to print log
  78. if (node->isa<CNode>()) {
  79. MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
  80. }
  81. return;
  82. }
  83. // visit all reduce node first, then other nodes
  84. std::vector<AnfNodePtr> active_nodes;
  85. for (const auto &output_edge : it->second) {
  86. auto next_node = output_edge.first;
  87. if (node_input_num_.find(next_node) == node_input_num_.end()) {
  88. MS_EXCEPTION_IF_NULL(next_node);
  89. MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
  90. }
  91. MS_EXCEPTION_IF_NULL(next_node);
  92. MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
  93. << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
  94. if (node_input_num_[next_node] < output_edge.second) {
  95. MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node]
  96. << ",depend edge:" << output_edge.second;
  97. }
  98. node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
  99. // allreduce first
  100. if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
  101. (void)visited_nodes->insert(next_node);
  102. if (AnfAlgo::IsCommunicationOp(next_node)) {
  103. MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
  104. visit_queue->push(next_node);
  105. } else {
  106. active_nodes.emplace_back(next_node);
  107. }
  108. }
  109. }
  110. for (auto &node : active_nodes) {
  111. MS_LOG(DEBUG) << "visit node:" << node->DebugString();
  112. visit_queue->push(node);
  113. }
  114. }
  115. void KernelGraph::SetExecOrderByDefault() {
  116. std::queue<AnfNodePtr> seed_nodes;
  117. UpdateNodeEdgeList(&seed_nodes);
  118. execution_order_.clear();
  119. std::unordered_set<AnfNodePtr> visited_nodes;
  120. std::queue<AnfNodePtr> zero_input_nodes;
  121. AnfNodePtr last_communication_node = nullptr;
  122. std::queue<AnfNodePtr> communication_descendants;
  123. while (!seed_nodes.empty() || last_communication_node != nullptr) {
  124. // seed nodes first, then visit last all reduce node descendant
  125. if (seed_nodes.empty()) {
  126. VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
  127. last_communication_node = nullptr;
  128. } else {
  129. zero_input_nodes.push(seed_nodes.front());
  130. seed_nodes.pop();
  131. }
  132. // all reduce node descendant first, then common queue
  133. while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
  134. AnfNodePtr node = nullptr;
  135. bool is_communication_descendant = false;
  136. if (communication_descendants.empty()) {
  137. node = zero_input_nodes.front();
  138. zero_input_nodes.pop();
  139. } else {
  140. node = communication_descendants.front();
  141. communication_descendants.pop();
  142. is_communication_descendant = true;
  143. }
  144. // add execute node
  145. MS_EXCEPTION_IF_NULL(node);
  146. if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
  147. execution_order_.push_back(node->cast<CNodePtr>());
  148. }
  149. // for all reduce node, visit last all reduce node descendant
  150. if (AnfAlgo::IsCommunicationOp(node)) {
  151. if (last_communication_node != nullptr) {
  152. VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
  153. }
  154. last_communication_node = node;
  155. } else if (is_communication_descendant) {
  156. VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
  157. } else {
  158. VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
  159. }
  160. }
  161. }
  162. CheckLoop();
  163. }
  164. void KernelGraph::CheckLoop() {
  165. std::map<AnfNodePtr, size_t> none_zero_nodes;
  166. if (node_input_edges_.size() != node_input_num_.size()) {
  167. MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size()
  168. << "not equal to node_input_num_ size:" << node_input_num_.size();
  169. }
  170. for (auto &it : node_input_num_) {
  171. MS_EXCEPTION_IF_NULL(it.first);
  172. string str;
  173. auto node_input_it = node_input_edges_.find(it.first);
  174. if (node_input_it == node_input_edges_.end()) {
  175. MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]";
  176. }
  177. for (const auto &input_edge : node_input_edges_[it.first]) {
  178. MS_EXCEPTION_IF_NULL(input_edge.first);
  179. str = str.append(input_edge.first->DebugString()).append("|");
  180. }
  181. if (it.second != 0) {
  182. MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second;
  183. none_zero_nodes[it.first] = it.second;
  184. }
  185. }
  186. // if don't consider control depend and loop exit,a exception will be throw
  187. if (!none_zero_nodes.empty()) {
  188. MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
  189. }
  190. }
  191. CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
  192. auto cnode = FuncGraph::NewCNode(inputs);
  193. MS_EXCEPTION_IF_NULL(cnode);
  194. cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
  195. // create kernel_info from new parameter
  196. auto kernel_info = std::make_shared<device::KernelInfo>();
  197. std::vector<size_t> feature_map_input_indexs;
  198. // if the node only has the primitive(such as getNext) or the node's input has a feature map input
  199. // then the node's output is a feature map output
  200. for (size_t index = 1; index < inputs.size(); ++index) {
  201. auto node = inputs[index];
  202. if (AnfAlgo::IsFeatureMapOutput(node)) {
  203. feature_map_input_indexs.push_back(index);
  204. }
  205. }
  206. if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
  207. AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
  208. }
  209. if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
  210. kernel_info->SetFeatureMapFlag(true);
  211. }
  212. if (AnfAlgo::IsRealCNodeKernel(cnode)) {
  213. AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
  214. AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
  215. }
  216. cnode->set_kernel_info(kernel_info);
  217. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  218. return cnode;
  219. }
  220. CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
  221. MS_EXCEPTION_IF_NULL(cnode);
  222. auto new_cnode = std::make_shared<CNode>(*cnode);
  223. // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map
  224. if (BackendNodeExistInFrontBackendMap(cnode)) {
  225. FrontBackendlMapUpdate(cnode, new_cnode);
  226. }
  227. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  228. return new_cnode;
  229. }
  230. ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
  231. ParameterPtr new_parameter = add_parameter();
  232. MS_EXCEPTION_IF_NULL(new_parameter);
  233. // create kernel_info form new parameter
  234. auto kernel_info = std::make_shared<device::KernelInfo>();
  235. size_t output_tensor_num = 1;
  236. // if use default parameter = nullptr,it remarks create a new parameter from no parameter
  237. if (parameter == nullptr) {
  238. new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>());
  239. kernel_info->SetFeatureMapFlag(true);
  240. } else {
  241. // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
  242. new_parameter->set_abstract(parameter->abstract());
  243. new_parameter->set_name(parameter->name());
  244. if (AnfAlgo::IsParameterWeight(parameter)) {
  245. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
  246. auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
  247. new_parameter->set_default_param(param_value_new);
  248. kernel_info->SetFeatureMapFlag(false);
  249. } else {
  250. kernel_info->SetFeatureMapFlag(true);
  251. }
  252. // if output is a tuple tensor,now can use for loop to handle tuple tensor
  253. output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter);
  254. }
  255. new_parameter->set_kernel_info(kernel_info);
  256. // create kernel_build_info for new parameter
  257. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  258. // create init data type,
  259. std::vector<TypeId> init_data_type = {};
  260. for (size_t i = 0; i < output_tensor_num; i++) {
  261. TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, i);
  262. init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type);
  263. }
  264. // set the format of parameter to DEFAULT_FORMAT
  265. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
  266. // set parameter initaial device data type
  267. kernel_build_info_builder->SetOutputsDeviceType(init_data_type);
  268. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get());
  269. AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
  270. return new_parameter;
  271. }
  272. std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) {
  273. MS_EXCEPTION_IF_NULL(value_node);
  274. auto node_value = value_node->value();
  275. auto output_size = AnfAlgo::GetOutputTensorNum(value_node);
  276. std::vector<AnfNodePtr> convert_inputs;
  277. if (!node_value->isa<ValueTuple>()) {
  278. MS_LOG(EXCEPTION) << "multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
  279. }
  280. auto value_tuple = node_value->cast<ValueTuplePtr>();
  281. MS_EXCEPTION_IF_NULL(value_tuple);
  282. if (value_tuple->size() != output_size) {
  283. MS_LOG(EXCEPTION) << "value tuple size" << value_tuple->size()
  284. << " is not mathced with the value node's output size" << output_size;
  285. }
  286. for (size_t index = 0; index < value_tuple->value().size(); ++index) {
  287. auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]);
  288. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)},
  289. {AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get());
  290. AddValueNodeToGraph(new_value_node);
  291. auto kernel_info = std::make_shared<device::KernelInfo>();
  292. new_value_node->set_kernel_info(kernel_info);
  293. kernel_info->SetFeatureMapFlag(false);
  294. // create kernel_build_info for new value node
  295. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  296. // set the format of value_node to DEFAULT_FORMAT
  297. kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
  298. // set value node initial device data type = infer data type
  299. kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown});
  300. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  301. AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
  302. AddValueNodeToGraph(new_value_node);
  303. convert_inputs.emplace_back(new_value_node);
  304. }
  305. if (!RemoveValueNodeFromGraph(value_node)) {
  306. MS_LOG(WARNING) << "failed to remove the value_node " << value_node->DebugString();
  307. }
  308. return convert_inputs;
  309. }
  310. ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
  311. MS_EXCEPTION_IF_NULL(value_node);
  312. ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
  313. new_value_node->set_abstract(value_node->abstract());
  314. // create kernel_info fo new value node
  315. auto kernel_info = std::make_shared<device::KernelInfo>();
  316. kernel_info->SetFeatureMapFlag(false);
  317. new_value_node->set_kernel_info(kernel_info);
  318. // create kernel_build_info for new value node
  319. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  320. // set the format of value_node to DEFAULT_FORMAT
  321. auto output_tensor_num = AnfAlgo::GetOutputTensorNum(value_node);
  322. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
  323. // set value node initial device data type = infer data type
  324. std::vector<TypeId> types = std::vector<TypeId>(output_tensor_num, kTypeUnknown);
  325. kernel_build_info_builder->SetOutputsDeviceType(types);
  326. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  327. AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
  328. return new_value_node;
  329. }
  330. const std::vector<AnfNodePtr> &KernelGraph::inputs() const {
  331. MS_EXCEPTION_IF_NULL(inputs_);
  332. return *inputs_;
  333. }
  334. void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) {
  335. MS_EXCEPTION_IF_NULL(front_anf);
  336. MS_EXCEPTION_IF_NULL(backend_anf);
  337. if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) {
  338. MS_LOG(EXCEPTION) << "anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
  339. }
  340. if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
  341. MS_LOG(EXCEPTION) << "kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
  342. }
  343. front_backend_anf_map_[front_anf] = backend_anf;
  344. backend_front_anf_map_[backend_anf] = front_anf;
  345. }
  346. void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) {
  347. MS_EXCEPTION_IF_NULL(old_backend_anf);
  348. MS_EXCEPTION_IF_NULL(new_backend_anf);
  349. if (old_backend_anf.get() == new_backend_anf.get()) {
  350. MS_LOG(EXCEPTION) << "old can't be same with new";
  351. }
  352. if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
  353. MS_LOG(DEBUG) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
  354. return;
  355. }
  356. if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
  357. MS_LOG(EXCEPTION) << "anf is not exist in the map ,old " << old_backend_anf->DebugString();
  358. }
  359. front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
  360. backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
  361. // delete old kernel
  362. (void)backend_front_anf_map_.erase(old_backend_anf);
  363. }
  364. // get kernel by anf
  365. AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
  366. if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) {
  367. return nullptr;
  368. }
  369. return front_backend_anf_map_[front_anf];
  370. }
  371. bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) {
  372. return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end();
  373. }
  374. ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) {
  375. if (tensor_to_value_node_map_.find(tensor) == tensor_to_value_node_map_.end()) {
  376. return nullptr;
  377. }
  378. return tensor_to_value_node_map_[tensor];
  379. }
  380. void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) {
  381. MS_EXCEPTION_IF_NULL(tensor);
  382. MS_EXCEPTION_IF_NULL(value_node);
  383. tensor_to_value_node_map_[tensor] = value_node;
  384. }
  385. void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) {
  386. MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num;
  387. auto output_depend_edge = std::pair<AnfNodePtr, size_t>(node, depend_edge_num);
  388. // add output depend edge of input
  389. auto output_it = node_output_edges_.find(input);
  390. if (output_it == node_output_edges_.end()) {
  391. node_output_edges_[input] = std::vector<std::pair<AnfNodePtr, size_t>>{output_depend_edge};
  392. } else {
  393. output_it->second.push_back(output_depend_edge);
  394. }
  395. // add input depend edge of output
  396. auto input_depend_edge = std::pair<AnfNodePtr, size_t>(input, depend_edge_num);
  397. auto input_it = node_input_edges_.find(node);
  398. if (input_it == node_input_edges_.end()) {
  399. node_input_edges_[node] = std::vector<std::pair<AnfNodePtr, size_t>>{input_depend_edge};
  400. } else {
  401. input_it->second.push_back(input_depend_edge);
  402. }
  403. // add node input depend num
  404. auto depend_it = node_input_num_.find(node);
  405. if (depend_it == node_input_num_.end()) {
  406. node_input_num_[node] = depend_edge_num;
  407. } else {
  408. depend_it->second += depend_edge_num;
  409. }
  410. }
  411. std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
  412. MS_EXCEPTION_IF_NULL(node);
  413. auto it = node_output_edges_.find(node);
  414. if (it == node_output_edges_.end()) {
  415. MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]";
  416. }
  417. std::vector<AnfNodePtr> output_nodes;
  418. auto trans = [](const std::pair<AnfNodePtr, size_t> &pair) -> AnfNodePtr { return pair.first; };
  419. (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes), trans);
  420. return output_nodes;
  421. }
  422. // update the depend relations of control depend
  423. void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
  424. for (const auto &node : depends) {
  425. MS_EXCEPTION_IF_NULL(node);
  426. if (!node->isa<CNode>()) {
  427. return;
  428. }
  429. auto cnode = node->cast<CNodePtr>();
  430. MS_EXCEPTION_IF_NULL(cnode);
  431. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
  432. MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend";
  433. }
  434. auto prior_node = cnode->input(kControlDependPriorIndex);
  435. auto depend_node = cnode->input(kControlDependBehindIndex);
  436. MS_EXCEPTION_IF_NULL(prior_node);
  437. MS_EXCEPTION_IF_NULL(depend_node);
  438. std::vector<AnfNodePtr> prior_nodes = {prior_node};
  439. std::vector<AnfNodePtr> depend_nodes = {depend_node};
  440. MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString();
  441. if (prior_node->isa<Parameter>()) {
  442. prior_nodes = GetOutputNodes(prior_node);
  443. }
  444. if (depend_node->isa<Parameter>()) {
  445. depend_nodes = GetOutputNodes(depend_node);
  446. }
  447. for (auto &first_node : prior_nodes) {
  448. for (auto &second_node : depend_nodes) {
  449. MS_EXCEPTION_IF_NULL(first_node);
  450. MS_EXCEPTION_IF_NULL(second_node);
  451. MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
  452. AddDependEdge(second_node, first_node, 1);
  453. }
  454. }
  455. }
  456. }
  457. bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  458. std::unordered_set<AnfNodePtr> *visited_nodes) {
  459. MS_EXCEPTION_IF_NULL(node);
  460. if (!node->isa<CNode>()) {
  461. return false;
  462. }
  463. auto cnode = node->cast<CNodePtr>();
  464. MS_EXCEPTION_IF_NULL(cnode);
  465. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
  466. return false;
  467. }
  468. // set the control depend visited but don't push it into the que
  469. if (visited_nodes->find(node) != visited_nodes->end()) {
  470. return true;
  471. }
  472. (void)visited_nodes->insert(cnode);
  473. // add a 0 depend num to keep the link relations to prepare for finding zero output nodes
  474. auto prior_node = cnode->input(kControlDependPriorIndex);
  475. auto depend_node = cnode->input(kControlDependBehindIndex);
  476. for (const auto &input : cnode->inputs()) {
  477. AddDependEdge(node, input, 0);
  478. }
  479. PushNoVisitedNode(depend_node, que, visited_nodes);
  480. PushNoVisitedNode(prior_node, que, visited_nodes);
  481. return true;
  482. }
  483. void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
  484. node_output_edges_.clear();
  485. node_input_num_.clear();
  486. node_input_edges_.clear();
  487. std::vector<AnfNodePtr> control_depends;
  488. std::unordered_set<AnfNodePtr> visited_nodes;
  489. std::queue<AnfNodePtr> que;
  490. que.push(get_return());
  491. while (!que.empty()) {
  492. auto node = que.front();
  493. que.pop();
  494. MS_EXCEPTION_IF_NULL(node);
  495. if (node->isa<Parameter>() || node->isa<ValueNode>()) {
  496. seed_nodes->push(node);
  497. continue;
  498. }
  499. if (!node->isa<CNode>()) {
  500. continue;
  501. }
  502. auto cnode = node->cast<CNodePtr>();
  503. MS_EXCEPTION_IF_NULL(cnode);
  504. // handle data links
  505. for (const auto &input : cnode->inputs()) {
  506. size_t depend_edge_num = 1;
  507. // handle control depend,all inputs of control depend has no depend edge
  508. if (HandleControlDependNode(input, &que, &visited_nodes)) {
  509. control_depends.push_back(input);
  510. depend_edge_num = 0;
  511. }
  512. PushNoVisitedNode(input, &que, &visited_nodes);
  513. AddDependEdge(node, input, depend_edge_num);
  514. }
  515. }
  516. UpdateControlDependRelations(control_depends);
  517. }
  518. void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); }
  519. bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; }
  520. AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
  521. if (!IsInRefOutputMap(out_pair)) {
  522. MS_LOG(EXCEPTION) << "out_pair is not in RefOutputMap";
  523. }
  524. return ref_out_in_map_.at(out_pair);
  525. }
  526. void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) {
  527. if (IsInRefOutputMap(final_pair)) {
  528. MS_LOG(EXCEPTION) << "out_pair is already in RefOutputMap";
  529. }
  530. (void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair));
  531. }
  532. bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
  533. if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) {
  534. (void)graph_value_nodes_.erase(value_node);
  535. return true;
  536. }
  537. return false;
  538. }
  539. void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node) {
  540. MS_EXCEPTION_IF_NULL(old_anf_node);
  541. MS_EXCEPTION_IF_NULL(new_anf_node);
  542. MS_EXCEPTION_IF_NULL(inputs_);
  543. auto it = node_output_edges_.find(old_anf_node);
  544. if (it == node_output_edges_.end()) {
  545. MS_LOG(EXCEPTION) << "Can't find anf node in node_output_edges map";
  546. }
  547. auto &outputs = it->second;
  548. for (auto &output_node : outputs) {
  549. auto output_cnode = output_node.first->cast<CNodePtr>();
  550. MS_EXCEPTION_IF_NULL(output_cnode);
  551. auto &output_node_inputs = output_cnode->inputs();
  552. for (size_t i = 1; i < output_node_inputs.size(); i++) {
  553. if (output_node_inputs[i] == old_anf_node) {
  554. output_cnode->set_input(i, new_anf_node);
  555. }
  556. }
  557. // update graph inputs
  558. for (size_t i = 0; i < inputs_->size(); i++) {
  559. if ((*inputs_)[i] == old_anf_node) {
  560. (*inputs_)[i] = new_anf_node;
  561. break;
  562. }
  563. }
  564. }
  565. // update front to backend map
  566. FrontBackendlMapUpdate(old_anf_node, new_anf_node);
  567. // update output depend relations
  568. node_output_edges_[new_anf_node] = it->second;
  569. (void)node_output_edges_.erase(old_anf_node);
  570. }
  571. void KernelGraph::UpdateExecuteKernelStreamLabel() {
  572. for (auto &kernel : execution_order_) {
  573. AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get());
  574. }
  575. }
  576. void KernelGraph::UpdateChildGraphOrder() {
  577. MS_LOG(INFO) << "graph id:" << graph_id_;
  578. auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
  579. for (auto &old_child_graph : child_graph_order_) {
  580. old_child_graph->set_parent_graph(nullptr);
  581. }
  582. child_graph_order_.clear();
  583. for (auto &call_node : call_nodes) {
  584. MS_EXCEPTION_IF_NULL(call_node);
  585. auto call_child_graphs = AnfAlgo ::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
  586. for (const auto &child_graph : call_child_graphs) {
  587. MS_EXCEPTION_IF_NULL(child_graph);
  588. if (child_graph != parent_graph()) {
  589. child_graph->set_parent_graph(shared_from_this()->cast<std::shared_ptr<KernelGraph>>());
  590. child_graph_order_.push_back(child_graph);
  591. }
  592. }
  593. }
  594. for (size_t i = 0; i < child_graph_order_.size(); i++) {
  595. MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order_[i]->graph_id() << "]";
  596. }
  597. }
  598. std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
  599. std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
  600. if (IsLeafGraph()) {
  601. leaf_graph_order.push_back(shared_from_this()->cast<KernelGraphPtr>());
  602. } else {
  603. for (const auto &child_graph : child_graph_order_) {
  604. MS_EXCEPTION_IF_NULL(child_graph);
  605. auto child_leaf_graph_order = child_graph->GetLeafGraphOrder();
  606. std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order));
  607. }
  608. }
  609. return leaf_graph_order;
  610. }
  611. bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
  612. std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const {
  613. auto anf_list = TopoSort(get_return());
  614. std::vector<CNodePtr> result;
  615. for (const auto &anf : anf_list) {
  616. if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
  617. result.push_back(anf->cast<CNodePtr>());
  618. }
  619. }
  620. return result;
  621. }
  622. std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr &parameter) {
  623. MS_EXCEPTION_IF_NULL(parameter);
  624. if (real_inputs_.find(parameter) == real_inputs_.end()) {
  625. return {};
  626. }
  627. return real_inputs_[parameter];
  628. }
  629. void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg) {
  630. MS_EXCEPTION_IF_NULL(parameter);
  631. MS_EXCEPTION_IF_NULL(arg);
  632. MS_LOG(INFO) << "parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
  633. MS_EXCEPTION_IF_NULL(parameter);
  634. MS_EXCEPTION_IF_NULL(arg);
  635. if (real_inputs_.find(parameter) == real_inputs_.end()) {
  636. real_inputs_[parameter] = std::set<AnfNodePtr>();
  637. }
  638. auto &args = real_inputs_[parameter];
  639. (void)args.insert(arg);
  640. }
  641. void KernelGraph::UpdateCallRealInput() {
  642. MS_LOG(INFO) << "Update graph id: " << graph_id_;
  643. for (auto &it : real_inputs_) {
  644. auto &parameter = it.first;
  645. MS_EXCEPTION_IF_NULL(parameter);
  646. auto &real_inputs = it.second;
  647. std::set<AnfNodePtr> new_real_inputs;
  648. std::set<AnfNodePtr> erase_real_inputs;
  649. for (auto &real_input : real_inputs) {
  650. // if real input is a call node ,find the child graph output act as the new real input
  651. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0);
  652. MS_EXCEPTION_IF_NULL(item_with_index.first);
  653. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) {
  654. MS_LOG(INFO) << "paramter: " << parameter->DebugString()
  655. << " erase real input:" << item_with_index.first->DebugString();
  656. (void)erase_real_inputs.insert(item_with_index.first);
  657. auto call_node_outputs = GetCallRealOutputs(item_with_index.first);
  658. for (auto &call_node_output : call_node_outputs) {
  659. MS_EXCEPTION_IF_NULL(call_node_output);
  660. MS_LOG(INFO) << "paramter: " << parameter->DebugString()
  661. << " insert real input:" << call_node_output->DebugString();
  662. (void)new_real_inputs.insert(call_node_output);
  663. }
  664. continue;
  665. }
  666. for (auto &erase_node : erase_real_inputs) {
  667. (void)real_inputs.erase(erase_node);
  668. }
  669. for (auto &new_real_input : new_real_inputs) {
  670. (void)real_inputs.insert(new_real_input);
  671. }
  672. }
  673. }
  674. }
  675. std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
  676. KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); }
  677. } // namespace session
  678. } // namespace mindspore