You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

specialize_transform.h 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_
  17. #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_
  18. #include <map>
  19. #include <vector>
  20. #include <memory>
  21. #include <utility>
  22. #include <unordered_map>
  23. #include <unordered_set>
  24. #include "optimizer/irpass.h"
  25. #include "optimizer/optimizer.h"
  26. #include "ir/visitor.h"
  27. #include "ir/manager.h"
  28. #include "ir/func_graph.h"
  29. #include "ir/func_graph_cloner.h"
  30. #include "operator/ops.h"
  31. namespace mindspore {
  32. namespace opt {
  33. namespace irpass {
  34. namespace internal {
  35. class SpecializeTransform {
  36. public:
  37. SpecializeTransform() : cache_() {}
  38. ~SpecializeTransform() = default;
  39. FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args,
  40. std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> value_args) {
  41. if (cache_.count(func_graph) == 0) {
  42. cache_[func_graph] = {};
  43. }
  44. auto &cache = cache_[func_graph];
  45. auto key = std::make_pair(graph_args, prim_args);
  46. if (cache.count(key) == 0) {
  47. auto mng = func_graph->manager();
  48. MS_EXCEPTION_IF_NULL(mng);
  49. FuncGraphPtr new_fg = TransformableClone(func_graph, std::make_shared<TraceTransform>("sp"));
  50. mng->AddFuncGraph(new_fg);
  51. std::vector<AnfNodePtr> params = new_fg->parameters();
  52. std::vector<AnfNodePtr> new_params;
  53. size_t n = graph_args.size();
  54. for (size_t i = 0; i < n; i++) {
  55. if (graph_args[i] != nullptr) {
  56. auto arg = NewValueNode(graph_args[i]);
  57. (void)mng->Replace(params[i], arg);
  58. continue;
  59. }
  60. if (prim_args[i] != nullptr) {
  61. auto arg = NewValueNode(prim_args[i]);
  62. (void)mng->Replace(params[i], arg);
  63. continue;
  64. }
  65. if (value_args[i] != nullptr) {
  66. auto &const_tensor = *value_args[i];
  67. auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
  68. AnfNodePtr arg = NewValueNode(const_tensor_ptr);
  69. (void)mng->Replace(params[i], arg);
  70. continue;
  71. }
  72. new_params.push_back(params[i]);
  73. }
  74. mng->SetParameters(new_fg, new_params);
  75. cache[key] = new_fg;
  76. }
  77. return cache[key];
  78. }
  79. private:
  80. std::unordered_map<FuncGraphPtr,
  81. std::map<std::pair<std::vector<FuncGraphPtr>, std::vector<PrimitivePtr>>, FuncGraphPtr>>
  82. cache_;
  83. };
  84. } // namespace internal
  85. // {G, Xs}
  86. class SpecializeOnGraphArguments : public AnfVisitor {
  87. public:
  88. SpecializeOnGraphArguments() : specialize_transform_() {}
  89. ~SpecializeOnGraphArguments() override = default;
  90. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  91. if (!node->isa<CNode>() || node->func_graph() == nullptr) {
  92. return nullptr;
  93. }
  94. auto &inputs = node->cast<CNodePtr>()->inputs();
  95. if (!IsValueNode<FuncGraph>(inputs[0])) {
  96. return nullptr;
  97. }
  98. auto inp0_fg = GetValueNode<FuncGraphPtr>(inputs[0]);
  99. if (inp0_fg->recursive()) {
  100. return nullptr;
  101. }
  102. std::vector<FuncGraphPtr> graph_args;
  103. std::vector<PrimitivePtr> prim_args;
  104. std::vector<tensor::TensorPtr> value_node_args;
  105. std::vector<AnfNodePtr> new_xs;
  106. bool hasVNode = false;
  107. for (size_t i = 1; i < inputs.size(); i++) {
  108. if (IsValueNode<FuncGraph>(inputs[i])) {
  109. auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]);
  110. graph_args.push_back(fg_vnode);
  111. prim_args.emplace_back(nullptr);
  112. value_node_args.emplace_back(nullptr);
  113. hasVNode = true;
  114. } else if (IsValueNode<Primitive>(inputs[i])) {
  115. auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]);
  116. graph_args.emplace_back(nullptr);
  117. prim_args.push_back(p_vnode);
  118. value_node_args.emplace_back(nullptr);
  119. hasVNode = true;
  120. } else if (IsValueNode<tensor::Tensor>(inputs[i])) {
  121. tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]);
  122. graph_args.emplace_back(nullptr);
  123. prim_args.emplace_back(nullptr);
  124. value_node_args.emplace_back(t_vnode);
  125. hasVNode = true;
  126. } else {
  127. graph_args.emplace_back(nullptr);
  128. prim_args.emplace_back(nullptr);
  129. value_node_args.emplace_back(nullptr);
  130. new_xs.push_back(inputs[i]);
  131. }
  132. }
  133. if (!hasVNode) {
  134. return nullptr;
  135. }
  136. auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args);
  137. (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
  138. return node->func_graph()->NewCNode(new_xs);
  139. }
  140. private:
  141. internal::SpecializeTransform specialize_transform_;
  142. };
  143. // Eliminate unused parameters.
  144. // {G, Xs}
  145. class UnusedParasEliminater : public AnfVisitor {
  146. public:
  147. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  148. if (!node->isa<CNode>() || node->func_graph() == nullptr) {
  149. return nullptr;
  150. }
  151. auto cnode = node->cast<CNodePtr>();
  152. MS_EXCEPTION_IF_NULL(cnode);
  153. auto &inputs = cnode->inputs();
  154. auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
  155. MS_EXCEPTION_IF_NULL(fg);
  156. std::vector<AnfNodePtr> parameters = fg->parameters();
  157. size_t size = parameters.size();
  158. if (size != inputs.size() - 1) {
  159. return nullptr;
  160. }
  161. std::vector<AnfNodePtr> new_xs;
  162. std::vector<bool> keep_parameters;
  163. auto mng = fg->manager();
  164. MS_EXCEPTION_IF_NULL(mng);
  165. auto &node_users = mng->node_users();
  166. bool has_unused_para = false;
  167. for (size_t i = 0; i < size; ++i) {
  168. auto iter = node_users.find(parameters[i]);
  169. if (iter != node_users.end() && !iter->second.empty()) {
  170. keep_parameters.push_back(true);
  171. new_xs.push_back(inputs[i + 1]);
  172. continue;
  173. }
  174. keep_parameters.push_back(false);
  175. has_unused_para = true;
  176. }
  177. if (!has_unused_para) {
  178. return nullptr;
  179. }
  180. FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
  181. mng->AddFuncGraph(new_fg);
  182. std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
  183. std::vector<AnfNodePtr> new_parameters;
  184. for (size_t i = 0; i < size; i++) {
  185. if (keep_parameters[i]) {
  186. if (parameters[i]->abstract() != nullptr) {
  187. new_fg_parameters[i]->set_abstract(parameters[i]->abstract());
  188. }
  189. new_parameters.push_back(new_fg_parameters[i]);
  190. }
  191. }
  192. mng->SetParameters(new_fg, new_parameters);
  193. (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
  194. return node->func_graph()->NewCNode(new_xs);
  195. }
  196. };
  197. // Eliminate unused outputs.
  198. // {G, Xs}
  199. class UnusedOutputEliminater : public AnfVisitor {
  200. public:
  201. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  202. if (!node->isa<CNode>() || node->func_graph() == nullptr) {
  203. return nullptr;
  204. }
  205. auto &inputs = node->cast<CNodePtr>()->inputs();
  206. auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
  207. MS_EXCEPTION_IF_NULL(fg);
  208. auto mng = fg->manager();
  209. MS_EXCEPTION_IF_NULL(mng);
  210. if (fg->recursive()) {
  211. return nullptr;
  212. }
  213. auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
  214. mng->AddFuncGraph(new_fg);
  215. auto new_fg_output = new_fg->output();
  216. if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) {
  217. return nullptr;
  218. }
  219. auto output_cnode = new_fg_output->cast<CNodePtr>();
  220. auto &node_users = mng->node_users();
  221. if (node_users.count(node) == 0 || node_users[node].empty()) {
  222. return nullptr;
  223. }
  224. std::unordered_set<int> used_output_idx;
  225. std::vector<std::pair<AnfNodePtr, int>> all_users;
  226. for (auto &node_user : node_users[node]) {
  227. if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
  228. return nullptr;
  229. }
  230. auto user_cnode = node_user.first->cast<CNodePtr>();
  231. size_t used_idx = GetValue<int>(user_cnode->input(2)->cast<ValueNodePtr>()->value());
  232. used_output_idx.insert(used_idx);
  233. all_users.push_back(std::make_pair(node_user.first, used_idx));
  234. }
  235. if (used_output_idx.size() >= output_cnode->inputs().size() - 1) {
  236. // all output has users.
  237. return nullptr;
  238. }
  239. if (used_output_idx.empty()) {
  240. // we do not process this case.
  241. return nullptr;
  242. } else if (used_output_idx.size() == 1) {
  243. // after eliminate, only one output left.
  244. new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1));
  245. // update users.
  246. for (auto &ret_user : all_users) {
  247. (void)mng->Replace(ret_user.first, node);
  248. }
  249. } else {
  250. // after eliminate, create new multi output.
  251. std::vector<AnfNodePtr> new_output_inputs{output_cnode->input(0)};
  252. std::unordered_map<int, int> new_idx_map;
  253. for (auto idx : used_output_idx) {
  254. new_idx_map[idx] = SizeToInt(new_output_inputs.size() - 1);
  255. new_output_inputs.push_back(output_cnode->input(idx + 1));
  256. }
  257. new_fg->set_output(new_fg->NewCNode(new_output_inputs));
  258. // update users.
  259. for (auto &ret_user : all_users) {
  260. auto ret_user_cnode = ret_user.first->cast<CNodePtr>();
  261. ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second]));
  262. }
  263. }
  264. auto new_sx = inputs;
  265. new_sx[0] = NewValueNode(new_fg);
  266. return node->func_graph()->NewCNode(new_sx);
  267. }
  268. };
  269. } // namespace irpass
  270. } // namespace opt
  271. } // namespace mindspore
  272. #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_