You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cc 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pre_activate/common/helper.h"
  17. #include <string>
  18. #include <unordered_set>
  19. #include <algorithm>
  20. #include "utils/utils.h"
  21. #include "utils/base_ref.h"
  22. #include "session/anf_runtime_algorithm.h"
  23. #include "operator/ops.h"
  24. #include "common/utils.h"
  25. #include "device/kernel_info.h"
  26. #include "utils/context/ms_context.h"
  27. namespace mindspore {
  28. namespace opt {
  29. constexpr size_t kType32Len = 4;
  30. std::vector<int> Convert2Int(const std::vector<size_t> &v) {
  31. std::vector<int> result;
  32. (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
  33. return result;
  34. }
  35. bool UnVisited(const BaseRef &n) {
  36. if (utils::isa<AnfNodePtr>(n)) {
  37. AnfNodePtr in = utils::cast<AnfNodePtr>(n);
  38. MS_EXCEPTION_IF_NULL(in);
  39. if (IsValueNode<Primitive>(in)) {
  40. auto value_node = in->cast<ValueNodePtr>();
  41. MS_EXCEPTION_IF_NULL(value_node);
  42. auto value = value_node->value();
  43. MS_EXCEPTION_IF_NULL(value);
  44. auto prim_py = value->cast<PrimitivePtr>();
  45. MS_EXCEPTION_IF_NULL(prim_py);
  46. return !prim_py->HasAttr(kAttrVisited);
  47. } else {
  48. return false;
  49. }
  50. }
  51. return false;
  52. }
  53. bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode) {
  54. MS_EXCEPTION_IF_NULL(node);
  55. if (!node->isa<CNode>()) {
  56. MS_LOG(ERROR) << "The node is expected to be a cnode";
  57. return false;
  58. }
  59. *cnode = node->cast<CNodePtr>();
  60. if (*cnode == nullptr) {
  61. return false;
  62. }
  63. if ((*cnode)->inputs().size() < IntToSize(input_size)) {
  64. auto op_name = AnfAlgo::GetCNodeName(*cnode);
  65. MS_LOG(ERROR) << "op[" + op_name + "] has less than " << input_size << " inputs.";
  66. return false;
  67. }
  68. return true;
  69. }
  70. CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size) {
  71. MS_EXCEPTION_IF_NULL(node);
  72. if (!node->isa<CNode>()) {
  73. MS_LOG(EXCEPTION) << "The node is expected to be a cnode";
  74. }
  75. auto cnode = node->cast<CNodePtr>();
  76. MS_EXCEPTION_IF_NULL(cnode);
  77. if (cnode->inputs().size() != IntToSize(input_size)) {
  78. auto op_name = AnfAlgo::GetCNodeName(cnode);
  79. MS_LOG(EXCEPTION) << "op[" + op_name + "] has less than " << input_size << " inputs.";
  80. }
  81. return cnode;
  82. }
  83. void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size) {
  84. MS_EXCEPTION_IF_NULL(cnode);
  85. if (cnode->inputs().size() != input_size) {
  86. MS_LOG(EXCEPTION) << "The input size of node " + cnode->DebugString() + " is not equal to " << input_size;
  87. }
  88. }
  89. bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) {
  90. MS_EXCEPTION_IF_NULL(node_x);
  91. MS_EXCEPTION_IF_NULL(node_y);
  92. return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) &&
  93. AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0));
  94. }
  95. const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  96. MS_EXCEPTION_IF_NULL(func_graph);
  97. auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum);
  98. auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum);
  99. auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum);
  100. MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1));
  101. MS_EXCEPTION_IF_NULL(prev_transop_cnode->input(kTransOpInputNum - 1));
  102. auto transed_node = prev_transop_cnode->input(kTransOpInputNum - 1);
  103. MS_EXCEPTION_IF_NULL(transed_node);
  104. std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node,
  105. depend_cnode->input(kDependInputNum - 1)};
  106. AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs);
  107. MS_EXCEPTION_IF_NULL(replace_depend);
  108. auto transed_abstract = transed_node->abstract();
  109. replace_depend->set_abstract(transed_abstract);
  110. return replace_depend;
  111. }
  112. bool Visited(const BaseRef &n) {
  113. if (utils::isa<AnfNodePtr>(n)) {
  114. AnfNodePtr in = utils::cast<AnfNodePtr>(n);
  115. MS_EXCEPTION_IF_NULL(in);
  116. if (IsValueNode<Primitive>(in)) {
  117. auto value_node = in->cast<ValueNodePtr>();
  118. MS_EXCEPTION_IF_NULL(value_node);
  119. auto value = value_node->value();
  120. MS_EXCEPTION_IF_NULL(value);
  121. auto prim_py = value->cast<PrimitivePtr>();
  122. MS_EXCEPTION_IF_NULL(prim_py);
  123. return prim_py->HasAttr(kAttrVisited);
  124. } else {
  125. return false;
  126. }
  127. }
  128. return false;
  129. }
  130. void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode,
  131. std::vector<AnfNodePtr> *conv_bn1_outputs) {
  132. auto prim = std::make_shared<Primitive>(kConvBN1OpName);
  133. std::vector<AnfNodePtr> conv_bn1_inputs = {NewValueNode(prim)};
  134. MS_EXCEPTION_IF_NULL(conv_cnode);
  135. // All the inputs of conv_bn1 are from the inputs of conv
  136. for (size_t i = 1; i < conv_cnode->inputs().size(); i++) {
  137. conv_bn1_inputs.push_back(conv_cnode->input(i));
  138. }
  139. MS_EXCEPTION_IF_NULL(func_graph);
  140. CNodePtr conv_bn1_cnode = func_graph->NewCNode(conv_bn1_inputs);
  141. MS_EXCEPTION_IF_NULL(conv_bn1_cnode);
  142. auto kernel_info = std::make_shared<device::KernelInfo>();
  143. conv_bn1_cnode->set_kernel_info(kernel_info);
  144. // Set attr for conv_bn1
  145. AnfAlgo::CopyNodeAttrs(conv_cnode, conv_bn1_cnode);
  146. // Set abstract of conv_bn1
  147. MS_EXCEPTION_IF_NULL(bn_cnode);
  148. auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_cnode->abstract());
  149. MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
  150. AbstractBasePtrList conv_bn1_abstract_list;
  151. conv_bn1_abstract_list.push_back(conv_cnode->abstract());
  152. auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(
  153. kFloat32, Convert2Int(AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, kVariance - 1)));
  154. conv_bn1_abstract_list.push_back(abstract_tensor);
  155. conv_bn1_abstract_list.push_back(bn_abstract_tuple->elements()[kSaveMean]);
  156. auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(conv_bn1_abstract_list);
  157. conv_bn1_cnode->set_abstract(abstract_tuple);
  158. CreateMultipleOutputsOfAnfNode(func_graph, conv_bn1_cnode, kConvBn1OutputNum, conv_bn1_outputs);
  159. }
  160. void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &fused_bn1_outputs,
  161. const CNodePtr &bn_node, std::vector<AnfNodePtr> *fused_bn2_outputs) {
  162. MS_EXCEPTION_IF_NULL(graph);
  163. MS_EXCEPTION_IF_NULL(bn_node);
  164. MS_EXCEPTION_IF_NULL(fused_bn2_outputs);
  165. if (bn_node->inputs().size() != kBnInputNum) {
  166. MS_LOG(EXCEPTION) << "BN node has wrong input size";
  167. }
  168. if (fused_bn1_outputs.size() != kBN1OutputNum) {
  169. MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size";
  170. }
  171. // the inputs of fused_bn2 are from the outputs of fused_bn1 and the inputs of bn
  172. std::vector<AnfNodePtr> fused_bn2_inputs = {NewValueNode(std::make_shared<Primitive>(kFusedBN2OpName))};
  173. fused_bn2_inputs.push_back(fused_bn1_outputs[0]);
  174. fused_bn2_inputs.push_back(fused_bn1_outputs[1]);
  175. fused_bn2_inputs.push_back(bn_node->input(4));
  176. fused_bn2_inputs.push_back(bn_node->input(5));
  177. auto fused_bn2 = graph->NewCNode(fused_bn2_inputs);
  178. MS_EXCEPTION_IF_NULL(fused_bn2);
  179. auto kernel_info = std::make_shared<device::KernelInfo>();
  180. fused_bn2->set_kernel_info(kernel_info);
  181. auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 4), AnfAlgo::GetOutputInferDataType(bn_node, 1),
  182. AnfAlgo::GetOutputInferDataType(bn_node, 2)};
  183. auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 4), AnfAlgo::GetOutputInferShape(bn_node, 1),
  184. AnfAlgo::GetOutputInferShape(bn_node, 2)};
  185. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn2.get());
  186. fused_bn2->set_scope(bn_node->scope());
  187. AnfAlgo::CopyNodeAttr(kAttrMomentum, bn_node, fused_bn2);
  188. CreateMultipleOutputsOfAnfNode(graph, fused_bn2, kBN2OutputNum, fused_bn2_outputs);
  189. }
  190. void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input,
  191. const std::vector<AnfNodePtr> &fused_bn1_outputs,
  192. const std::vector<AnfNodePtr> &fused_bn2_outputs, const CNodePtr &bn_node,
  193. std::vector<AnfNodePtr> *fused_bn3_outputs) {
  194. MS_EXCEPTION_IF_NULL(graph);
  195. MS_EXCEPTION_IF_NULL(data_input);
  196. MS_EXCEPTION_IF_NULL(bn_node);
  197. MS_EXCEPTION_IF_NULL(fused_bn3_outputs);
  198. if (bn_node->inputs().size() != kBnInputNum) {
  199. MS_LOG(EXCEPTION) << "BN node has wrong input size";
  200. }
  201. if (fused_bn1_outputs.size() != kBN1OutputNum) {
  202. MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size";
  203. }
  204. if (fused_bn2_outputs.size() != kBN2OutputNum) {
  205. MS_LOG(EXCEPTION) << "BN2 outputs has wrong input size";
  206. }
  207. // the inputs of fused_bn3 are from the outputs of fused_bn1 and the inputs of bn
  208. std::vector<AnfNodePtr> fused_bn3_inputs = {NewValueNode(std::make_shared<Primitive>(kFusedBN3OpName))};
  209. fused_bn3_inputs.push_back(data_input);
  210. fused_bn3_inputs.push_back(fused_bn1_outputs[0]);
  211. fused_bn3_inputs.push_back(fused_bn2_outputs[0]);
  212. fused_bn3_inputs.push_back(bn_node->input(2));
  213. fused_bn3_inputs.push_back(bn_node->input(3));
  214. auto fused_bn3 = graph->NewCNode(fused_bn3_inputs);
  215. MS_EXCEPTION_IF_NULL(fused_bn3);
  216. auto kernel_info = std::make_shared<device::KernelInfo>();
  217. fused_bn3->set_kernel_info(kernel_info);
  218. auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 0)};
  219. auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 0)};
  220. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn3.get());
  221. fused_bn3->set_scope(bn_node->scope());
  222. AnfAlgo::CopyNodeAttr(kAttrEpsilon, kAttrEps, bn_node, fused_bn3);
  223. (*fused_bn3_outputs).push_back(fused_bn3);
  224. }
  225. void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
  226. std::vector<AnfNodePtr> *outputs) {
  227. MS_EXCEPTION_IF_NULL(func_graph);
  228. MS_EXCEPTION_IF_NULL(node);
  229. MS_EXCEPTION_IF_NULL(outputs);
  230. for (size_t i = 0; i < output_num; i++) {
  231. auto idx = NewValueNode(SizeToInt(i));
  232. MS_EXCEPTION_IF_NULL(idx);
  233. int temp = SizeToInt(i);
  234. auto imm = std::make_shared<Int32Imm>(temp);
  235. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  236. idx->set_abstract(abstract_scalar);
  237. auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  238. MS_EXCEPTION_IF_NULL(tuple_getitem);
  239. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(node, i)},
  240. {AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get());
  241. (*outputs).push_back(tuple_getitem);
  242. }
  243. }
  244. template <typename T>
  245. tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
  246. size_t data_length) {
  247. MS_EXCEPTION_IF_NULL(value_tuple_ptr);
  248. MS_EXCEPTION_IF_NULL(type_ptr);
  249. std::vector<T> values;
  250. for (const auto &v : value_tuple_ptr->value()) {
  251. MS_EXCEPTION_IF_NULL(v);
  252. if (v->isa<Scalar>()) {
  253. ScalarPtr scalar = v->cast<ScalarPtr>();
  254. values.push_back(GetValue<T>(scalar));
  255. } else {
  256. MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
  257. return nullptr;
  258. }
  259. }
  260. std::vector<int> tensor_shape = {SizeToInt(values.size())};
  261. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
  262. MS_EXCEPTION_IF_NULL(tensor);
  263. tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
  264. tensor->set_device_info(device_info);
  265. auto data_ptr = tensor->data_c(true);
  266. MS_EXCEPTION_IF_NULL(data_ptr);
  267. auto elem_num = values.size() * data_length;
  268. auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
  269. if (ret_code != 0) {
  270. MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
  271. }
  272. return tensor;
  273. }
  274. tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
  275. MS_EXCEPTION_IF_NULL(value_tuple);
  276. tensor::TensorPtr tensor = nullptr;
  277. ValuePtr v = *(value_tuple->value().begin());
  278. MS_EXCEPTION_IF_NULL(v);
  279. // Currently we only deal with the scalar tuple
  280. if (!v->isa<Scalar>()) {
  281. MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
  282. return nullptr;
  283. }
  284. ScalarPtr scalar = v->cast<ScalarPtr>();
  285. MS_EXCEPTION_IF_NULL(scalar);
  286. if (scalar->isa<IntergerImm>()) {
  287. tensor = CreateTensorWithValueTuple<int>(value_tuple, kInt32, kType32Len);
  288. } else if (scalar->isa<FloatImm>()) {
  289. tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, kType32Len);
  290. } else {
  291. auto type = scalar->type();
  292. auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
  293. MS_LOG(ERROR) << "Invalid scalar type: " << type_str;
  294. return nullptr;
  295. }
  296. return tensor;
  297. }
  298. bool IsNopNode(const AnfNodePtr &node) {
  299. auto context_ptr = MsContext::GetInstance();
  300. MS_EXCEPTION_IF_NULL(context_ptr);
  301. if (context_ptr->device_target() != kAscendDevice) {
  302. return false;
  303. }
  304. static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
  305. prim::kPrimSqueeze->name(), prim::kPrimFlatten->name()};
  306. if (node == nullptr || !node->isa<CNode>()) {
  307. return false;
  308. }
  309. CNodePtr cnode = node->cast<CNodePtr>();
  310. MS_EXCEPTION_IF_NULL(cnode);
  311. if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end()) {
  312. return false;
  313. }
  314. return true;
  315. }
  316. void HideNopNode(session::KernelGraph *const graph) {
  317. MS_EXCEPTION_IF_NULL(graph);
  318. auto execution_order = graph->execution_order();
  319. MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
  320. std::vector<CNodePtr> new_nodes;
  321. for (auto &cnode : execution_order) {
  322. MS_EXCEPTION_IF_NULL(cnode);
  323. if (!IsNopNode(cnode)) {
  324. new_nodes.push_back(cnode);
  325. }
  326. }
  327. graph->set_execution_order(new_nodes);
  328. MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size();
  329. }
  330. void RemoveNopNode(session::KernelGraph *const graph) {
  331. MS_EXCEPTION_IF_NULL(graph);
  332. bool changed = true;
  333. while (changed) {
  334. changed = false;
  335. std::vector<CNodePtr> new_nodes;
  336. for (auto &cnode : graph->execution_order()) {
  337. MS_EXCEPTION_IF_NULL(cnode);
  338. // ignore nop node itself
  339. if (IsNopNode(cnode)) {
  340. continue;
  341. }
  342. // Replace the input which is nop node
  343. std::vector<AnfNodePtr> new_inputs;
  344. new_inputs.push_back(cnode->input(0));
  345. bool need_update = false;
  346. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  347. auto input = cnode->input(i);
  348. MS_EXCEPTION_IF_NULL(input);
  349. auto cinput = input->cast<CNodePtr>();
  350. if (cinput == nullptr || !IsNopNode(cinput)) {
  351. new_inputs.push_back(input);
  352. continue;
  353. }
  354. if (cinput->inputs().size() == 2) {
  355. new_inputs.push_back(cinput->input(1));
  356. need_update = true;
  357. changed = true;
  358. } else {
  359. new_inputs.push_back(input);
  360. }
  361. }
  362. if (need_update) {
  363. cnode->set_inputs(new_inputs);
  364. }
  365. // push into new execution list
  366. new_nodes.push_back(cnode);
  367. }
  368. graph->set_execution_order(new_nodes);
  369. }
  370. }
  371. bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  372. MS_EXCEPTION_IF_NULL(graph);
  373. MS_EXCEPTION_IF_NULL(node);
  374. auto manager = graph->manager();
  375. MS_EXCEPTION_IF_NULL(manager);
  376. if (manager->node_users().find(node) == manager->node_users().end()) {
  377. MS_LOG(EXCEPTION) << "node has no output in manager";
  378. }
  379. return manager->node_users()[node].size() > 1;
  380. }
  381. AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
  382. auto idx = NewValueNode(SizeToInt(output_idx));
  383. MS_EXCEPTION_IF_NULL(idx);
  384. auto imm = std::make_shared<Int32Imm>(SizeToInt(output_idx));
  385. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  386. idx->set_abstract(abstract_scalar);
  387. AnfNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  388. MS_EXCEPTION_IF_NULL(tuple_getitem);
  389. tuple_getitem->set_scope(node->scope());
  390. std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
  391. TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
  392. AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
  393. return tuple_getitem;
  394. }
  395. void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) {
  396. MS_EXCEPTION_IF_NULL(cnode);
  397. std::vector<AnfNodePtr> new_inputs;
  398. std::vector<std::string> new_input_names;
  399. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  400. MS_EXCEPTION_IF_NULL(primitive);
  401. auto input_names = primitive->GetAttr(kAttrInputNames);
  402. if (input_names == nullptr) {
  403. MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
  404. return;
  405. }
  406. auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
  407. auto inputs = cnode->inputs();
  408. new_inputs.push_back(inputs[0]);
  409. bool need_update = false;
  410. for (size_t i = 0; i < inputs.size() - 1; ++i) {
  411. auto input_node = inputs[i + 1];
  412. MS_EXCEPTION_IF_NULL(input_node);
  413. if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
  414. auto value_node = input_node->cast<ValueNodePtr>();
  415. MS_EXCEPTION_IF_NULL(value_node);
  416. MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
  417. if (i >= input_names_vec.size()) {
  418. MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
  419. }
  420. primitive->set_attr(input_names_vec[i], value_node->value());
  421. need_update = true;
  422. } else {
  423. new_inputs.push_back(input_node);
  424. if (i < input_names_vec.size()) {
  425. new_input_names.push_back(input_names_vec[i]);
  426. }
  427. }
  428. }
  429. if (need_update) {
  430. // Update cnode's inputs
  431. cnode->set_inputs(new_inputs);
  432. // Update cnode's input_names attr
  433. primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));
  434. }
  435. }
  436. } // namespace opt
  437. } // namespace mindspore