You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

map.cc 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "operator/composite/map.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <utility>
  20. #include <vector>
  21. #include "ir/anf.h"
  22. #include "ir/func_graph.h"
  23. #include "pipeline/static_analysis/abstract_value.h"
  24. #include "pipeline/static_analysis/abstract_function.h"
  25. #include "pipeline/static_analysis/dshape.h"
  26. #include "pybind_api/api_register.h"
  27. #include "debug/trace.h"
  28. #include "operator/ops.h"
  29. #include "./common.h"
  30. namespace mindspore {
  31. // namespace to support composite operators definition
  32. namespace prim {
  33. using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
  34. AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) {
  35. MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n";
  36. MS_EXCEPTION_IF_NULL(func_graph);
  37. std::vector<AnfNodePtr> inputs;
  38. if (fn_arg != nullptr) {
  39. inputs.emplace_back(fn_arg);
  40. } else {
  41. inputs.emplace_back(NewValueNode(fn_leaf_));
  42. }
  43. inputs.insert(inputs.end(), args.begin(), args.end());
  44. return func_graph->NewCNode(inputs);
  45. }
  46. FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
  47. // Generate func for leaf nodes
  48. FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
  49. ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  50. ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
  51. ptrGraph->debug_info()->set_name("map");
  52. AnfNodePtr ptrFnArg = nullptr;
  53. if (fn_leaf_ == nullptr) {
  54. ptrFnArg = ptrGraph->add_parameter();
  55. }
  56. AnfNodePtrList args;
  57. for (size_t i = 0; i < args_size; ++i) {
  58. args.emplace_back(ptrGraph->add_parameter());
  59. }
  60. ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args));
  61. return ptrGraph;
  62. }
  63. AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
  64. const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
  65. MS_EXCEPTION_IF_NULL(func_graph);
  66. MS_EXCEPTION_IF_NULL(type);
  67. std::size_t size = type->elements().size();
  68. bool is_not_same =
  69. std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  70. auto lhs = std::dynamic_pointer_cast<List>(item.second);
  71. MS_EXCEPTION_IF_NULL(lhs);
  72. return lhs->elements().size() != size;
  73. });
  74. if (is_not_same) {
  75. MS_LOG(EXCEPTION) << "List in Map should have same length";
  76. }
  77. std::vector<AnfNodePtr> inputs;
  78. inputs.push_back(NewValueNode(prim::kPrimMakeList));
  79. for (int i = 0; i < SizeToInt(size); ++i) {
  80. MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target";
  81. auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
  82. auto fn = NewValueNode(ptrGraph);
  83. std::vector<AnfNodePtr> inputs2;
  84. inputs2.push_back(fn);
  85. if (fn_arg != nullptr) {
  86. inputs2.push_back(fn_arg);
  87. }
  88. (void)std::transform(
  89. arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
  90. [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
  91. return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
  92. });
  93. inputs.push_back(func_graph->NewCNode(inputs2));
  94. }
  95. return func_graph->NewCNode(inputs);
  96. }
  97. AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
  98. const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
  99. MS_EXCEPTION_IF_NULL(func_graph);
  100. MS_EXCEPTION_IF_NULL(type);
  101. std::size_t size = type->elements().size();
  102. bool is_not_same =
  103. std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  104. auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
  105. MS_EXCEPTION_IF_NULL(lhs);
  106. return lhs->elements().size() != size;
  107. });
  108. if (is_not_same) {
  109. MS_LOG(EXCEPTION) << "tuple in Map should have same length";
  110. }
  111. std::vector<AnfNodePtr> inputs;
  112. inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  113. for (int i = 0; i < SizeToInt(size); ++i) {
  114. MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs";
  115. auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
  116. auto fn = NewValueNode(ptrGraph);
  117. std::vector<AnfNodePtr> inputs2;
  118. inputs2.push_back(fn);
  119. if (fn_arg != nullptr) {
  120. inputs2.push_back(fn_arg);
  121. }
  122. (void)std::transform(
  123. arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
  124. [&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
  125. return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
  126. });
  127. inputs.push_back(func_graph->NewCNode(inputs2));
  128. }
  129. return func_graph->NewCNode(inputs);
  130. }
  131. AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
  132. const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
  133. MS_EXCEPTION_IF_NULL(type);
  134. MS_EXCEPTION_IF_NULL(func_graph);
  135. std::vector<AnfNodePtr> inputs;
  136. inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
  137. inputs.push_back(NewValueNode(type));
  138. std::size_t attrSize = type->GetAttributes().size();
  139. for (std::size_t i = 0; i < attrSize; ++i) {
  140. MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs";
  141. auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
  142. auto fn = NewValueNode(ptrGraph);
  143. std::vector<AnfNodePtr> inputs2;
  144. inputs2.push_back(fn);
  145. if (fn_arg != nullptr) {
  146. inputs2.push_back(fn_arg);
  147. }
  148. int j = 0;
  149. for (auto item : arg_pairs) {
  150. inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
  151. j++;
  152. }
  153. inputs.push_back(func_graph->NewCNode(inputs2));
  154. }
  155. return func_graph->NewCNode(inputs);
  156. }
  157. AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
  158. bool found = false;
  159. TypeId id = kObjectTypeEnd;
  160. std::pair<AnfNodePtr, TypePtr> pair;
  161. for (auto &item : arg_pairs) {
  162. pair = item;
  163. MS_LOG(DEBUG) << "Map " << pair.second->ToString();
  164. id = item.second->type_id();
  165. if (nonleaf_.count(id)) {
  166. found = true;
  167. break;
  168. }
  169. }
  170. if (found) {
  171. // In a nonleaf situation, all arguments must have the same generic.
  172. bool is_not_same =
  173. std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
  174. if (item.first != pair.first) {
  175. return item.second->type_id() != pair.second->type_id();
  176. }
  177. return false;
  178. });
  179. if (is_not_same) {
  180. std::ostringstream oss;
  181. oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
  182. << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
  183. int idx = 0;
  184. for (auto &item : arg_pairs) {
  185. oss << ++idx << ": " << item.second->ToString() << "\n";
  186. }
  187. MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n"
  188. << oss.str() << pair.second->ToString() << "\n";
  189. }
  190. }
  191. switch (id) {
  192. case kObjectTypeList: {
  193. auto type = std::static_pointer_cast<List>(pair.second);
  194. return FullMakeList(type, func_graph, fn_arg, arg_pairs);
  195. }
  196. case kObjectTypeTuple: {
  197. auto type = std::static_pointer_cast<Tuple>(pair.second);
  198. return FullMakeTuple(type, func_graph, fn_arg, arg_pairs);
  199. }
  200. case kObjectTypeClass: {
  201. auto type = std::static_pointer_cast<Class>(pair.second);
  202. return FullMakeClass(type, func_graph, fn_arg, arg_pairs);
  203. }
  204. default:
  205. MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class "
  206. << ", but got " << pair.second->ToString();
  207. }
  208. }
  209. FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) {
  210. FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
  211. ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  212. ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
  213. ptrGraph->debug_info()->set_name("map");
  214. AnfNodePtr ptrFnArg = nullptr;
  215. std::size_t i = 0;
  216. if (fn_leaf_ == nullptr) {
  217. ptrFnArg = ptrGraph->add_parameter();
  218. i = 1;
  219. }
  220. ArgsPairList arg_pairs;
  221. std::size_t size = args_spec_list.size();
  222. for (; i < size; ++i) {
  223. MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString();
  224. arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
  225. }
  226. ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs));
  227. return ptrGraph;
  228. }
  229. abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
  230. if (fn_leaf_ == nullptr) {
  231. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  232. // Assert that map's function param does not contain free variables
  233. if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
  234. auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
  235. auto func_graph = graph_func->func_graph();
  236. if (func_graph->parent() != nullptr) {
  237. MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet.";
  238. }
  239. }
  240. }
  241. AbstractBasePtrList broadened;
  242. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
  243. [](const AbstractBasePtr &arg) -> AbstractBasePtr {
  244. MS_EXCEPTION_IF_NULL(arg);
  245. return arg->Broaden();
  246. });
  247. return broadened;
  248. }
  249. REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) {
  250. (void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
  251. .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
  252. .def(py::init<>());
  253. }));
  254. } // namespace prim
  255. } // namespace mindspore