You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cc 44 kB

5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/common/helper.h"
  17. #include <string>
  18. #include <utility>
  19. #include <algorithm>
  20. #include <map>
  21. #include <set>
  22. #include <deque>
  23. #include "utils/hash_set.h"
  24. #include "utils/utils.h"
  25. #include "base/base_ref.h"
  26. #include "backend/session/anf_runtime_algorithm.h"
  27. #include "base/core_ops.h"
  28. #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
  29. #include "frontend/operator/ops.h"
  30. #include "utils/ms_utils.h"
  31. #include "utils/convert_utils.h"
  32. #include "runtime/device/kernel_info.h"
  33. #include "utils/ms_context.h"
  34. #include "utils/trace_base.h"
  35. #include "backend/optimizer/common/const_input_to_attr_registry.h"
  36. #include "abstract/primitive_infer_map.h"
  37. namespace mindspore {
  38. namespace opt {
  39. namespace {
  40. constexpr size_t kType32Len = 4;
  41. constexpr size_t kType64Len = 8;
  42. void UpdateDumpFlagAndDebugInfo(const CNodePtr &node, const std::vector<AnfNodePtr> &orig_nodes) {
  43. std::vector<AnfNodePtr> orig_real_cnodes;
  44. for (auto &orig_node : orig_nodes) {
  45. if (AnfUtils::IsRealCNodeKernel(orig_node)) {
  46. auto orig_cnode = orig_node->cast<CNodePtr>();
  47. if (AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
  48. AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
  49. }
  50. orig_real_cnodes.push_back(orig_node);
  51. }
  52. }
  53. node->AddFusedDebugInfoList(orig_real_cnodes);
  54. }
  55. } // namespace
  56. std::vector<int64_t> Convert2Int(const std::vector<size_t> &v) {
  57. std::vector<int64_t> result;
  58. (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
  59. return result;
  60. }
  61. std::vector<int64_t> Convert2Long(const std::vector<size_t> &v) {
  62. std::vector<int64_t> result;
  63. (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToLong);
  64. return result;
  65. }
  66. bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
  67. MS_EXCEPTION_IF_NULL(node);
  68. FuncGraphManagerPtr manager = graph.manager();
  69. MS_EXCEPTION_IF_NULL(manager);
  70. mindspore::HashSet<AnfNodePtr> seen_node;
  71. std::deque<AnfNodePtr> todo{node};
  72. while (!todo.empty()) {
  73. AnfNodePtr nd = todo.front();
  74. todo.pop_front();
  75. if (seen_node.count(nd) > 0 || !manager->all_nodes().contains(nd)) {
  76. continue;
  77. }
  78. (void)seen_node.insert(nd);
  79. if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) {
  80. return true;
  81. }
  82. if (nd->isa<CNode>()) {
  83. auto cnode = nd->cast<CNodePtr>();
  84. MS_EXCEPTION_IF_NULL(cnode);
  85. auto inputs = cnode->inputs();
  86. (void)todo.insert(todo.end(), inputs.begin(), inputs.end());
  87. }
  88. }
  89. return false;
  90. }
  91. bool UnVisited(const BaseRef &n) {
  92. if (utils::isa<AnfNodePtr>(n)) {
  93. AnfNodePtr in = utils::cast<AnfNodePtr>(n);
  94. MS_EXCEPTION_IF_NULL(in);
  95. if (IsValueNode<Primitive>(in)) {
  96. auto value_node = in->cast<ValueNodePtr>();
  97. MS_EXCEPTION_IF_NULL(value_node);
  98. auto value = value_node->value();
  99. MS_EXCEPTION_IF_NULL(value);
  100. auto prim_py = value->cast<PrimitivePtr>();
  101. MS_EXCEPTION_IF_NULL(prim_py);
  102. return !prim_py->HasAttr(kAttrVisited);
  103. } else if (IsValueNode<FuncGraph>(in)) {
  104. auto func_graph = GetValueNode<FuncGraphPtr>(in);
  105. MS_EXCEPTION_IF_NULL(func_graph);
  106. return !func_graph->has_flag(kAttrVisited);
  107. }
  108. return false;
  109. }
  110. return false;
  111. }
  112. CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
  113. const std::vector<AnfNodePtr> &orig_nodes) {
  114. MS_EXCEPTION_IF_NULL(fg);
  115. auto node = fg->NewCNode(inputs);
  116. MS_EXCEPTION_IF_NULL(node);
  117. UpdateDumpFlagAndDebugInfo(node, orig_nodes);
  118. return node;
  119. }
  120. CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes) {
  121. MS_EXCEPTION_IF_NULL(fg);
  122. auto node = fg->NewCNode(cnode);
  123. MS_EXCEPTION_IF_NULL(node);
  124. UpdateDumpFlagAndDebugInfo(node, orig_nodes);
  125. return node;
  126. }
  127. CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size) {
  128. MS_EXCEPTION_IF_NULL(node);
  129. if (!node->isa<CNode>()) {
  130. MS_LOG(EXCEPTION) << "The node is expected to be a cnode";
  131. }
  132. auto cnode = node->cast<CNodePtr>();
  133. CheckCNodeInputSize(cnode, input_size);
  134. return cnode;
  135. }
  136. void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) {
  137. MS_EXCEPTION_IF_NULL(cnode);
  138. auto real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
  139. if (real_input_tensor_num != input_tensor_size) {
  140. MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num
  141. << "] of node [" + cnode->DebugString() + "] is not equal to " << input_tensor_size
  142. << trace::DumpSourceLines(cnode);
  143. }
  144. }
  145. bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) {
  146. MS_EXCEPTION_IF_NULL(node_x);
  147. MS_EXCEPTION_IF_NULL(node_y);
  148. return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) &&
  149. AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0));
  150. }
  151. const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  152. MS_EXCEPTION_IF_NULL(func_graph);
  153. auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
  154. MS_EXCEPTION_IF_NULL(transop_cnode);
  155. auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(1), kDependInputTensorNum);
  156. auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputTensorNum);
  157. auto transed_node = prev_transop_cnode->input(1);
  158. MS_EXCEPTION_IF_NULL(transed_node);
  159. std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node,
  160. depend_cnode->input(kDependAttachNodeIndex)};
  161. AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs);
  162. MS_EXCEPTION_IF_NULL(replace_depend);
  163. auto transed_abstract = transed_node->abstract();
  164. replace_depend->set_abstract(transed_abstract);
  165. return replace_depend;
  166. }
  167. bool Visited(const BaseRef &n) {
  168. if (utils::isa<AnfNodePtr>(n)) {
  169. AnfNodePtr in = utils::cast<AnfNodePtr>(n);
  170. MS_EXCEPTION_IF_NULL(in);
  171. if (IsValueNode<Primitive>(in)) {
  172. auto value_node = in->cast<ValueNodePtr>();
  173. MS_EXCEPTION_IF_NULL(value_node);
  174. auto value = value_node->value();
  175. MS_EXCEPTION_IF_NULL(value);
  176. auto prim_py = value->cast<PrimitivePtr>();
  177. MS_EXCEPTION_IF_NULL(prim_py);
  178. return prim_py->HasAttr(kAttrVisited);
  179. } else if (IsValueNode<FuncGraph>(in)) {
  180. auto func_graph = GetValueNode<FuncGraphPtr>(in);
  181. MS_EXCEPTION_IF_NULL(func_graph);
  182. return func_graph->has_flag(kAttrVisited);
  183. }
  184. return false;
  185. }
  186. return false;
  187. }
  188. void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
  189. std::vector<AnfNodePtr> *outputs) {
  190. MS_EXCEPTION_IF_NULL(func_graph);
  191. MS_EXCEPTION_IF_NULL(node);
  192. MS_EXCEPTION_IF_NULL(outputs);
  193. auto type_ptr = node->Type();
  194. auto shape_ptr = node->Shape();
  195. for (size_t i = 0; i < output_num; i++) {
  196. int64_t temp = SizeToLong(i);
  197. auto idx = NewValueNode(temp);
  198. MS_EXCEPTION_IF_NULL(idx);
  199. auto imm = std::make_shared<Int64Imm>(temp);
  200. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  201. idx->set_abstract(abstract_scalar);
  202. auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  203. MS_EXCEPTION_IF_NULL(tuple_getitem);
  204. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(type_ptr, i)},
  205. {AnfAlgo::GetOutputInferShape(node, shape_ptr, i)}, tuple_getitem.get());
  206. (*outputs).push_back(tuple_getitem);
  207. }
  208. }
  209. template <typename T>
  210. tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
  211. size_t data_length) {
  212. MS_EXCEPTION_IF_NULL(value_tuple_ptr);
  213. MS_EXCEPTION_IF_NULL(type_ptr);
  214. std::vector<T> values;
  215. for (const auto &v : value_tuple_ptr->value()) {
  216. MS_EXCEPTION_IF_NULL(v);
  217. if (v->isa<Scalar>()) {
  218. ScalarPtr scalar = v->cast<ScalarPtr>();
  219. values.push_back(GetValue<T>(scalar));
  220. } else {
  221. MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
  222. return nullptr;
  223. }
  224. }
  225. std::vector<int64_t> tensor_shape = {SizeToLong(values.size())};
  226. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
  227. MS_EXCEPTION_IF_NULL(tensor);
  228. tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
  229. tensor->set_device_info(device_info);
  230. auto data_ptr = tensor->data_c();
  231. MS_EXCEPTION_IF_NULL(data_ptr);
  232. auto elem_num = values.size() * data_length;
  233. auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
  234. if (ret_code != 0) {
  235. MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
  236. }
  237. return tensor;
  238. }
  239. tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
  240. MS_EXCEPTION_IF_NULL(value_tuple);
  241. tensor::TensorPtr tensor = nullptr;
  242. if (value_tuple->value().empty()) {
  243. MS_LOG(WARNING) << "The value tuple is empty.";
  244. return nullptr;
  245. }
  246. ValuePtr v = *(value_tuple->value().begin());
  247. MS_EXCEPTION_IF_NULL(v);
  248. // Currently we only deal with the scalar tuple
  249. if (!v->isa<Scalar>()) {
  250. MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
  251. return nullptr;
  252. }
  253. ScalarPtr scalar = v->cast<ScalarPtr>();
  254. MS_EXCEPTION_IF_NULL(scalar);
  255. if (scalar->isa<Int32Imm>()) {
  256. tensor = CreateTensorWithValueTuple<int32_t>(value_tuple, kInt32, sizeof(int32_t));
  257. } else if (scalar->isa<Int64Imm>()) {
  258. tensor = CreateTensorWithValueTuple<int64_t>(value_tuple, kInt64, sizeof(int64_t));
  259. } else if (scalar->isa<FloatImm>()) {
  260. tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, sizeof(float));
  261. } else {
  262. auto type = scalar->type();
  263. auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
  264. MS_LOG(ERROR) << "Invalid scalar type: " << type_str;
  265. return nullptr;
  266. }
  267. return tensor;
  268. }
  269. bool IsNopNode(const AnfNodePtr &node) {
  270. auto context_ptr = MsContext::GetInstance();
  271. MS_EXCEPTION_IF_NULL(context_ptr);
  272. auto target = GetCNodeTarget(node);
  273. if (target == kCPUDevice) {
  274. return false;
  275. }
  276. if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
  277. context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
  278. return false;
  279. }
  280. static mindspore::HashSet<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
  281. prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
  282. kFlattenGradOpName, prim::kPrimReformat->name()};
  283. if (node == nullptr || !node->isa<CNode>()) {
  284. return false;
  285. }
  286. CNodePtr cnode = node->cast<CNodePtr>();
  287. MS_EXCEPTION_IF_NULL(cnode);
  288. if (cnode->inputs().empty()) {
  289. return false;
  290. }
  291. auto input0 = cnode->input(0);
  292. MS_EXCEPTION_IF_NULL(input0);
  293. if (!input0->isa<ValueNode>()) {
  294. return false;
  295. }
  296. bool is_nop_node = false;
  297. if (AnfAlgo::HasNodeAttr("nop_op", cnode)) {
  298. is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, "nop_op");
  299. }
  300. if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
  301. return false;
  302. }
  303. return true;
  304. }
  305. bool IsAllNopNode(const session::KernelGraph *const graph) {
  306. MS_EXCEPTION_IF_NULL(graph);
  307. auto execution_order = graph->execution_order();
  308. for (auto &cnode : execution_order) {
  309. MS_EXCEPTION_IF_NULL(cnode);
  310. if (!IsNopNode(cnode)) {
  311. return false;
  312. }
  313. }
  314. return true;
  315. }
  316. bool NeedHideNode(const std::vector<AnfNodePtr> &outputs, const AnfNodePtr &node, bool is_dynamic_graph) {
  317. MS_EXCEPTION_IF_NULL(node);
  318. // if node is not a nop node, keep it in execution order
  319. if (!IsNopNode(node)) {
  320. return false;
  321. }
  322. // if node is nop node and the graph is dynamic graph, check if the nop node is graph's output.
  323. if (is_dynamic_graph) {
  324. auto iter = find(outputs.begin(), outputs.end(), node);
  325. if (iter != outputs.end()) {
  326. return false;
  327. }
  328. }
  329. return true;
  330. }
  331. void HideNopNode(session::KernelGraph *const graph) {
  332. MS_EXCEPTION_IF_NULL(graph);
  333. if (IsAllNopNode(graph) == true) {
  334. return;
  335. }
  336. auto execution_order = graph->execution_order();
  337. auto outputs = graph->outputs();
  338. bool is_dynamic_graph = graph->is_dynamic_shape();
  339. MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
  340. std::vector<CNodePtr> new_nodes;
  341. for (auto &cnode : execution_order) {
  342. MS_EXCEPTION_IF_NULL(cnode);
  343. if (NeedHideNode(outputs, cnode, is_dynamic_graph)) {
  344. AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
  345. } else {
  346. new_nodes.push_back(cnode);
  347. }
  348. }
  349. graph->set_execution_order(new_nodes);
  350. MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size();
  351. }
  352. void RemoveNopNode(session::KernelGraph *const graph) {
  353. MS_EXCEPTION_IF_NULL(graph);
  354. if (IsAllNopNode(graph) == true) {
  355. return;
  356. }
  357. bool changed = true;
  358. while (changed) {
  359. changed = false;
  360. std::vector<CNodePtr> new_nodes;
  361. auto outputs = graph->outputs();
  362. bool is_dynamic_graph = graph->is_dynamic_shape();
  363. for (auto &cnode : graph->execution_order()) {
  364. MS_EXCEPTION_IF_NULL(cnode);
  365. // ignore nop node itself
  366. if (NeedHideNode(outputs, cnode, is_dynamic_graph)) {
  367. AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
  368. continue;
  369. }
  370. // Replace the input which is nop node
  371. std::vector<AnfNodePtr> new_inputs;
  372. new_inputs.push_back(cnode->input(0));
  373. bool need_update = false;
  374. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  375. auto input = cnode->input(i);
  376. MS_EXCEPTION_IF_NULL(input);
  377. auto cinput = input->cast<CNodePtr>();
  378. if (cinput == nullptr || !IsNopNode(cinput)) {
  379. new_inputs.push_back(input);
  380. continue;
  381. }
  382. constexpr auto kInputSize = 2;
  383. if (cinput->inputs().size() == kInputSize) {
  384. new_inputs.push_back(cinput->input(1));
  385. need_update = true;
  386. changed = true;
  387. } else {
  388. new_inputs.push_back(input);
  389. }
  390. }
  391. if (need_update) {
  392. cnode->set_inputs(new_inputs);
  393. }
  394. // push into new execution list
  395. new_nodes.push_back(cnode);
  396. }
  397. graph->set_execution_order(new_nodes);
  398. }
  399. }
  400. size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  401. auto out_list = GetRealNodeUsedList(graph, node);
  402. MS_EXCEPTION_IF_NULL(out_list);
  403. return out_list->size();
  404. }
  405. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
  406. const AnfNodePtr &node) {
  407. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  408. MS_EXCEPTION_IF_NULL(graph);
  409. auto manager = graph->manager();
  410. MS_EXCEPTION_IF_NULL(manager);
  411. auto iter = manager->node_users().find(node);
  412. if (iter == manager->node_users().end()) {
  413. return output_node_list;
  414. }
  415. auto output_info_list = iter->second;
  416. for (const auto &output_info : output_info_list) {
  417. auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
  418. if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
  419. (cnode_name == prim::kPrimUpdateState->name())) {
  420. continue;
  421. }
  422. output_node_list->push_back(output_info);
  423. }
  424. return output_node_list;
  425. }
  426. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
  427. const AnfNodePtr &node,
  428. size_t output_index) {
  429. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  430. MS_EXCEPTION_IF_NULL(graph);
  431. auto manager = graph->manager();
  432. MS_EXCEPTION_IF_NULL(manager);
  433. auto iter = manager->node_users().find(node);
  434. if (iter == manager->node_users().end()) {
  435. MS_LOG(EXCEPTION) << "node has no output in manager";
  436. }
  437. auto output_info_list = iter->second;
  438. for (const auto &output_info : output_info_list) {
  439. auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
  440. if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
  441. (cnode_name == prim::kPrimUpdateState->name())) {
  442. continue;
  443. }
  444. size_t used_output_index;
  445. if (cnode_name == prim::kPrimTupleGetItem->name()) {
  446. used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
  447. } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
  448. used_output_index = output_index;
  449. } else {
  450. auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(output_info.first, IntToSize(output_info.second - 1));
  451. if (kernel_with_index.first.get() != node.get()) {
  452. MS_LOG(EXCEPTION) << "Get used node failed for op[" << AnfAlgo::GetCNodeName(node) << "]";
  453. }
  454. used_output_index = kernel_with_index.second;
  455. }
  456. if (used_output_index == output_index) {
  457. output_node_list->push_back(output_info);
  458. }
  459. }
  460. return output_node_list;
  461. }
  462. bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  463. MS_EXCEPTION_IF_NULL(graph);
  464. MS_EXCEPTION_IF_NULL(node);
  465. auto output_node_list = GetRealNodeUsedList(graph, node);
  466. MS_EXCEPTION_IF_NULL(output_node_list);
  467. return output_node_list->size() > 1;
  468. }
  469. bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  470. MS_EXCEPTION_IF_NULL(graph);
  471. MS_EXCEPTION_IF_NULL(node);
  472. auto output_node_list = GetRealNodeUsedList(graph, node);
  473. MS_EXCEPTION_IF_NULL(output_node_list);
  474. if (output_node_list->empty()) {
  475. return true;
  476. }
  477. for (const auto &output : *output_node_list) {
  478. auto out_node = output.first;
  479. auto name = AnfAlgo::GetCNodeName(out_node);
  480. if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
  481. name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name()) {
  482. auto result = IsNotRealUsedByOthers(graph, out_node);
  483. if (!result) {
  484. return result;
  485. }
  486. continue;
  487. }
  488. return false;
  489. }
  490. return true;
  491. }
  492. CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
  493. MS_EXCEPTION_IF_NULL(func_graph);
  494. auto idx = NewValueNode(SizeToLong(output_idx));
  495. MS_EXCEPTION_IF_NULL(idx);
  496. auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx));
  497. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  498. idx->set_abstract(abstract_scalar);
  499. CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  500. MS_EXCEPTION_IF_NULL(tuple_getitem);
  501. tuple_getitem->set_scope(node->scope());
  502. auto abs = node->abstract()->cast<abstract::AbstractTuplePtr>();
  503. MS_EXCEPTION_IF_NULL(abs);
  504. auto abs_i = abs->elements()[output_idx];
  505. MS_EXCEPTION_IF_NULL(abs_i);
  506. tuple_getitem->set_abstract(abs_i);
  507. return tuple_getitem;
  508. }
  509. ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape, bool to_tensor) {
  510. MS_EXCEPTION_IF_NULL(func_graph);
  511. auto kernel_graph = func_graph->cast<KernelGraphPtr>();
  512. MS_EXCEPTION_IF_NULL(kernel_graph);
  513. ValuePtr shape_value = nullptr;
  514. AbstractBasePtr abstract = nullptr;
  515. if (to_tensor) {
  516. // create Tensor
  517. int64_t shape_dim = SizeToLong(shape.size());
  518. std::vector<int64_t> shape_vec_shape = {shape_dim};
  519. auto shape_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, shape_vec_shape);
  520. MS_EXCEPTION_IF_NULL(shape_tensor);
  521. auto data_ptr = shape_tensor->data_c();
  522. MS_EXCEPTION_IF_NULL(data_ptr);
  523. auto elem_num = shape.size() * kType64Len;
  524. auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num);
  525. if (ret_code != 0) {
  526. MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
  527. return nullptr;
  528. }
  529. shape_value = shape_tensor;
  530. abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
  531. } else {
  532. // create ValueTuple
  533. std::vector<ValuePtr> dim_values{};
  534. abstract::AbstractBasePtrList abs{};
  535. for (const auto &dim : shape) {
  536. dim_values.push_back(MakeValue(dim));
  537. abs.push_back(std::make_shared<abstract::AbstractScalar>(dim));
  538. }
  539. shape_value = std::make_shared<ValueTuple>(dim_values);
  540. abstract = std::make_shared<abstract::AbstractTuple>(abs);
  541. }
  542. MS_EXCEPTION_IF_NULL(shape_value);
  543. MS_EXCEPTION_IF_NULL(abstract);
  544. auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value);
  545. MS_EXCEPTION_IF_NULL(shape_value_node);
  546. kernel_graph->AddValueNodeToGraph(shape_value_node);
  547. return shape_value_node;
  548. }
  549. void ConstInputToAttr(const CNodePtr &cnode, const mindspore::HashSet<size_t> &input_attrs) {
  550. MS_EXCEPTION_IF_NULL(cnode);
  551. std::vector<AnfNodePtr> new_inputs;
  552. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  553. MS_EXCEPTION_IF_NULL(primitive);
  554. primitive = primitive->Clone();
  555. auto input_names = primitive->GetAttr(kAttrInputNames);
  556. if (input_names == nullptr) {
  557. MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
  558. return;
  559. }
  560. auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
  561. auto inputs = cnode->inputs();
  562. new_inputs.push_back(inputs[0]);
  563. bool need_update = false;
  564. for (size_t i = 0; i < inputs.size() - 1; ++i) {
  565. auto input_node = inputs[i + 1];
  566. if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimDepend)) {
  567. input_node = AnfAlgo::VisitKernel(input_node, 0).first;
  568. }
  569. MS_EXCEPTION_IF_NULL(input_node);
  570. if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>() && !HasAbstractMonad(input_node)) {
  571. auto value_node = input_node->cast<ValueNodePtr>();
  572. MS_EXCEPTION_IF_NULL(value_node);
  573. MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
  574. if (i >= input_names_vec.size()) {
  575. MS_LOG(EXCEPTION) << "Index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
  576. }
  577. auto value = value_node->value();
  578. if (value->isa<tensor::Tensor>()) {
  579. auto tensor = value->cast<tensor::TensorPtr>();
  580. if (tensor->data().const_data() == nullptr) {
  581. need_update = false;
  582. break;
  583. }
  584. }
  585. primitive->set_attr(input_names_vec[i], value);
  586. need_update = true;
  587. } else {
  588. new_inputs.push_back(inputs[i + 1]);
  589. }
  590. }
  591. if (need_update) {
  592. // Update cnode's inputs
  593. new_inputs[0] = NewValueNode(primitive);
  594. cnode->set_inputs(new_inputs);
  595. }
  596. }
  597. bool AnfEqual(const BaseRef &a, const BaseRef &b) {
  598. if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
  599. auto a_node = utils::cast<AnfNodePtr>(a);
  600. auto b_node = utils::cast<AnfNodePtr>(b);
  601. MS_EXCEPTION_IF_NULL(a_node);
  602. MS_EXCEPTION_IF_NULL(b_node);
  603. if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
  604. auto a_value_node = a_node->cast<ValueNodePtr>();
  605. MS_EXCEPTION_IF_NULL(a_value_node);
  606. auto a_value = a_value_node->value();
  607. MS_EXCEPTION_IF_NULL(a_value);
  608. auto a_prim = a_value->cast<PrimitivePtr>();
  609. MS_EXCEPTION_IF_NULL(a_prim);
  610. auto b_value_node = b_node->cast<ValueNodePtr>();
  611. MS_EXCEPTION_IF_NULL(b_value_node);
  612. auto b_value = b_value_node->value();
  613. MS_EXCEPTION_IF_NULL(b_value);
  614. auto b_prim = b_value->cast<PrimitivePtr>();
  615. MS_EXCEPTION_IF_NULL(b_prim);
  616. return a_prim->name() == b_prim->name();
  617. } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
  618. auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
  619. if (a_value_node_ptr == nullptr) {
  620. MS_LOG(EXCEPTION) << "Cast value node ptr fail.";
  621. }
  622. auto a_value_ptr = a_value_node_ptr->value();
  623. if (a_value_ptr == nullptr) {
  624. MS_LOG(EXCEPTION) << "Value ptr is nullptr.";
  625. }
  626. auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
  627. if (b_value_node_ptr == nullptr) {
  628. MS_LOG(EXCEPTION) << "Cast value node ptr fail.";
  629. }
  630. auto b_value_ptr = b_value_node_ptr->value();
  631. if (b_value_ptr == nullptr) {
  632. MS_LOG(EXCEPTION) << "Value ptr is nullptr.";
  633. }
  634. return (*a_value_ptr) == (*b_value_ptr);
  635. }
  636. MS_LOG(DEBUG) << "check AnfNodePtr equal";
  637. }
  638. if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
  639. MS_LOG(DEBUG) << "check GraphPtr equal";
  640. }
  641. return a == b;
  642. }
  643. bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  644. // To matchCNode and Kernel's type
  645. if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
  646. return true;
  647. }
  648. return a.type() == b.type();
  649. }
  650. namespace {
  651. ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp, PrimitiveVarMap *primitive_vars) {
  652. if (utils::isa<int>(sexp)) {
  653. return NewValueNode(utils::cast<int>(sexp));
  654. }
  655. if (utils::isa<int64_t>(sexp)) {
  656. return NewValueNode(utils::cast<int64_t>(sexp));
  657. }
  658. if (utils::isa<float>(sexp)) {
  659. return NewValueNode(utils::cast<float>(sexp));
  660. }
  661. if (utils::isa<bool>(sexp)) {
  662. return NewValueNode(utils::cast<bool>(sexp));
  663. }
  664. if (utils::isa<ValuePtr>(sexp)) {
  665. auto value = utils::cast<ValuePtr>(sexp);
  666. if (utils::isa<PrimitivePtr>(sexp)) {
  667. auto prim = utils::cast<PrimitivePtr>(sexp);
  668. if (primitive_vars->find(prim) != primitive_vars->end()) {
  669. prim = std::make_shared<Primitive>(prim->name());
  670. value = prim;
  671. }
  672. (*primitive_vars)[prim] = std::make_shared<Var>(prim);
  673. }
  674. return NewValueNode(value);
  675. }
  676. return nullptr;
  677. }
  678. CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
  679. if (utils::isa<FuncGraphPtr>(graph)) {
  680. return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
  681. }
  682. if (utils::isa<VarPtr>(graph)) {
  683. return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
  684. }
  685. return nullptr;
  686. }
  687. VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
  688. if (utils::isa<VarPtr>(graph)) {
  689. MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
  690. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
  691. }
  692. if (utils::isa<FuncGraphPtr>(graph)) {
  693. MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
  694. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
  695. }
  696. MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
  697. return nullptr;
  698. }
  699. AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
  700. bool multigraph) {
  701. MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  702. std::vector<AnfNodePtr> input_nodes;
  703. const auto &tuple = utils::cast<VectorRef>(sexp);
  704. if (multigraph && utils::isa<VarPtr>(graph)) {
  705. for (auto &x : tuple) {
  706. AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
  707. input_nodes.push_back(node);
  708. }
  709. VarPtr var_ptr = utils::cast<VarPtr>(graph);
  710. return std::make_shared<CNode>(input_nodes, var_ptr);
  711. }
  712. for (auto &x : tuple) {
  713. AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
  714. input_nodes.push_back(node);
  715. }
  716. return CreateCNodeWithGraph(input_nodes, graph);
  717. }
  718. // rectify absttract if the input has been converted to the attr
  719. AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
  720. const AbstractBasePtrList &input_abstract) {
  721. MS_EXCEPTION_IF_NULL(primitive);
  722. opt::ConstInputToAttrInfoRegister reg;
  723. if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), &reg)) {
  724. return input_abstract;
  725. }
  726. if (AnfAlgo::HasDynamicShapeFlag(primitive)) {
  727. return input_abstract;
  728. }
  729. auto ms_context = MsContext::GetInstance();
  730. MS_EXCEPTION_IF_NULL(ms_context);
  731. auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  732. if (device == kGPUDevice) {
  733. if (DynamicShapeConstInputToAttrGPU.find(primitive->name()) != DynamicShapeConstInputToAttrGPU.end()) {
  734. return input_abstract;
  735. }
  736. } else if (DynamicShapeConstInputToAttr.find(primitive->name()) != DynamicShapeConstInputToAttr.end()) {
  737. return input_abstract;
  738. }
  739. auto convert_input_list = reg.GetConstInputAttrInfo();
  740. auto input_names = primitive->GetAttr(kAttrInputNames);
  741. if (input_names == nullptr) {
  742. return input_abstract;
  743. }
  744. auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
  745. AbstractBasePtrList rectify_abs_list;
  746. size_t ori_index = 0;
  747. rectify_abs_list.resize(input_names_vec.size());
  748. for (size_t index = 0; index < rectify_abs_list.size(); ++index) {
  749. // if convert input list find the index it means the input has been converted to the attr
  750. if (convert_input_list.find(index) != convert_input_list.end()) {
  751. AbstractBasePtr rectify_abs = nullptr;
  752. auto input_name = input_names_vec[index];
  753. auto attr = primitive->GetAttr(input_name);
  754. if (attr != nullptr) {
  755. rectify_abs = attr->ToAbstract();
  756. } else {
  757. MS_LOG(DEBUG) << "the node prim name :" << primitive->name() << "input index :" << index
  758. << " input name :" << input_name << "has not been converted to the attr";
  759. rectify_abs = input_abstract[ori_index++];
  760. }
  761. rectify_abs_list[index] = rectify_abs;
  762. continue;
  763. }
  764. if (ori_index > input_abstract.size()) {
  765. MS_LOG(EXCEPTION) << "Index " << ori_index << " is out of range in input abstract size " << input_abstract.size();
  766. }
  767. rectify_abs_list[index] = input_abstract[ori_index++];
  768. }
  769. return rectify_abs_list;
  770. }
  771. AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
  772. const AbstractBasePtrList &input_abstract) {
  773. auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
  774. if (dynamic_inputs_list == nullptr) {
  775. return input_abstract;
  776. }
  777. AbstractBasePtrList rectifyed_abs_list;
  778. const int kNotDynamicFlag = -1;
  779. auto dynamic_inputs_index = GetValue<std::vector<int64_t>>(dynamic_inputs_list);
  780. size_t input_index = 0;
  781. for (auto item : dynamic_inputs_index) {
  782. if (item == kNotDynamicFlag) {
  783. if (input_index >= input_abstract.size()) {
  784. MS_LOG(EXCEPTION) << "Index " << input_index << " is out of range in input abstract " << input_abstract.size();
  785. }
  786. (void)rectifyed_abs_list.emplace_back(input_abstract[input_index++]);
  787. } else {
  788. if (item < 0) {
  789. MS_LOG(EXCEPTION) << "The dynamic input size check error the index should be -1 or positive number but got "
  790. << item;
  791. }
  792. AbstractBasePtrList dynamic_inputs_abs;
  793. for (auto index = item; index > 0; --index) {
  794. if (input_index >= input_abstract.size()) {
  795. MS_LOG(EXCEPTION) << "Index " << input_index << " is out of range in input abstract "
  796. << input_abstract.size();
  797. }
  798. (void)dynamic_inputs_abs.emplace_back(input_abstract[input_index++]);
  799. }
  800. (void)rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs));
  801. }
  802. }
  803. return rectifyed_abs_list;
  804. }
  805. AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) {
  806. auto rectify_abs_list = RectifyAbstractFromRegAttr(primitive, input_abstract);
  807. return RectifyAbstractFromDynamicInput(primitive, rectify_abs_list);
  808. }
  809. } // namespace
  810. AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
  811. MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  812. MS_EXCEPTION_IF_NULL(primitive_vars);
  813. if (utils::isa<VectorRef>(sexp)) {
  814. return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
  815. }
  816. if (utils::isa<VarPtr>(sexp)) {
  817. auto var_ptr = utils::cast<VarPtr>(sexp);
  818. MS_EXCEPTION_IF_NULL(var_ptr);
  819. if (var_ptr->primitive()) {
  820. (*primitive_vars)[var_ptr->primitive()] = var_ptr;
  821. return NewValueNode(var_ptr->primitive());
  822. }
  823. return CreateVarNodeWithSexp(sexp, graph);
  824. }
  825. if (utils::isa<AnfNodePtr>(sexp)) {
  826. return utils::cast<AnfNodePtr>(sexp);
  827. }
  828. auto value_node = CreateValueNodeWithSexp(sexp, primitive_vars);
  829. if (value_node == nullptr) {
  830. MS_LOG(EXCEPTION) << "Sexp cannot converted, sexp: " + sexp.ToString();
  831. }
  832. return value_node;
  833. }
  834. bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) {
  835. MS_EXCEPTION_IF_NULL(equiv1);
  836. MS_EXCEPTION_IF_NULL(equiv2);
  837. MS_EXCEPTION_IF_NULL(var_node);
  838. auto equiv1_node = GetAnfNodeByVar(equiv1, var_node);
  839. MS_EXCEPTION_IF_NULL(equiv1_node);
  840. auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
  841. MS_EXCEPTION_IF_NULL(equiv2_node);
  842. return *equiv1_node == *equiv2_node;
  843. }
  844. AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
  845. MS_EXCEPTION_IF_NULL(equiv);
  846. MS_EXCEPTION_IF_NULL(var_node);
  847. auto iter = (*equiv).find(var_node);
  848. if (iter == (*equiv).end()) {
  849. MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched.";
  850. return nullptr;
  851. }
  852. auto res = utils::cast<AnfNodePtr>(iter->second);
  853. if (res == nullptr) {
  854. MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node";
  855. }
  856. return res;
  857. }
  858. bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
  859. MS_EXCEPTION_IF_NULL(n1);
  860. MS_EXCEPTION_IF_NULL(n2);
  861. auto n1_cnode = n1->cast<CNodePtr>();
  862. auto n2_cnode = n2->cast<CNodePtr>();
  863. MS_EXCEPTION_IF_NULL(n1_cnode);
  864. MS_EXCEPTION_IF_NULL(n2_cnode);
  865. auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem);
  866. MS_EXCEPTION_IF_NULL(index_input1);
  867. auto value_node1 = index_input1->cast<ValueNodePtr>();
  868. MS_EXCEPTION_IF_NULL(value_node1);
  869. auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem);
  870. MS_EXCEPTION_IF_NULL(index_input2);
  871. auto value_node2 = index_input2->cast<ValueNodePtr>();
  872. MS_EXCEPTION_IF_NULL(value_node2);
  873. return GetValue<int64_t>(value_node1->value()) < GetValue<int64_t>(value_node2->value());
  874. }
  875. bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
  876. MS_EXCEPTION_IF_NULL(node);
  877. if (!node->isa<CNode>()) {
  878. MS_LOG(INFO) << "node is not a cnode";
  879. return false;
  880. }
  881. auto cnode = node->cast<CNodePtr>();
  882. MS_EXCEPTION_IF_NULL(cnode);
  883. return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
  884. }
  885. bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
  886. MS_EXCEPTION_IF_NULL(node);
  887. TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0);
  888. if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
  889. return true;
  890. }
  891. MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
  892. return false;
  893. }
  894. ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
  895. MS_EXCEPTION_IF_NULL(value_node);
  896. ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
  897. MS_EXCEPTION_IF_NULL(new_value_node);
  898. new_value_node->set_abstract(value_node->abstract());
  899. // create kernel_info fo new value node
  900. auto kernel_info = std::make_shared<device::KernelInfo>();
  901. new_value_node->set_kernel_info(kernel_info);
  902. // create kernel_build_info for new value node
  903. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  904. MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
  905. // set the format of value_node to DEFAULT_FORMAT
  906. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  907. // set value node initial device data type = infer data type
  908. std::vector<TypeId> types;
  909. size_t output_num = AnfAlgo::GetOutputTensorNum(value_node);
  910. for (size_t index = 0; index < output_num; ++index) {
  911. types.push_back(kTypeUnknown);
  912. }
  913. kernel_build_info_builder->SetOutputsDeviceType(types);
  914. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  915. return new_value_node;
  916. }
  917. void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
  918. MS_EXCEPTION_IF_NULL(old_node);
  919. MS_EXCEPTION_IF_NULL(graph);
  920. auto manager = graph->manager();
  921. MS_EXCEPTION_IF_NULL(manager);
  922. // Find BatchNorm's output which is a Depend or UpdateState.
  923. auto node_users = manager->node_users()[old_node];
  924. for (const auto &node_index : node_users) {
  925. AnfNodePtr output = node_index.first;
  926. MS_EXCEPTION_IF_NULL(output);
  927. if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
  928. AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
  929. auto depend = output->cast<CNodePtr>();
  930. MS_EXCEPTION_IF_NULL(depend);
  931. manager->SetEdge(depend, node_index.second, new_node);
  932. }
  933. }
  934. }
  935. AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
  936. MS_EXCEPTION_IF_NULL(prim);
  937. auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
  938. auto ret = prim_eval_implement_map.find(prim);
  939. if (ret != prim_eval_implement_map.end()) {
  940. // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
  941. MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_);
  942. auto infer_spec_list = RectifyAbstract(prim, args_spec_list);
  943. return ret->second.infer_shape_impl_(nullptr, prim, infer_spec_list);
  944. } else {
  945. // if the infer function has been not founded in the front infer map find it in the backend infer map instead
  946. auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
  947. auto ret_backend = prim_backend_eval_impl_map.find(prim);
  948. if (ret_backend != prim_backend_eval_impl_map.end()) {
  949. MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_);
  950. auto infer_spec_list = args_spec_list;
  951. if (!ret_backend->second.in_white_list_) {
  952. infer_spec_list = RectifyAbstract(prim, args_spec_list);
  953. }
  954. return ret_backend->second.infer_shape_impl_(nullptr, prim, infer_spec_list);
  955. }
  956. }
  957. MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
  958. << " primitive type:" << prim->type_name();
  959. }
  960. kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
  961. std::vector<std::string> inputs_device_format;
  962. std::vector<std::string> outputs_device_format;
  963. std::vector<TypeId> inputs_device_type;
  964. std::vector<TypeId> outputs_device_type;
  965. std::vector<std::vector<size_t>> outputs_shape;
  966. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  967. for (size_t idx = 0; idx < node_list.size(); ++idx) {
  968. auto cnode = utils::cast<CNodePtr>(node_list[idx]);
  969. MS_EXCEPTION_IF_NULL(cnode);
  970. size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
  971. for (size_t input_index = 0; input_index < input_num; ++input_index) {
  972. (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
  973. (void)inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
  974. }
  975. size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
  976. for (size_t output_index = 0; output_index < output_num; ++output_index) {
  977. (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
  978. (void)outputs_device_type.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
  979. (void)outputs_shape.emplace_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
  980. }
  981. }
  982. builder.SetInputsFormat(inputs_device_format);
  983. builder.SetOutputsFormat(outputs_device_format);
  984. builder.SetInputsDeviceType(inputs_device_type);
  985. builder.SetOutputsDeviceType(outputs_device_type);
  986. return builder.Build();
  987. }
  988. std::vector<int64_t> GetNodeOutputUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
  989. MS_EXCEPTION_IF_NULL(node);
  990. auto manager = kernel_graph.manager();
  991. MS_EXCEPTION_IF_NULL(manager);
  992. auto output_num = AnfAlgo::GetOutputTensorNum(node);
  993. std::vector<int64_t> output_used_num(output_num, 0);
  994. if (output_num == 1) {
  995. output_used_num[0] = SizeToLong(manager->node_users()[node].size());
  996. } else {
  997. for (auto out_getitem : manager->node_users()[node]) {
  998. MS_EXCEPTION_IF_NULL(out_getitem.first);
  999. if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
  1000. continue;
  1001. }
  1002. auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
  1003. MS_EXCEPTION_IF_NULL(out_getitem_ptr);
  1004. auto getitem_input2 = out_getitem_ptr->input(kInputNodeOutputIndexInTupleGetItem);
  1005. auto output_idx = LongToSize(GetValue<int64_t>(GetValueNode(getitem_input2)));
  1006. output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
  1007. }
  1008. }
  1009. return output_used_num;
  1010. }
  1011. int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
  1012. auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node);
  1013. return std::accumulate(output_used_num.begin(), output_used_num.end(), int64_t(0));
  1014. }
  1015. void GetCustomOpAttrIndex(const PrimitivePtr &primitive, mindspore::HashSet<size_t> *indexes) {
  1016. if (primitive == nullptr || primitive->name() != prim::kPrimCustom->name()) {
  1017. return;
  1018. }
  1019. MS_EXCEPTION_IF_NULL(indexes);
  1020. auto input_names = primitive->GetAttr(kAttrInputNames);
  1021. auto attr_names = primitive->GetAttr(kAttrAttrNames);
  1022. if (input_names == nullptr || attr_names == nullptr) {
  1023. return;
  1024. }
  1025. auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
  1026. auto attr_names_vec = GetValue<std::vector<std::string>>(attr_names);
  1027. if (input_names_vec.size() >= attr_names_vec.size()) {
  1028. size_t offset = input_names_vec.size() - attr_names_vec.size();
  1029. for (size_t i = offset; i < input_names_vec.size(); ++i) {
  1030. if (input_names_vec[i] != attr_names_vec[i - offset]) {
  1031. MS_LOG(EXCEPTION) << primitive->name() << " found mismatching attr name " << input_names_vec[i]
  1032. << "in input_names and " << attr_names_vec[i - offset] << " in attr_names";
  1033. }
  1034. indexes->insert(i);
  1035. }
  1036. }
  1037. }
  1038. } // namespace opt
  1039. } // namespace mindspore