You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

map.cc 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/operator/composite/map.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <utility>
  20. #include <vector>
  21. #include "ir/anf.h"
  22. #include "ir/func_graph.h"
  23. #include "abstract/abstract_value.h"
  24. #include "abstract/abstract_function.h"
  25. #include "abstract/dshape.h"
  26. #include "pybind_api/api_register.h"
  27. #include "debug/trace.h"
  28. #include "frontend/operator/ops.h"
  29. namespace mindspore {
  30. // namespace to support composite operators definition
  31. namespace prim {
  32. using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
  33. AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) {
  34. MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n";
  35. MS_EXCEPTION_IF_NULL(func_graph);
  36. std::vector<AnfNodePtr> inputs;
  37. if (fn_arg != nullptr) {
  38. inputs.emplace_back(fn_arg);
  39. } else {
  40. inputs.emplace_back(NewValueNode(fn_leaf_));
  41. }
  42. inputs.insert(inputs.end(), args.begin(), args.end());
  43. return func_graph->NewCNode(inputs);
  44. }
  45. FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
  46. // Generate func for leaf nodes
  47. FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
  48. ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  49. ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
  50. ptrGraph->debug_info()->set_name("map");
  51. AnfNodePtr ptrFnArg = nullptr;
  52. if (fn_leaf_ == nullptr) {
  53. ptrFnArg = ptrGraph->add_parameter();
  54. }
  55. AnfNodePtrList args;
  56. for (size_t i = 0; i < args_size; ++i) {
  57. args.emplace_back(ptrGraph->add_parameter());
  58. }
  59. ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args));
  60. return ptrGraph;
  61. }
  62. AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
  63. const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
  64. MS_EXCEPTION_IF_NULL(func_graph);
  65. MS_EXCEPTION_IF_NULL(type);
  66. std::size_t size = type->elements().size();
  67. bool is_not_same =
  68. std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  69. auto lhs = std::dynamic_pointer_cast<List>(item.second);
  70. MS_EXCEPTION_IF_NULL(lhs);
  71. return lhs->elements().size() != size;
  72. });
  73. if (is_not_same) {
  74. MS_LOG(EXCEPTION) << "List in Map should have same length";
  75. }
  76. std::vector<AnfNodePtr> inputs;
  77. inputs.push_back(NewValueNode(prim::kPrimMakeList));
  78. for (int i = 0; i < SizeToInt(size); ++i) {
  79. MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target";
  80. auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
  81. auto fn = NewValueNode(ptrGraph);
  82. std::vector<AnfNodePtr> inputs2;
  83. inputs2.push_back(fn);
  84. if (fn_arg != nullptr) {
  85. inputs2.push_back(fn_arg);
  86. }
  87. (void)std::transform(
  88. arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
  89. [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
  90. return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
  91. });
  92. inputs.push_back(func_graph->NewCNode(inputs2));
  93. }
  94. return func_graph->NewCNode(inputs);
  95. }
  96. AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
  97. const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
  98. MS_EXCEPTION_IF_NULL(func_graph);
  99. MS_EXCEPTION_IF_NULL(type);
  100. std::size_t size = type->elements().size();
  101. bool is_not_same =
  102. std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  103. auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
  104. MS_EXCEPTION_IF_NULL(lhs);
  105. return lhs->elements().size() != size;
  106. });
  107. if (is_not_same) {
  108. MS_LOG(EXCEPTION) << "tuple in Map should have same length";
  109. }
  110. std::vector<AnfNodePtr> inputs;
  111. inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  112. for (int i = 0; i < SizeToInt(size); ++i) {
  113. MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs";
  114. auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
  115. auto fn = NewValueNode(ptrGraph);
  116. std::vector<AnfNodePtr> inputs2;
  117. inputs2.push_back(fn);
  118. if (fn_arg != nullptr) {
  119. inputs2.push_back(fn_arg);
  120. }
  121. (void)std::transform(
  122. arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
  123. [&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
  124. return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
  125. });
  126. inputs.push_back(func_graph->NewCNode(inputs2));
  127. }
  128. return func_graph->NewCNode(inputs);
  129. }
  130. AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
  131. const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
  132. MS_EXCEPTION_IF_NULL(type);
  133. MS_EXCEPTION_IF_NULL(func_graph);
  134. std::vector<AnfNodePtr> inputs;
  135. inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
  136. inputs.push_back(NewValueNode(type));
  137. std::size_t attrSize = type->GetAttributes().size();
  138. for (std::size_t i = 0; i < attrSize; ++i) {
  139. MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs";
  140. auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
  141. auto fn = NewValueNode(ptrGraph);
  142. std::vector<AnfNodePtr> inputs2;
  143. inputs2.push_back(fn);
  144. if (fn_arg != nullptr) {
  145. inputs2.push_back(fn_arg);
  146. }
  147. int j = 0;
  148. for (auto item : arg_pairs) {
  149. inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
  150. j++;
  151. }
  152. inputs.push_back(func_graph->NewCNode(inputs2));
  153. }
  154. return func_graph->NewCNode(inputs);
  155. }
  156. AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
  157. if (arg_pairs.empty()) {
  158. MS_EXCEPTION(TypeError) << "map() must have at least two arguments";
  159. }
  160. bool found = false;
  161. TypeId id = kObjectTypeEnd;
  162. std::pair<AnfNodePtr, TypePtr> pair;
  163. for (auto &item : arg_pairs) {
  164. pair = item;
  165. MS_LOG(DEBUG) << "Map " << pair.second->ToString();
  166. id = item.second->type_id();
  167. if (nonleaf_.count(id)) {
  168. found = true;
  169. break;
  170. }
  171. }
  172. if (found) {
  173. // In a nonleaf situation, all arguments must have the same generic.
  174. bool is_not_same =
  175. std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
  176. if (item.first != pair.first) {
  177. return item.second->type_id() != pair.second->type_id();
  178. }
  179. return false;
  180. });
  181. if (is_not_same) {
  182. std::ostringstream oss;
  183. oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
  184. << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
  185. int idx = 0;
  186. for (auto &item : arg_pairs) {
  187. oss << ++idx << ": " << item.second->ToString() << "\n";
  188. }
  189. MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n"
  190. << oss.str() << pair.second->ToString() << "\n";
  191. }
  192. }
  193. switch (id) {
  194. case kObjectTypeList: {
  195. auto type = std::static_pointer_cast<List>(pair.second);
  196. return FullMakeList(type, func_graph, fn_arg, arg_pairs);
  197. }
  198. case kObjectTypeTuple: {
  199. auto type = std::static_pointer_cast<Tuple>(pair.second);
  200. return FullMakeTuple(type, func_graph, fn_arg, arg_pairs);
  201. }
  202. case kObjectTypeClass: {
  203. auto type = std::static_pointer_cast<Class>(pair.second);
  204. return FullMakeClass(type, func_graph, fn_arg, arg_pairs);
  205. }
  206. default:
  207. MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class "
  208. << ", but got " << pair.second->ToString();
  209. }
  210. }
  211. FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) {
  212. FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
  213. ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  214. ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
  215. ptrGraph->debug_info()->set_name("map");
  216. AnfNodePtr ptrFnArg = nullptr;
  217. std::size_t i = 0;
  218. if (fn_leaf_ == nullptr) {
  219. ptrFnArg = ptrGraph->add_parameter();
  220. i = 1;
  221. }
  222. ArgsPairList arg_pairs;
  223. std::size_t size = args_spec_list.size();
  224. for (; i < size; ++i) {
  225. MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString();
  226. arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
  227. }
  228. ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs));
  229. return ptrGraph;
  230. }
  231. abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
  232. if (fn_leaf_ == nullptr) {
  233. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  234. // Assert that map's function param does not contain free variables
  235. if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
  236. auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
  237. auto func_graph = graph_func->func_graph();
  238. if (func_graph->parent() != nullptr) {
  239. MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet.";
  240. }
  241. }
  242. }
  243. AbstractBasePtrList broadened;
  244. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
  245. [](const AbstractBasePtr &arg) -> AbstractBasePtr {
  246. MS_EXCEPTION_IF_NULL(arg);
  247. return arg->Broaden();
  248. });
  249. return broadened;
  250. }
  251. REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) {
  252. (void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
  253. .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
  254. .def(py::init<>());
  255. }));
  256. } // namespace prim
  257. } // namespace mindspore