You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cc 31 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/common/helper.h"
  17. #include <string>
  18. #include <utility>
  19. #include <unordered_set>
  20. #include <algorithm>
  21. #include <map>
  22. #include <set>
  23. #include <deque>
  24. #include "utils/utils.h"
  25. #include "base/base_ref.h"
  26. #include "backend/session/anf_runtime_algorithm.h"
  27. #include "base/core_ops.h"
  28. #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
  29. #include "frontend/operator/ops.h"
  30. #include "utils/ms_utils.h"
  31. #include "runtime/device/kernel_info.h"
  32. #include "utils/ms_context.h"
  33. namespace mindspore {
  34. namespace opt {
  35. constexpr size_t kType32Len = 4;
  36. std::vector<int64_t> Convert2Int(const std::vector<size_t> &v) {
  37. std::vector<int64_t> result;
  38. (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
  39. return result;
  40. }
  41. std::vector<int64_t> Convert2Long(const std::vector<size_t> &v) {
  42. std::vector<int64_t> result;
  43. (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToLong);
  44. return result;
  45. }
  46. bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
  47. MS_EXCEPTION_IF_NULL(node);
  48. std::vector<AnfNodePtr> node_list = TopoSort(graph.get_return());
  49. std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map;
  50. for (auto &nd : node_list) {
  51. MS_EXCEPTION_IF_NULL(nd);
  52. if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) {
  53. auto control_depend = nd->cast<CNodePtr>();
  54. auto prior_node = control_depend->input(kControlDependPriorIndex);
  55. auto behind_node = control_depend->input(kControlDependBehindIndex);
  56. auto it = control_depend_map.find(behind_node);
  57. if (it == control_depend_map.end()) {
  58. control_depend_map[behind_node] = std::set<AnfNodePtr>{prior_node};
  59. } else {
  60. it->second.insert(prior_node);
  61. }
  62. }
  63. }
  64. FuncGraphManagerPtr manager = graph.manager();
  65. MS_EXCEPTION_IF_NULL(manager);
  66. std::unordered_set<AnfNodePtr> seen_node;
  67. std::deque<AnfNodePtr> todo{node};
  68. while (!todo.empty()) {
  69. AnfNodePtr nd = todo.front();
  70. todo.pop_front();
  71. if (seen_node.count(nd) > 0 || !manager->all_nodes().contains(nd)) {
  72. continue;
  73. }
  74. (void)seen_node.insert(nd);
  75. if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) {
  76. return true;
  77. }
  78. if (nd->isa<CNode>()) {
  79. auto cnode = nd->cast<CNodePtr>();
  80. MS_EXCEPTION_IF_NULL(cnode);
  81. auto inputs = cnode->inputs();
  82. (void)todo.insert(todo.end(), inputs.begin(), inputs.end());
  83. }
  84. auto it = control_depend_map.find(nd);
  85. if (it != control_depend_map.end()) {
  86. (void)todo.insert(todo.end(), it->second.begin(), it->second.end());
  87. }
  88. }
  89. return false;
  90. }
  91. bool UnVisited(const BaseRef &n) {
  92. if (utils::isa<AnfNodePtr>(n)) {
  93. AnfNodePtr in = utils::cast<AnfNodePtr>(n);
  94. MS_EXCEPTION_IF_NULL(in);
  95. if (IsValueNode<Primitive>(in)) {
  96. auto value_node = in->cast<ValueNodePtr>();
  97. MS_EXCEPTION_IF_NULL(value_node);
  98. auto value = value_node->value();
  99. MS_EXCEPTION_IF_NULL(value);
  100. auto prim_py = value->cast<PrimitivePtr>();
  101. MS_EXCEPTION_IF_NULL(prim_py);
  102. return !prim_py->HasAttr(kAttrVisited);
  103. } else if (IsValueNode<FuncGraph>(in)) {
  104. auto func_graph = GetValueNode<FuncGraphPtr>(in);
  105. MS_EXCEPTION_IF_NULL(func_graph);
  106. return !func_graph->has_flag(kAttrVisited);
  107. }
  108. return false;
  109. }
  110. return false;
  111. }
  112. CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size) {
  113. MS_EXCEPTION_IF_NULL(node);
  114. if (!node->isa<CNode>()) {
  115. MS_LOG(EXCEPTION) << "The node is expected to be a cnode";
  116. }
  117. auto cnode = node->cast<CNodePtr>();
  118. CheckCNodeInputSize(cnode, input_size);
  119. return cnode;
  120. }
  121. void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) {
  122. MS_EXCEPTION_IF_NULL(cnode);
  123. auto real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
  124. if (real_input_tensor_num != input_tensor_size) {
  125. MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num
  126. << "] of node " + cnode->DebugString() + " is not equal to " << input_tensor_size;
  127. }
  128. }
  129. bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) {
  130. MS_EXCEPTION_IF_NULL(node_x);
  131. MS_EXCEPTION_IF_NULL(node_y);
  132. return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) &&
  133. AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0));
  134. }
  135. const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  136. MS_EXCEPTION_IF_NULL(func_graph);
  137. auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
  138. MS_EXCEPTION_IF_NULL(transop_cnode);
  139. auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(1), kDependInputTensorNum);
  140. auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputTensorNum);
  141. auto transed_node = prev_transop_cnode->input(1);
  142. MS_EXCEPTION_IF_NULL(transed_node);
  143. std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node,
  144. depend_cnode->input(kDependAttachNodeIndex)};
  145. AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs);
  146. MS_EXCEPTION_IF_NULL(replace_depend);
  147. auto transed_abstract = transed_node->abstract();
  148. replace_depend->set_abstract(transed_abstract);
  149. return replace_depend;
  150. }
  151. bool Visited(const BaseRef &n) {
  152. if (utils::isa<AnfNodePtr>(n)) {
  153. AnfNodePtr in = utils::cast<AnfNodePtr>(n);
  154. MS_EXCEPTION_IF_NULL(in);
  155. if (IsValueNode<Primitive>(in)) {
  156. auto value_node = in->cast<ValueNodePtr>();
  157. MS_EXCEPTION_IF_NULL(value_node);
  158. auto value = value_node->value();
  159. MS_EXCEPTION_IF_NULL(value);
  160. auto prim_py = value->cast<PrimitivePtr>();
  161. MS_EXCEPTION_IF_NULL(prim_py);
  162. return prim_py->HasAttr(kAttrVisited);
  163. } else if (IsValueNode<FuncGraph>(in)) {
  164. auto func_graph = GetValueNode<FuncGraphPtr>(in);
  165. MS_EXCEPTION_IF_NULL(func_graph);
  166. return func_graph->has_flag(kAttrVisited);
  167. }
  168. return false;
  169. }
  170. return false;
  171. }
  172. void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
  173. std::vector<AnfNodePtr> *outputs) {
  174. MS_EXCEPTION_IF_NULL(func_graph);
  175. MS_EXCEPTION_IF_NULL(node);
  176. MS_EXCEPTION_IF_NULL(outputs);
  177. for (size_t i = 0; i < output_num; i++) {
  178. int64_t temp = SizeToLong(i);
  179. auto idx = NewValueNode(temp);
  180. MS_EXCEPTION_IF_NULL(idx);
  181. auto imm = std::make_shared<Int64Imm>(temp);
  182. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  183. idx->set_abstract(abstract_scalar);
  184. auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  185. MS_EXCEPTION_IF_NULL(tuple_getitem);
  186. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(node, i)},
  187. {AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get());
  188. (*outputs).push_back(tuple_getitem);
  189. }
  190. }
  191. template <typename T>
  192. tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
  193. size_t data_length) {
  194. MS_EXCEPTION_IF_NULL(value_tuple_ptr);
  195. MS_EXCEPTION_IF_NULL(type_ptr);
  196. std::vector<T> values;
  197. for (const auto &v : value_tuple_ptr->value()) {
  198. MS_EXCEPTION_IF_NULL(v);
  199. if (v->isa<Scalar>()) {
  200. ScalarPtr scalar = v->cast<ScalarPtr>();
  201. values.push_back(GetValue<T>(scalar));
  202. } else {
  203. MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
  204. return nullptr;
  205. }
  206. }
  207. std::vector<int64_t> tensor_shape = {SizeToLong(values.size())};
  208. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
  209. MS_EXCEPTION_IF_NULL(tensor);
  210. tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
  211. tensor->set_device_info(device_info);
  212. auto data_ptr = tensor->data_c();
  213. MS_EXCEPTION_IF_NULL(data_ptr);
  214. auto elem_num = values.size() * data_length;
  215. auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
  216. if (ret_code != 0) {
  217. MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
  218. }
  219. return tensor;
  220. }
  221. tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
  222. MS_EXCEPTION_IF_NULL(value_tuple);
  223. tensor::TensorPtr tensor = nullptr;
  224. if (value_tuple->value().empty()) {
  225. MS_LOG(WARNING) << "The value tuple is empty.";
  226. return nullptr;
  227. }
  228. ValuePtr v = *(value_tuple->value().begin());
  229. MS_EXCEPTION_IF_NULL(v);
  230. // Currently we only deal with the scalar tuple
  231. if (!v->isa<Scalar>()) {
  232. MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
  233. return nullptr;
  234. }
  235. ScalarPtr scalar = v->cast<ScalarPtr>();
  236. MS_EXCEPTION_IF_NULL(scalar);
  237. if (scalar->isa<Int32Imm>()) {
  238. tensor = CreateTensorWithValueTuple<int32_t>(value_tuple, kInt32, sizeof(int32_t));
  239. } else if (scalar->isa<Int64Imm>()) {
  240. tensor = CreateTensorWithValueTuple<int64_t>(value_tuple, kInt64, sizeof(int64_t));
  241. } else if (scalar->isa<FloatImm>()) {
  242. tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, sizeof(float));
  243. } else {
  244. auto type = scalar->type();
  245. auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
  246. MS_LOG(ERROR) << "Invalid scalar type: " << type_str;
  247. return nullptr;
  248. }
  249. return tensor;
  250. }
  251. bool IsNopNode(const AnfNodePtr &node) {
  252. auto context_ptr = MsContext::GetInstance();
  253. MS_EXCEPTION_IF_NULL(context_ptr);
  254. auto target = GetCNodeTarget(node);
  255. if (target == kCPUDevice) {
  256. return false;
  257. }
  258. if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
  259. context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
  260. return false;
  261. }
  262. static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
  263. prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
  264. kFlattenGradOpName, prim::kPrimReformat->name()};
  265. if (node == nullptr || !node->isa<CNode>()) {
  266. return false;
  267. }
  268. CNodePtr cnode = node->cast<CNodePtr>();
  269. MS_EXCEPTION_IF_NULL(cnode);
  270. if (cnode->inputs().empty()) {
  271. return false;
  272. }
  273. auto input0 = cnode->input(0);
  274. MS_EXCEPTION_IF_NULL(input0);
  275. if (!input0->isa<ValueNode>()) {
  276. return false;
  277. }
  278. bool is_nop_node = false;
  279. if (AnfAlgo::HasNodeAttr("nop_op", cnode)) {
  280. is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, "nop_op");
  281. }
  282. if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
  283. return false;
  284. }
  285. return true;
  286. }
  287. bool IsAllNopNode(const session::KernelGraph *const graph) {
  288. MS_EXCEPTION_IF_NULL(graph);
  289. auto execution_order = graph->execution_order();
  290. for (auto &cnode : execution_order) {
  291. MS_EXCEPTION_IF_NULL(cnode);
  292. if (!IsNopNode(cnode)) {
  293. return false;
  294. }
  295. }
  296. return true;
  297. }
  298. bool CheckNopNodeIsOutputNode(const std::vector<AnfNodePtr> &outputs, const AnfNodePtr &node, bool is_dynamic_graph) {
  299. MS_EXCEPTION_IF_NULL(node);
  300. // if node is not a nop node, keep it in execution order
  301. if (!IsNopNode(node)) {
  302. return true;
  303. }
  304. // if node is nop node and the graph is dynamic graph, check if the nop node is graph's output.
  305. if (is_dynamic_graph) {
  306. auto iter = find(outputs.begin(), outputs.end(), node);
  307. if (iter != outputs.end()) {
  308. return true;
  309. }
  310. }
  311. return false;
  312. }
  313. void HideNopNode(session::KernelGraph *const graph) {
  314. MS_EXCEPTION_IF_NULL(graph);
  315. if (IsAllNopNode(graph) == true) {
  316. return;
  317. }
  318. auto execution_order = graph->execution_order();
  319. auto outputs = graph->outputs();
  320. bool is_dynamic_graph = graph->is_dynamic_shape();
  321. MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
  322. std::vector<CNodePtr> new_nodes;
  323. for (auto &cnode : execution_order) {
  324. MS_EXCEPTION_IF_NULL(cnode);
  325. if (CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) {
  326. new_nodes.push_back(cnode);
  327. }
  328. }
  329. graph->set_execution_order(new_nodes);
  330. MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size();
  331. }
  332. void RemoveNopNode(session::KernelGraph *const graph) {
  333. MS_EXCEPTION_IF_NULL(graph);
  334. if (IsAllNopNode(graph) == true) {
  335. return;
  336. }
  337. bool changed = true;
  338. while (changed) {
  339. changed = false;
  340. std::vector<CNodePtr> new_nodes;
  341. auto outputs = graph->outputs();
  342. bool is_dynamic_graph = graph->is_dynamic_shape();
  343. for (auto &cnode : graph->execution_order()) {
  344. MS_EXCEPTION_IF_NULL(cnode);
  345. // ignore nop node itself
  346. if (!CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) {
  347. continue;
  348. }
  349. // Replace the input which is nop node
  350. std::vector<AnfNodePtr> new_inputs;
  351. new_inputs.push_back(cnode->input(0));
  352. bool need_update = false;
  353. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  354. auto input = cnode->input(i);
  355. MS_EXCEPTION_IF_NULL(input);
  356. auto cinput = input->cast<CNodePtr>();
  357. if (cinput == nullptr || !IsNopNode(cinput)) {
  358. new_inputs.push_back(input);
  359. continue;
  360. }
  361. if (cinput->inputs().size() == 2) {
  362. new_inputs.push_back(cinput->input(1));
  363. need_update = true;
  364. changed = true;
  365. } else {
  366. new_inputs.push_back(input);
  367. }
  368. }
  369. if (need_update) {
  370. cnode->set_inputs(new_inputs);
  371. }
  372. // push into new execution list
  373. new_nodes.push_back(cnode);
  374. }
  375. graph->set_execution_order(new_nodes);
  376. }
  377. }
  378. size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  379. auto out_list = GetRealNodeUsedList(graph, node);
  380. MS_EXCEPTION_IF_NULL(out_list);
  381. return out_list->size();
  382. }
  383. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
  384. const AnfNodePtr &node) {
  385. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  386. MS_EXCEPTION_IF_NULL(graph);
  387. auto manager = graph->manager();
  388. MS_EXCEPTION_IF_NULL(manager);
  389. auto iter = manager->node_users().find(node);
  390. if (iter == manager->node_users().end()) {
  391. MS_LOG(EXCEPTION) << "node has no output in manager";
  392. }
  393. auto output_info_list = iter->second;
  394. for (const auto &output_info : output_info_list) {
  395. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
  396. output_info.second == kDependAttachNodeIndex) {
  397. continue;
  398. }
  399. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimUpdateState->name()) {
  400. continue;
  401. }
  402. output_node_list->push_back(output_info);
  403. }
  404. return output_node_list;
  405. }
  406. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
  407. const AnfNodePtr &node,
  408. size_t output_index) {
  409. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  410. MS_EXCEPTION_IF_NULL(graph);
  411. auto manager = graph->manager();
  412. MS_EXCEPTION_IF_NULL(manager);
  413. auto iter = manager->node_users().find(node);
  414. if (iter == manager->node_users().end()) {
  415. MS_LOG(EXCEPTION) << "node has no output in manager";
  416. }
  417. auto output_info_list = iter->second;
  418. for (const auto &output_info : output_info_list) {
  419. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
  420. continue;
  421. }
  422. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
  423. output_info.second == kDependAttachNodeIndex) {
  424. continue;
  425. }
  426. size_t used_output_index;
  427. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) {
  428. used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
  429. } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
  430. used_output_index = output_index;
  431. } else {
  432. auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(output_info.first, output_info.second - 1);
  433. if (kernel_with_index.first.get() != node.get()) {
  434. MS_LOG(EXCEPTION) << "Get used node failed for op[" << AnfAlgo::GetCNodeName(node) << "]";
  435. }
  436. used_output_index = kernel_with_index.second;
  437. }
  438. if (used_output_index == output_index) {
  439. output_node_list->push_back(output_info);
  440. }
  441. }
  442. return output_node_list;
  443. }
  444. bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  445. MS_EXCEPTION_IF_NULL(graph);
  446. MS_EXCEPTION_IF_NULL(node);
  447. auto output_node_list = GetRealNodeUsedList(graph, node);
  448. MS_EXCEPTION_IF_NULL(output_node_list);
  449. return output_node_list->size() > 1;
  450. }
  451. bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  452. MS_EXCEPTION_IF_NULL(graph);
  453. MS_EXCEPTION_IF_NULL(node);
  454. auto output_node_list = GetRealNodeUsedList(graph, node);
  455. MS_EXCEPTION_IF_NULL(output_node_list);
  456. if (output_node_list->empty()) {
  457. return true;
  458. }
  459. for (const auto &output : *output_node_list) {
  460. auto out_node = output.first;
  461. auto name = AnfAlgo::GetCNodeName(out_node);
  462. if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
  463. name == prim::kPrimTupleGetItem->name()) {
  464. auto result = IsNotRealUsedByOthers(graph, out_node);
  465. if (!result) {
  466. return result;
  467. }
  468. continue;
  469. }
  470. return false;
  471. }
  472. return true;
  473. }
  474. CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
  475. auto idx = NewValueNode(SizeToLong(output_idx));
  476. MS_EXCEPTION_IF_NULL(idx);
  477. auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx));
  478. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  479. idx->set_abstract(abstract_scalar);
  480. CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  481. MS_EXCEPTION_IF_NULL(tuple_getitem);
  482. tuple_getitem->set_scope(node->scope());
  483. std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
  484. TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
  485. AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
  486. return tuple_getitem;
  487. }
  488. void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) {
  489. MS_EXCEPTION_IF_NULL(cnode);
  490. std::vector<AnfNodePtr> new_inputs;
  491. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  492. MS_EXCEPTION_IF_NULL(primitive);
  493. primitive = primitive->Clone();
  494. auto input_names = primitive->GetAttr(kAttrInputNames);
  495. if (input_names == nullptr) {
  496. MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
  497. return;
  498. }
  499. auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
  500. auto inputs = cnode->inputs();
  501. new_inputs.push_back(inputs[0]);
  502. bool need_update = false;
  503. for (size_t i = 0; i < inputs.size() - 1; ++i) {
  504. auto input_node = inputs[i + 1];
  505. if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimDepend)) {
  506. input_node = AnfAlgo::VisitKernel(input_node, 0).first;
  507. }
  508. MS_EXCEPTION_IF_NULL(input_node);
  509. if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
  510. auto value_node = input_node->cast<ValueNodePtr>();
  511. MS_EXCEPTION_IF_NULL(value_node);
  512. MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
  513. if (i >= input_names_vec.size()) {
  514. MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
  515. }
  516. primitive->set_attr(input_names_vec[i], value_node->value());
  517. need_update = true;
  518. } else {
  519. new_inputs.push_back(inputs[i + 1]);
  520. }
  521. }
  522. if (need_update) {
  523. // Update cnode's inputs
  524. new_inputs[0] = NewValueNode(primitive);
  525. cnode->set_inputs(new_inputs);
  526. }
  527. }
  528. bool AnfEqual(const BaseRef &a, const BaseRef &b) {
  529. if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
  530. auto a_node = utils::cast<AnfNodePtr>(a);
  531. auto b_node = utils::cast<AnfNodePtr>(b);
  532. MS_EXCEPTION_IF_NULL(a_node);
  533. MS_EXCEPTION_IF_NULL(b_node);
  534. if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
  535. auto a_value_node = a_node->cast<ValueNodePtr>();
  536. MS_EXCEPTION_IF_NULL(a_value_node);
  537. auto a_value = a_value_node->value();
  538. MS_EXCEPTION_IF_NULL(a_value);
  539. auto a_prim = a_value->cast<PrimitivePtr>();
  540. MS_EXCEPTION_IF_NULL(a_prim);
  541. auto b_value_node = b_node->cast<ValueNodePtr>();
  542. MS_EXCEPTION_IF_NULL(b_value_node);
  543. auto b_value = b_value_node->value();
  544. MS_EXCEPTION_IF_NULL(b_value);
  545. auto b_prim = b_value->cast<PrimitivePtr>();
  546. MS_EXCEPTION_IF_NULL(b_prim);
  547. return a_prim->name() == b_prim->name();
  548. } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
  549. auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
  550. if (a_value_node_ptr == nullptr) {
  551. MS_LOG(EXCEPTION) << "cast value node ptr fail";
  552. }
  553. auto a_value_ptr = a_value_node_ptr->value();
  554. if (a_value_ptr == nullptr) {
  555. MS_LOG(EXCEPTION) << "value ptr is nullptr";
  556. }
  557. auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
  558. if (b_value_node_ptr == nullptr) {
  559. MS_LOG(EXCEPTION) << "cast value node ptr fail";
  560. }
  561. auto b_value_ptr = b_value_node_ptr->value();
  562. if (b_value_ptr == nullptr) {
  563. MS_LOG(EXCEPTION) << "value ptr is nullptr";
  564. }
  565. return (*a_value_ptr) == (*b_value_ptr);
  566. }
  567. MS_LOG(DEBUG) << "check AnfNodePtr equal";
  568. }
  569. if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
  570. MS_LOG(DEBUG) << "check GraphPtr equal";
  571. }
  572. return a == b;
  573. }
  574. bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  575. // To matchCNode and Kernel's type
  576. if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
  577. return true;
  578. }
  579. return a.type() == b.type();
  580. }
  581. namespace {
  582. ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
  583. if (utils::isa<int>(sexp)) {
  584. return NewValueNode(utils::cast<int>(sexp));
  585. }
  586. if (utils::isa<int64_t>(sexp)) {
  587. return NewValueNode(utils::cast<int64_t>(sexp));
  588. }
  589. if (utils::isa<float>(sexp)) {
  590. return NewValueNode(utils::cast<float>(sexp));
  591. }
  592. if (utils::isa<bool>(sexp)) {
  593. return NewValueNode(utils::cast<bool>(sexp));
  594. }
  595. if (utils::isa<ValuePtr>(sexp)) {
  596. return NewValueNode(utils::cast<ValuePtr>(sexp));
  597. }
  598. return nullptr;
  599. }
  600. CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
  601. if (utils::isa<FuncGraphPtr>(graph)) {
  602. return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
  603. }
  604. if (utils::isa<VarPtr>(graph)) {
  605. return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
  606. }
  607. return nullptr;
  608. }
  609. VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
  610. if (utils::isa<VarPtr>(graph)) {
  611. MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
  612. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
  613. }
  614. if (utils::isa<FuncGraphPtr>(graph)) {
  615. MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
  616. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
  617. }
  618. MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
  619. return nullptr;
  620. }
  621. AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
  622. bool multigraph) {
  623. MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  624. std::vector<AnfNodePtr> input_nodes;
  625. const auto &tuple = utils::cast<VectorRef>(sexp);
  626. if (multigraph && utils::isa<VarPtr>(graph)) {
  627. for (auto &x : tuple) {
  628. AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
  629. input_nodes.push_back(node);
  630. }
  631. VarPtr var_ptr = utils::cast<VarPtr>(graph);
  632. return std::make_shared<CNode>(input_nodes, var_ptr);
  633. }
  634. for (auto &x : tuple) {
  635. AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
  636. input_nodes.push_back(node);
  637. }
  638. return CreateCNodeWithGraph(input_nodes, graph);
  639. }
  640. } // namespace
  641. AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
  642. MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  643. MS_EXCEPTION_IF_NULL(primitive_vars);
  644. if (utils::isa<VectorRef>(sexp)) {
  645. return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
  646. }
  647. if (utils::isa<VarPtr>(sexp)) {
  648. auto var_ptr = utils::cast<VarPtr>(sexp);
  649. MS_EXCEPTION_IF_NULL(var_ptr);
  650. if (var_ptr->primitive()) {
  651. (*primitive_vars)[var_ptr->primitive()] = var_ptr;
  652. return NewValueNode(var_ptr->primitive());
  653. }
  654. return CreateVarNodeWithSexp(sexp, graph);
  655. }
  656. if (utils::isa<AnfNodePtr>(sexp)) {
  657. return utils::cast<AnfNodePtr>(sexp);
  658. }
  659. auto value_node = CreateValueNodeWithSexp(sexp);
  660. if (value_node == nullptr) {
  661. MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
  662. }
  663. return value_node;
  664. }
  665. bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) {
  666. MS_EXCEPTION_IF_NULL(equiv1);
  667. MS_EXCEPTION_IF_NULL(equiv2);
  668. MS_EXCEPTION_IF_NULL(var_node);
  669. auto equiv1_node = GetAnfNodeByVar(equiv1, var_node);
  670. MS_EXCEPTION_IF_NULL(equiv1_node);
  671. auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
  672. MS_EXCEPTION_IF_NULL(equiv2_node);
  673. return *equiv1_node == *equiv2_node;
  674. }
  675. AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
  676. MS_EXCEPTION_IF_NULL(equiv);
  677. MS_EXCEPTION_IF_NULL(var_node);
  678. auto iter = (*equiv).find(var_node);
  679. if (iter == (*equiv).end()) {
  680. MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched.";
  681. return nullptr;
  682. }
  683. auto res = utils::cast<AnfNodePtr>(iter->second);
  684. if (res == nullptr) {
  685. MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node";
  686. }
  687. return res;
  688. }
  689. bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
  690. MS_EXCEPTION_IF_NULL(n1);
  691. MS_EXCEPTION_IF_NULL(n2);
  692. auto n1_cnode = n1->cast<CNodePtr>();
  693. auto n2_cnode = n2->cast<CNodePtr>();
  694. MS_EXCEPTION_IF_NULL(n1_cnode);
  695. MS_EXCEPTION_IF_NULL(n2_cnode);
  696. auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem);
  697. MS_EXCEPTION_IF_NULL(index_input1);
  698. auto value_node1 = index_input1->cast<ValueNodePtr>();
  699. MS_EXCEPTION_IF_NULL(value_node1);
  700. auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem);
  701. MS_EXCEPTION_IF_NULL(index_input2);
  702. auto value_node2 = index_input2->cast<ValueNodePtr>();
  703. MS_EXCEPTION_IF_NULL(value_node2);
  704. return GetValue<int64_t>(value_node1->value()) < GetValue<int64_t>(value_node2->value());
  705. }
  706. bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
  707. MS_EXCEPTION_IF_NULL(node);
  708. if (!node->isa<CNode>()) {
  709. MS_LOG(INFO) << "node is not a cnode";
  710. return false;
  711. }
  712. auto cnode = node->cast<CNodePtr>();
  713. MS_EXCEPTION_IF_NULL(cnode);
  714. return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
  715. }
  716. bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
  717. MS_EXCEPTION_IF_NULL(node);
  718. TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0);
  719. if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
  720. return true;
  721. }
  722. MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
  723. return false;
  724. }
  725. ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
  726. MS_EXCEPTION_IF_NULL(value_node);
  727. ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
  728. new_value_node->set_abstract(value_node->abstract());
  729. // create kernel_info fo new value node
  730. auto kernel_info = std::make_shared<device::KernelInfo>();
  731. new_value_node->set_kernel_info(kernel_info);
  732. // create kernel_build_info for new value node
  733. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  734. // set the format of value_node to DEFAULT_FORMAT
  735. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  736. // set value node initial device data type = infer data type
  737. std::vector<TypeId> types;
  738. size_t output_num = AnfAlgo::GetOutputTensorNum(value_node);
  739. for (size_t index = 0; index < output_num; ++index) {
  740. types.push_back(kTypeUnknown);
  741. }
  742. kernel_build_info_builder->SetOutputsDeviceType(types);
  743. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  744. return new_value_node;
  745. }
  746. void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
  747. MS_EXCEPTION_IF_NULL(old_node);
  748. MS_EXCEPTION_IF_NULL(graph);
  749. auto manager = graph->manager();
  750. MS_EXCEPTION_IF_NULL(manager);
  751. // find BatchNorm's output which is a Depend or ControlDepend
  752. for (const auto &node_index : manager->node_users()[old_node]) {
  753. AnfNodePtr output = node_index.first;
  754. size_t index = IntToSize(node_index.second);
  755. MS_EXCEPTION_IF_NULL(output);
  756. if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
  757. auto control_depend = output->cast<CNodePtr>();
  758. MS_EXCEPTION_IF_NULL(control_depend);
  759. control_depend->set_input(index, new_node);
  760. } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
  761. auto depend = output->cast<CNodePtr>();
  762. MS_EXCEPTION_IF_NULL(depend);
  763. depend->set_input(index, new_node);
  764. }
  765. }
  766. }
  767. } // namespace opt
  768. } // namespace mindspore