You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

visit.cc 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "backend/optimizer/common/visit.h"
  19. #include <vector>
  20. #include <memory>
  21. #include <algorithm>
  22. #include "backend/optimizer/common/pattern_engine.h"
  23. #include "utils/any.h"
  24. #include "ir/anf.h"
  25. #include "ir/func_graph.h"
  26. #include "utils/log_adapter.h"
  27. /* namespace to support utils definition */
  28. namespace mindspore {
  29. bool CheckIfNeedExpand(const std::vector<BaseRef> &list) {
  30. return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa<Seq>(any); });
  31. }
  32. std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list) {
  33. std::shared_ptr<VectorRef> new_list = std::make_shared<VectorRef>();
  34. for (auto &item : list) {
  35. if (utils::isa<Seq>(item)) {
  36. const Seq &seq = utils::cast<Seq>(item);
  37. new_list->insert(new_list->end(), seq.begin(), seq.end());
  38. } else {
  39. new_list->push_back(item);
  40. }
  41. }
  42. return new_list;
  43. }
  44. static BaseRef GetVar(const BaseRef &x) {
  45. if (utils::isa<AnfNodePtr>(x)) {
  46. auto node = utils::cast<AnfNodePtr>(x);
  47. MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
  48. if (node->isa<VarNode>()) {
  49. MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
  50. return node->cast<VarNodePtr>()->var_;
  51. }
  52. }
  53. return x;
  54. }
  55. bool Visitor::Visit(const VectorRef &v_any, VectorRef *const values_ref, BaseRef *const visit_out) const {
  56. std::vector<BaseRef> out;
  57. for (const auto &element : v_any) {
  58. out.push_back(element);
  59. values_ref->push_back(GetVar(element));
  60. }
  61. if (visit_out != nullptr) {
  62. *visit_out = ExpandList(out);
  63. }
  64. return true;
  65. }
  66. bool Visitor::Visit(const BaseRef &any, VectorRef *const values_ref, BaseRef *const visit_out) const {
  67. if (utils::isa<Seq>(any)) {
  68. return Visit(utils::cast<Seq>(any), values_ref, visit_out);
  69. } else if (utils::isa<AnfNodePtr>(any)) {
  70. auto nodeptr = utils::cast<AnfNodePtr>(any);
  71. AnfNodePtr output;
  72. AnfNodePtr *p_output = &output;
  73. if (visit_out == nullptr) {
  74. p_output = nullptr;
  75. }
  76. Visit(nodeptr, values_ref, p_output);
  77. if (visit_out != nullptr) {
  78. *visit_out = output;
  79. }
  80. return true;
  81. }
  82. MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString();
  83. return false;
  84. }
  85. void Visitor::Visit(const AnfNodePtr &node, VectorRef *const values_ref, AnfNodePtr *output) const {
  86. if (node->isa<CNode>()) {
  87. Visit(node->cast<CNodePtr>(), values_ref, output);
  88. return;
  89. }
  90. if (node->isa<ValueNode>()) {
  91. Visit(node->cast<ValueNodePtr>(), values_ref, output);
  92. return;
  93. }
  94. if (output != nullptr) {
  95. *output = node;
  96. }
  97. }
  98. void Visitor::Visit(const CNodePtr &cnode, VectorRef *const values_ref, AnfNodePtr *output) const {
  99. // if output is nullptr, it's not required to make the new CNode node.
  100. if (output == nullptr) {
  101. for (auto &inp : cnode->inputs()) {
  102. auto var = GetVar(inp);
  103. values_ref->push_back(var);
  104. }
  105. if (cnode->func_graph() != nullptr) {
  106. values_ref->push_back(GetVar(cnode->func_graph()));
  107. } else {
  108. values_ref->push_back(GetVar(cnode->func_graph_as_var()));
  109. }
  110. return;
  111. }
  112. std::vector<AnfNodePtr> new_inputs;
  113. std::vector<BaseRef> after_cnode_fn;
  114. std::shared_ptr<VectorRef> out;
  115. for (auto &input : cnode->inputs()) {
  116. after_cnode_fn.push_back(input);
  117. values_ref->push_back(GetVar(input));
  118. }
  119. if (CheckIfNeedExpand(after_cnode_fn)) {
  120. out = ExpandList(after_cnode_fn);
  121. }
  122. std::vector<BaseRef> &outs = after_cnode_fn;
  123. if (out != nullptr) {
  124. outs = out->elements();
  125. }
  126. for (auto &any_item : outs) {
  127. if (!utils::isa<AnfNodePtr>(any_item)) {
  128. MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr";
  129. }
  130. new_inputs.push_back(utils::cast<AnfNodePtr>(any_item));
  131. }
  132. BaseRef any_fg;
  133. AnfNodePtr new_cnode = nullptr;
  134. if (cnode->func_graph() != nullptr) {
  135. any_fg = cnode->func_graph();
  136. values_ref->push_back(GetVar(any_fg));
  137. if (!utils::isa<FuncGraphPtr>(any_fg)) {
  138. MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr";
  139. }
  140. new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
  141. } else {
  142. any_fg = cnode->func_graph_as_var();
  143. values_ref->push_back(GetVar(any_fg));
  144. if (utils::isa<VarPtr>(any_fg)) {
  145. new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<VarPtr>(any_fg));
  146. } else if (utils::isa<FuncGraphPtr>(any_fg)) {
  147. new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
  148. } else {
  149. MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr";
  150. }
  151. }
  152. new_cnode->set_abstract(cnode->abstract());
  153. *output = new_cnode;
  154. }
  155. void Visitor::Visit(const ValueNodePtr &vnode, VectorRef *const values_ref, AnfNodePtr *output) const {
  156. values_ref->push_back(GetVar(vnode->value()));
  157. const BaseRef &value = utils::cast<ValuePtr>(vnode->value());
  158. if (utils::isa<ValuePtr>(value)) {
  159. if (output != nullptr) {
  160. auto ct = NewValueNode(utils::cast<ValuePtr>(value));
  161. ct->set_abstract(vnode->abstract());
  162. *output = ct;
  163. }
  164. return;
  165. }
  166. MS_LOG(EXCEPTION) << "Visit result is not ValuePtr.";
  167. }
  168. } // namespace mindspore