You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_graph.cc 32 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/kernel_graph.h"
  17. #include <algorithm>
  18. #include <queue>
  19. #include <unordered_set>
  20. #include <set>
  21. #include "operator/ops.h"
  22. #include "ir/param_value_py.h"
  23. #include "session/anf_runtime_algorithm.h"
  24. #include "device/kernel_info.h"
  25. #include "kernel/kernel_build_info.h"
  26. #include "device/kernel_runtime_manager.h"
  27. namespace mindspore {
  28. namespace session {
  29. namespace {
  30. constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
  31. constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
  32. void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  33. std::unordered_set<AnfNodePtr> *visited_nodes) {
  34. MS_EXCEPTION_IF_NULL(que);
  35. MS_EXCEPTION_IF_NULL(visited_nodes);
  36. if (visited_nodes->find(node) == visited_nodes->end()) {
  37. que->push(node);
  38. (void)visited_nodes->insert(node);
  39. MS_LOG(DEBUG) << "Push que:" << node->DebugString();
  40. }
  41. }
  42. std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
  43. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(call_node, 0);
  44. MS_EXCEPTION_IF_NULL(item_with_index.first);
  45. if (!AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) {
  46. return {item_with_index.first};
  47. }
  48. std::vector<AnfNodePtr> real_inputs;
  49. auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast<CNodePtr>());
  50. for (const auto &child_graph : child_graphs) {
  51. if (child_graph->get_output_null()) {
  52. continue;
  53. }
  54. auto real_input = child_graph->output();
  55. auto child_real_inputs = GetCallRealOutputs(real_input);
  56. std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs));
  57. }
  58. return real_inputs;
  59. }
  60. } // namespace
  61. std::vector<AnfNodePtr> KernelGraph::outputs() const {
  62. auto graph_output = output();
  63. if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) {
  64. auto make_tuple = output()->cast<CNodePtr>();
  65. MS_EXCEPTION_IF_NULL(make_tuple);
  66. auto &inputs = make_tuple->inputs();
  67. return std::vector<AnfNodePtr>(inputs.begin() + 1, inputs.end());
  68. }
  69. return std::vector<AnfNodePtr>(1, graph_output);
  70. }
  71. void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
  72. std::unordered_set<AnfNodePtr> *visited_nodes) {
  73. MS_EXCEPTION_IF_NULL(visit_queue);
  74. MS_EXCEPTION_IF_NULL(visited_nodes);
  75. auto it = node_output_edges_.find(node);
  76. if (it == node_output_edges_.end()) {
  77. // value node and parameter has no input,no need to print log
  78. if (node->isa<CNode>()) {
  79. MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
  80. }
  81. return;
  82. }
  83. // visit all reduce node first, then other nodes
  84. std::vector<AnfNodePtr> active_nodes;
  85. for (const auto &output_edge : it->second) {
  86. auto next_node = output_edge.first;
  87. if (node_input_num_.find(next_node) == node_input_num_.end()) {
  88. MS_EXCEPTION_IF_NULL(next_node);
  89. MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
  90. }
  91. MS_EXCEPTION_IF_NULL(next_node);
  92. MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
  93. << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
  94. if (node_input_num_[next_node] < output_edge.second) {
  95. MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node]
  96. << ",depend edge:" << output_edge.second;
  97. }
  98. node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
  99. // allreduce first
  100. if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
  101. (void)visited_nodes->insert(next_node);
  102. if (AnfAlgo::IsCommunicationOp(next_node)) {
  103. MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
  104. visit_queue->push(next_node);
  105. } else {
  106. active_nodes.emplace_back(next_node);
  107. }
  108. }
  109. }
  110. for (auto &node : active_nodes) {
  111. MS_LOG(DEBUG) << "visit node:" << node->DebugString();
  112. visit_queue->push(node);
  113. }
  114. }
  115. void KernelGraph::SetExecOrderByDefault() {
  116. std::queue<AnfNodePtr> seed_nodes;
  117. UpdateNodeEdgeList(&seed_nodes);
  118. execution_order_.clear();
  119. std::unordered_set<AnfNodePtr> visited_nodes;
  120. std::queue<AnfNodePtr> zero_input_nodes;
  121. AnfNodePtr last_communication_node = nullptr;
  122. std::queue<AnfNodePtr> communication_descendants;
  123. while (!seed_nodes.empty() || last_communication_node != nullptr) {
  124. // seed nodes first, then visit last all reduce node descendant
  125. if (seed_nodes.empty()) {
  126. VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
  127. last_communication_node = nullptr;
  128. } else {
  129. zero_input_nodes.push(seed_nodes.front());
  130. seed_nodes.pop();
  131. }
  132. // all reduce node descendant first, then common queue
  133. while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
  134. AnfNodePtr node = nullptr;
  135. bool is_communication_descendant = false;
  136. if (communication_descendants.empty()) {
  137. node = zero_input_nodes.front();
  138. zero_input_nodes.pop();
  139. } else {
  140. node = communication_descendants.front();
  141. communication_descendants.pop();
  142. is_communication_descendant = true;
  143. }
  144. // add execute node
  145. MS_EXCEPTION_IF_NULL(node);
  146. if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
  147. execution_order_.push_back(node->cast<CNodePtr>());
  148. }
  149. // for all reduce node, visit last all reduce node descendant
  150. if (AnfAlgo::IsCommunicationOp(node)) {
  151. if (last_communication_node != nullptr) {
  152. VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
  153. }
  154. last_communication_node = node;
  155. } else if (is_communication_descendant) {
  156. VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
  157. } else {
  158. VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
  159. }
  160. }
  161. }
  162. CheckLoop();
  163. // resort start label / end goto
  164. std::vector<CNodePtr> re_order;
  165. if (start_label_ != nullptr) {
  166. re_order.push_back(start_label_);
  167. }
  168. for (auto &node : execution_order_) {
  169. if (node == start_label_ || node == end_goto_) {
  170. continue;
  171. }
  172. re_order.push_back(node);
  173. }
  174. if (end_goto_ != nullptr) {
  175. re_order.push_back(end_goto_);
  176. }
  177. execution_order_ = re_order;
  178. }
  179. void KernelGraph::CheckLoop() {
  180. std::map<AnfNodePtr, size_t> none_zero_nodes;
  181. if (node_input_edges_.size() != node_input_num_.size()) {
  182. MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size()
  183. << "not equal to node_input_num_ size:" << node_input_num_.size();
  184. }
  185. for (auto &it : node_input_num_) {
  186. MS_EXCEPTION_IF_NULL(it.first);
  187. string str;
  188. auto node_input_it = node_input_edges_.find(it.first);
  189. if (node_input_it == node_input_edges_.end()) {
  190. MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]";
  191. }
  192. for (const auto &input_edge : node_input_edges_[it.first]) {
  193. MS_EXCEPTION_IF_NULL(input_edge.first);
  194. str = str.append(input_edge.first->DebugString()).append("|");
  195. }
  196. if (it.second != 0) {
  197. MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second;
  198. none_zero_nodes[it.first] = it.second;
  199. }
  200. }
  201. // if don't consider control depend and loop exit,a exception will be throw
  202. if (!none_zero_nodes.empty()) {
  203. MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
  204. }
  205. }
  206. CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
  207. auto cnode = FuncGraph::NewCNode(inputs);
  208. MS_EXCEPTION_IF_NULL(cnode);
  209. cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
  210. // create kernel_info from new parameter
  211. auto kernel_info = std::make_shared<device::KernelInfo>();
  212. std::vector<size_t> feature_map_input_indexs;
  213. // if the node only has the primitive(such as getNext) or the node's input has a feature map input
  214. // then the node's output is a feature map output
  215. for (size_t index = 1; index < inputs.size(); ++index) {
  216. auto node = inputs[index];
  217. if (AnfAlgo::IsFeatureMapOutput(node)) {
  218. feature_map_input_indexs.push_back(index);
  219. }
  220. }
  221. if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
  222. AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
  223. }
  224. if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
  225. kernel_info->SetFeatureMapFlag(true);
  226. }
  227. if (AnfAlgo::IsRealCNodeKernel(cnode)) {
  228. AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
  229. AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
  230. }
  231. cnode->set_kernel_info(kernel_info);
  232. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  233. return cnode;
  234. }
  235. CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
  236. MS_EXCEPTION_IF_NULL(cnode);
  237. auto new_cnode = std::make_shared<CNode>(*cnode);
  238. // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map
  239. if (BackendNodeExistInFrontBackendMap(cnode)) {
  240. FrontBackendlMapUpdate(cnode, new_cnode);
  241. }
  242. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  243. return new_cnode;
  244. }
  245. ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
  246. ParameterPtr new_parameter = add_parameter();
  247. MS_EXCEPTION_IF_NULL(new_parameter);
  248. // create kernel_info form new parameter
  249. auto kernel_info = std::make_shared<device::KernelInfo>();
  250. size_t output_tensor_num = 1;
  251. // if use default parameter = nullptr,it remarks create a new parameter from no parameter
  252. if (parameter == nullptr) {
  253. new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>());
  254. kernel_info->SetFeatureMapFlag(true);
  255. } else {
  256. // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
  257. new_parameter->set_abstract(parameter->abstract());
  258. new_parameter->set_name(parameter->name());
  259. if (AnfAlgo::IsParameterWeight(parameter)) {
  260. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
  261. auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
  262. new_parameter->set_default_param(param_value_new);
  263. kernel_info->SetFeatureMapFlag(false);
  264. } else {
  265. kernel_info->SetFeatureMapFlag(true);
  266. }
  267. // if output is a tuple tensor,now can use for loop to handle tuple tensor
  268. output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter);
  269. }
  270. new_parameter->set_kernel_info(kernel_info);
  271. // create kernel_build_info for new parameter
  272. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  273. // create init data type,
  274. std::vector<TypeId> init_data_type = {};
  275. for (size_t i = 0; i < output_tensor_num; i++) {
  276. TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, i);
  277. init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type);
  278. }
  279. // set the format of parameter to DEFAULT_FORMAT
  280. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
  281. // set parameter initaial device data type
  282. kernel_build_info_builder->SetOutputsDeviceType(init_data_type);
  283. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get());
  284. AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
  285. return new_parameter;
  286. }
  287. std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) {
  288. MS_EXCEPTION_IF_NULL(value_node);
  289. auto node_value = value_node->value();
  290. auto output_size = AnfAlgo::GetOutputTensorNum(value_node);
  291. std::vector<AnfNodePtr> convert_inputs;
  292. if (!node_value->isa<ValueTuple>()) {
  293. MS_LOG(EXCEPTION) << "multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
  294. }
  295. auto value_tuple = node_value->cast<ValueTuplePtr>();
  296. MS_EXCEPTION_IF_NULL(value_tuple);
  297. if (value_tuple->size() != output_size) {
  298. MS_LOG(EXCEPTION) << "value tuple size" << value_tuple->size()
  299. << " is not mathced with the value node's output size" << output_size;
  300. }
  301. for (size_t index = 0; index < value_tuple->value().size(); ++index) {
  302. auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]);
  303. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)},
  304. {AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get());
  305. AddValueNodeToGraph(new_value_node);
  306. auto kernel_info = std::make_shared<device::KernelInfo>();
  307. new_value_node->set_kernel_info(kernel_info);
  308. kernel_info->SetFeatureMapFlag(false);
  309. // create kernel_build_info for new value node
  310. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  311. // set the format of value_node to DEFAULT_FORMAT
  312. kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
  313. // set value node initial device data type = infer data type
  314. kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown});
  315. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  316. AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
  317. AddValueNodeToGraph(new_value_node);
  318. convert_inputs.emplace_back(new_value_node);
  319. }
  320. if (!RemoveValueNodeFromGraph(value_node)) {
  321. MS_LOG(WARNING) << "failed to remove the value_node " << value_node->DebugString();
  322. }
  323. return convert_inputs;
  324. }
  325. ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
  326. MS_EXCEPTION_IF_NULL(value_node);
  327. ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
  328. new_value_node->set_abstract(value_node->abstract());
  329. // create kernel_info fo new value node
  330. auto kernel_info = std::make_shared<device::KernelInfo>();
  331. kernel_info->SetFeatureMapFlag(false);
  332. new_value_node->set_kernel_info(kernel_info);
  333. // create kernel_build_info for new value node
  334. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  335. // set the format of value_node to DEFAULT_FORMAT
  336. auto output_tensor_num = AnfAlgo::GetOutputTensorNum(value_node);
  337. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
  338. // set value node initial device data type = infer data type
  339. std::vector<TypeId> types = std::vector<TypeId>(output_tensor_num, kTypeUnknown);
  340. kernel_build_info_builder->SetOutputsDeviceType(types);
  341. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  342. AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
  343. return new_value_node;
  344. }
  345. const std::vector<AnfNodePtr> &KernelGraph::inputs() const {
  346. MS_EXCEPTION_IF_NULL(inputs_);
  347. return *inputs_;
  348. }
  349. void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) {
  350. MS_EXCEPTION_IF_NULL(front_anf);
  351. MS_EXCEPTION_IF_NULL(backend_anf);
  352. if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) {
  353. MS_LOG(EXCEPTION) << "anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
  354. }
  355. if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
  356. MS_LOG(EXCEPTION) << "kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
  357. }
  358. front_backend_anf_map_[front_anf] = backend_anf;
  359. backend_front_anf_map_[backend_anf] = front_anf;
  360. }
  361. void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) {
  362. MS_EXCEPTION_IF_NULL(old_backend_anf);
  363. MS_EXCEPTION_IF_NULL(new_backend_anf);
  364. if (old_backend_anf == new_backend_anf) {
  365. MS_LOG(DEBUG) << "old same with new:" << old_backend_anf->DebugString();
  366. return;
  367. }
  368. if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
  369. MS_LOG(DEBUG) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
  370. return;
  371. }
  372. if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
  373. MS_LOG(EXCEPTION) << "anf is not exist in the map ,old " << old_backend_anf->DebugString();
  374. }
  375. front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
  376. backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
  377. // delete old kernel
  378. (void)backend_front_anf_map_.erase(old_backend_anf);
  379. }
  380. // get kernel by anf
  381. AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
  382. if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) {
  383. return nullptr;
  384. }
  385. return front_backend_anf_map_[front_anf];
  386. }
  387. bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) {
  388. return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end();
  389. }
  390. ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) {
  391. if (tensor_to_value_node_map_.find(tensor) == tensor_to_value_node_map_.end()) {
  392. return nullptr;
  393. }
  394. return tensor_to_value_node_map_[tensor];
  395. }
  396. void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) {
  397. MS_EXCEPTION_IF_NULL(tensor);
  398. MS_EXCEPTION_IF_NULL(value_node);
  399. tensor_to_value_node_map_[tensor] = value_node;
  400. }
  401. void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) {
  402. MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num;
  403. auto output_depend_edge = std::pair<AnfNodePtr, size_t>(node, depend_edge_num);
  404. // add output depend edge of input
  405. auto output_it = node_output_edges_.find(input);
  406. if (output_it == node_output_edges_.end()) {
  407. node_output_edges_[input] = std::vector<std::pair<AnfNodePtr, size_t>>{output_depend_edge};
  408. } else {
  409. output_it->second.push_back(output_depend_edge);
  410. }
  411. // add input depend edge of output
  412. auto input_depend_edge = std::pair<AnfNodePtr, size_t>(input, depend_edge_num);
  413. auto input_it = node_input_edges_.find(node);
  414. if (input_it == node_input_edges_.end()) {
  415. node_input_edges_[node] = std::vector<std::pair<AnfNodePtr, size_t>>{input_depend_edge};
  416. } else {
  417. input_it->second.push_back(input_depend_edge);
  418. }
  419. // add node input depend num
  420. auto depend_it = node_input_num_.find(node);
  421. if (depend_it == node_input_num_.end()) {
  422. node_input_num_[node] = depend_edge_num;
  423. } else {
  424. depend_it->second += depend_edge_num;
  425. }
  426. }
  427. std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
  428. MS_EXCEPTION_IF_NULL(node);
  429. auto it = node_output_edges_.find(node);
  430. if (it == node_output_edges_.end()) {
  431. MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]";
  432. }
  433. std::vector<AnfNodePtr> output_nodes;
  434. auto trans = [](const std::pair<AnfNodePtr, size_t> &pair) -> AnfNodePtr { return pair.first; };
  435. (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes), trans);
  436. return output_nodes;
  437. }
  438. // update the depend relations of control depend
  439. void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
  440. for (const auto &node : depends) {
  441. MS_EXCEPTION_IF_NULL(node);
  442. if (!node->isa<CNode>()) {
  443. return;
  444. }
  445. auto cnode = node->cast<CNodePtr>();
  446. MS_EXCEPTION_IF_NULL(cnode);
  447. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
  448. MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend";
  449. }
  450. auto prior_node = cnode->input(kControlDependPriorIndex);
  451. auto depend_node = cnode->input(kControlDependBehindIndex);
  452. MS_EXCEPTION_IF_NULL(prior_node);
  453. MS_EXCEPTION_IF_NULL(depend_node);
  454. std::vector<AnfNodePtr> prior_nodes = {prior_node};
  455. std::vector<AnfNodePtr> depend_nodes = {depend_node};
  456. MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString();
  457. if (prior_node->isa<Parameter>()) {
  458. prior_nodes = GetOutputNodes(prior_node);
  459. }
  460. if (depend_node->isa<Parameter>()) {
  461. depend_nodes = GetOutputNodes(depend_node);
  462. }
  463. for (auto &first_node : prior_nodes) {
  464. if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
  465. continue;
  466. }
  467. for (auto &second_node : depend_nodes) {
  468. if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
  469. continue;
  470. }
  471. MS_EXCEPTION_IF_NULL(first_node);
  472. MS_EXCEPTION_IF_NULL(second_node);
  473. MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
  474. AddDependEdge(second_node, first_node, 1);
  475. }
  476. }
  477. }
  478. }
  479. bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  480. std::unordered_set<AnfNodePtr> *visited_nodes) {
  481. MS_EXCEPTION_IF_NULL(node);
  482. if (!node->isa<CNode>()) {
  483. return false;
  484. }
  485. auto cnode = node->cast<CNodePtr>();
  486. MS_EXCEPTION_IF_NULL(cnode);
  487. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
  488. return false;
  489. }
  490. // set the control depend visited but don't push it into the que
  491. if (visited_nodes->find(node) != visited_nodes->end()) {
  492. return true;
  493. }
  494. (void)visited_nodes->insert(cnode);
  495. // add a 0 depend num to keep the link relations to prepare for finding zero output nodes
  496. auto prior_node = cnode->input(kControlDependPriorIndex);
  497. auto depend_node = cnode->input(kControlDependBehindIndex);
  498. for (const auto &input : cnode->inputs()) {
  499. AddDependEdge(node, input, 0);
  500. }
  501. PushNoVisitedNode(depend_node, que, visited_nodes);
  502. PushNoVisitedNode(prior_node, que, visited_nodes);
  503. return true;
  504. }
  505. void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
  506. node_output_edges_.clear();
  507. node_input_num_.clear();
  508. node_input_edges_.clear();
  509. std::vector<AnfNodePtr> control_depends;
  510. std::unordered_set<AnfNodePtr> visited_nodes;
  511. std::queue<AnfNodePtr> que;
  512. que.push(get_return());
  513. while (!que.empty()) {
  514. auto node = que.front();
  515. que.pop();
  516. MS_EXCEPTION_IF_NULL(node);
  517. if (node->isa<Parameter>() || node->isa<ValueNode>()) {
  518. seed_nodes->push(node);
  519. continue;
  520. }
  521. if (!node->isa<CNode>()) {
  522. continue;
  523. }
  524. auto cnode = node->cast<CNodePtr>();
  525. MS_EXCEPTION_IF_NULL(cnode);
  526. // handle data links
  527. for (const auto &input : cnode->inputs()) {
  528. size_t depend_edge_num = 1;
  529. // handle control depend,all inputs of control depend has no depend edge
  530. if (HandleControlDependNode(input, &que, &visited_nodes)) {
  531. control_depends.push_back(input);
  532. depend_edge_num = 0;
  533. }
  534. PushNoVisitedNode(input, &que, &visited_nodes);
  535. AddDependEdge(node, input, depend_edge_num);
  536. }
  537. }
  538. UpdateControlDependRelations(control_depends);
  539. }
  540. void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); }
  541. bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; }
  542. AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
  543. if (!IsInRefOutputMap(out_pair)) {
  544. MS_LOG(EXCEPTION) << "out_pair is not in RefOutputMap";
  545. }
  546. return ref_out_in_map_.at(out_pair);
  547. }
  548. void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) {
  549. if (IsInRefOutputMap(final_pair)) {
  550. MS_LOG(EXCEPTION) << "out_pair is already in RefOutputMap";
  551. }
  552. (void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair));
  553. }
  554. bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
  555. if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) {
  556. (void)graph_value_nodes_.erase(value_node);
  557. return true;
  558. }
  559. return false;
  560. }
  561. void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node) {
  562. MS_EXCEPTION_IF_NULL(inputs_);
  563. auto it = node_output_edges_.find(old_anf_node);
  564. if (it != node_output_edges_.end()) {
  565. const auto &outputs = it->second;
  566. for (auto &output_node : outputs) {
  567. MS_EXCEPTION_IF_NULL(output_node.first);
  568. auto output_cnode = output_node.first->cast<CNodePtr>();
  569. MS_EXCEPTION_IF_NULL(output_cnode);
  570. auto &output_node_inputs = output_cnode->inputs();
  571. // don't replace node if it is a control edge => output_node.second == 0
  572. if (output_node.second == 0) {
  573. continue;
  574. }
  575. for (size_t i = 1; i < output_node_inputs.size(); i++) {
  576. if (output_node_inputs[i] == old_anf_node.get()) {
  577. output_cnode->set_input(i, new_anf_node);
  578. }
  579. }
  580. // update graph inputs
  581. for (size_t i = 0; i < inputs_->size(); i++) {
  582. if ((*inputs_)[i] == old_anf_node.get()) {
  583. MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString()
  584. << ",new graph input:" << new_anf_node->DebugString();
  585. (*inputs_)[i] = new_anf_node.get();
  586. break;
  587. }
  588. }
  589. }
  590. // update front to backend map
  591. FrontBackendlMapUpdate(old_anf_node, new_anf_node);
  592. // update output depend relations
  593. node_output_edges_[new_anf_node.get()] = it->second;
  594. (void)node_output_edges_.erase(old_anf_node);
  595. }
  596. // update graph inputs in child graph
  597. auto it_real_inputs = real_inputs_.find(old_anf_node);
  598. if (it_real_inputs != real_inputs_.end()) {
  599. // insert new parameter to map
  600. auto iter = real_inputs_.find(new_anf_node);
  601. if (iter != real_inputs_.end()) {
  602. MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited.";
  603. iter->second = it_real_inputs->second;
  604. } else {
  605. real_inputs_[new_anf_node.get()] = it_real_inputs->second;
  606. }
  607. // erase old parameter in map
  608. real_inputs_.erase(old_anf_node);
  609. }
  610. }
  611. void KernelGraph::UpdateExecuteKernelStreamLabel() {
  612. for (auto &kernel : execution_order_) {
  613. AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get());
  614. }
  615. }
  616. std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
  617. std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
  618. if (IsLeafGraph()) {
  619. leaf_graph_order.push_back(shared_from_this()->cast<KernelGraphPtr>());
  620. } else {
  621. for (const auto &child_graph : child_graph_order_) {
  622. MS_EXCEPTION_IF_NULL(child_graph);
  623. auto child_leaf_graph_order = child_graph->GetLeafGraphOrder();
  624. std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order));
  625. }
  626. }
  627. return leaf_graph_order;
  628. }
  629. bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
  630. std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const {
  631. std::vector<CNodePtr> result;
  632. for (const auto &anf : execution_order_) {
  633. if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
  634. result.push_back(anf->cast<CNodePtr>());
  635. }
  636. }
  637. return result;
  638. }
  639. void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg) {
  640. MS_EXCEPTION_IF_NULL(parameter);
  641. MS_EXCEPTION_IF_NULL(arg);
  642. MS_LOG(INFO) << "parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
  643. MS_EXCEPTION_IF_NULL(parameter);
  644. MS_EXCEPTION_IF_NULL(arg);
  645. if (real_inputs_.find(parameter) == real_inputs_.end()) {
  646. real_inputs_[parameter] = std::vector<AnfNodePtr>();
  647. }
  648. auto &args = real_inputs_[parameter];
  649. (void)args.push_back(arg);
  650. }
  651. std::vector<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr &parameter) {
  652. MS_EXCEPTION_IF_NULL(parameter);
  653. auto iter = real_inputs_.find(parameter);
  654. if (iter != real_inputs_.end()) {
  655. return iter->second;
  656. }
  657. MS_LOG(EXCEPTION) << parameter->DebugString() << " not found.";
  658. }
  659. void KernelGraph::UpdateCallRealInput() {
  660. MS_LOG(INFO) << "Update graph id: " << graph_id_;
  661. std::map<AnfNodePtr, std::vector<AnfNodePtr>> real_inputs_map;
  662. for (auto &it : real_inputs_) {
  663. auto parameter = it.first;
  664. MS_EXCEPTION_IF_NULL(parameter);
  665. auto real_inputs = it.second;
  666. std::vector<AnfNodePtr> new_real_inputs;
  667. std::set<AnfNodePtr> erase_real_inputs;
  668. for (auto &real_input : real_inputs) {
  669. // if real input is a call node ,find the child graph output act as the new real input
  670. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0);
  671. MS_EXCEPTION_IF_NULL(item_with_index.first);
  672. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) {
  673. (void)erase_real_inputs.insert(item_with_index.first);
  674. new_real_inputs = GetCallRealOutputs(item_with_index.first);
  675. continue;
  676. }
  677. }
  678. for (auto &erase_node : erase_real_inputs) {
  679. MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString();
  680. for (auto iter = real_inputs.begin(); iter != real_inputs.end();) {
  681. if (*iter == erase_node) {
  682. iter = real_inputs.erase(iter);
  683. } else {
  684. ++iter;
  685. }
  686. }
  687. }
  688. for (auto &new_real_input : new_real_inputs) {
  689. MS_LOG(INFO) << "paramter: " << parameter->DebugString()
  690. << " insert real input:" << new_real_input->DebugString();
  691. (void)real_inputs.push_back(new_real_input);
  692. }
  693. real_inputs_map[parameter] = real_inputs;
  694. }
  695. real_inputs_ = real_inputs_map;
  696. }
  697. void KernelGraph::PrintGraphExecuteOrder() const {
  698. MS_LOG(INFO) << "graph:" << graph_id_ << "execution order";
  699. for (size_t i = 0; i < execution_order_.size(); i++) {
  700. CNodePtr cur_cnode_ptr = execution_order_[i];
  701. MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
  702. std::string event_str;
  703. std::string label_str;
  704. if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) {
  705. event_str = ", event_id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId)) + "]";
  706. }
  707. if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) {
  708. label_str = ", label_id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrLabelIndex)) + "]";
  709. }
  710. if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) {
  711. auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrLabelSwitchList);
  712. label_str = ", label_id[";
  713. for (size_t j = 0; j < label_list.size(); ++j) {
  714. label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]");
  715. }
  716. }
  717. MS_LOG(INFO) << "index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
  718. << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
  719. << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"
  720. << event_str << label_str;
  721. }
  722. }
  723. std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
  724. KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); }
  725. } // namespace session
  726. } // namespace mindspore