You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resolve.cc 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/parse/resolve.h"
  17. #include <string>
  18. #include <memory>
  19. #include <vector>
  20. #include <algorithm>
  21. #include "pipeline/parse/data_converter.h"
  22. #include "pipeline/parse/parse.h"
  23. #include "pipeline/parse/python_adapter.h"
  24. #include "utils/any.h"
  25. #include "operator/ops.h"
  26. #include "optimizer/opt.h"
  27. #include "optimizer/irpass.h"
  28. #include "./common.h"
  29. namespace mindspore {
  30. namespace parse {
  31. abstract::AbstractBasePtr ClassObject::ToAbstract() {
  32. ClassPtr cls_ptr = ParseDataClass(obj());
  33. auto abs_scalar = std::make_shared<abstract::AbstractScalar>();
  34. abs_scalar->set_type(std::make_shared<TypeType>());
  35. abs_scalar->set_value(cls_ptr);
  36. AbstractBasePtrList args_spec_list = {abs_scalar};
  37. auto func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeRecord);
  38. return std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
  39. }
  40. abstract::AbstractBasePtr ClassType::ToAbstract() {
  41. auto abs_scalar =
  42. std::make_shared<abstract::AbstractScalar>(shared_from_base<ClassType>(), std::make_shared<TypeType>());
  43. AbstractBasePtrList args_spec_list = {abs_scalar};
  44. auto func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
  45. auto ret_val = std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
  46. ret_val->set_value_desc(ToString());
  47. return ret_val;
  48. }
  49. // call python PYTHON_MOD_RESOLVE_FUNCTION interface to resolve the symbol in corresponding namespace
  50. bool SymbolResolver::Resolve() {
  51. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  52. py::object obj = namespace_->obj();
  53. std::string symbol = symbol_->symbol();
  54. if (py::isinstance<py::none>(obj)) {
  55. MS_LOG(ERROR) << "Unresolved symbol: " << symbol;
  56. return false;
  57. }
  58. result_ = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_FUNCTION, obj, common::SafeCStr(symbol));
  59. return true;
  60. }
  61. namespace {
  62. // argument obj should be python Parameter object
  63. // it will be converted to Parameter node here
  64. AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
  65. MS_EXCEPTION_IF_NULL(func_graph);
  66. // parameter object should not be none
  67. if (py::isinstance<py::none>(obj)) {
  68. MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null.";
  69. }
  70. if (!py::hasattr(obj, "name")) {
  71. MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj";
  72. }
  73. // get the parameter name from parameter object
  74. auto name_attr = python_adapter::GetPyObjAttr(obj, "name");
  75. if (py::isinstance<py::none>(name_attr)) {
  76. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  77. }
  78. std::string param_name = py::cast<std::string>(name_attr);
  79. auto top_graph = Parser::GetTopFuncGraph();
  80. // if the parameter node has been created , return it
  81. AnfNodePtr para_node = nullptr;
  82. for (auto param : top_graph->parameters()) {
  83. auto param_node = dyn_cast<Parameter>(param);
  84. if (param_node != nullptr && param_node->name() == param_name) {
  85. para_node = param;
  86. break;
  87. }
  88. }
  89. if (para_node == nullptr) {
  90. ParameterPtr node = top_graph->AddWeightParameter(param_name);
  91. node->set_default_param(obj);
  92. // set_abstract for parameter
  93. auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
  94. ValuePtr converted = nullptr;
  95. (void)ConvertData(to_convert, &converted);
  96. bool broaden = true;
  97. node->set_abstract(abstract::FromValue(converted, broaden));
  98. para_node = node;
  99. }
  100. auto iter = func_graph->make_ref_params().find(para_node);
  101. if (iter == func_graph->make_ref_params().end()) {
  102. AnfNodePtr value = GetMixedPrecisionCastHelp(func_graph, para_node);
  103. AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
  104. AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name));
  105. AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, value, para_node});
  106. func_graph->make_ref_params()[para_node] = ref_node;
  107. func_graph->add_parameter_obj_node(ref_node);
  108. return ref_node;
  109. } else {
  110. return iter->second;
  111. }
  112. }
  113. bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
  114. AnfNodePtr output = nullptr;
  115. if (py::hasattr(obj, "__parameter__")) {
  116. auto param = ResolveParameterObj(func_graph, obj);
  117. if (param == nullptr) {
  118. MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr";
  119. return false;
  120. }
  121. MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString();
  122. output = param;
  123. } else if (py::hasattr(obj, "__parameter_tuple__")) {
  124. auto tuple = obj.cast<py::tuple>();
  125. std::vector<AnfNodePtr> args;
  126. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  127. for (size_t it = 0; it < tuple.size(); ++it) {
  128. AnfNodePtr out = nullptr;
  129. bool success = ResolveObjectToNode(func_graph, tuple[it], &out);
  130. if (!success) {
  131. MS_LOG(ERROR) << "Resolve object to node failed";
  132. return false;
  133. }
  134. args.push_back(out);
  135. }
  136. output = NewCNode(args, func_graph);
  137. } else {
  138. ValuePtr convert_result = nullptr;
  139. bool converted = ConvertData(obj, &convert_result, parse::python_adapter::UseSignatureInResolve());
  140. if (!converted) {
  141. MS_LOG(ERROR) << "Convert data failed";
  142. return false;
  143. }
  144. MS_EXCEPTION_IF_NULL(convert_result);
  145. output = NewValueNode(convert_result);
  146. if (convert_result->isa<tensor::Tensor>()) {
  147. output = GetMixedPrecisionCastHelp(func_graph, output);
  148. }
  149. }
  150. *node = output;
  151. return true;
  152. }
  153. // transform the ValueTuple or ValueList of graph node to make tuple of const graph node
  154. bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node,
  155. const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
  156. MS_EXCEPTION_IF_NULL(value_node);
  157. const auto &value_vec = GetValue<std::vector<ValuePtr>>(value_node->value());
  158. bool has_graph_in_list = false;
  159. for (auto &elemv : value_vec) {
  160. MS_EXCEPTION_IF_NULL(elemv);
  161. if (elemv->isa<FuncGraph>()) {
  162. FuncGraphPtr new_fg = elemv->cast<FuncGraphPtr>();
  163. manager->AddFuncGraph(new_fg);
  164. has_graph_in_list = true;
  165. continue;
  166. }
  167. if (has_graph_in_list) {
  168. MS_LOG(EXCEPTION) << "List has graph in it, but not all is graph";
  169. }
  170. }
  171. // The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
  172. // So if has graph in list, try to replace the node with make tuple of graph value node.
  173. if (has_graph_in_list) {
  174. // change the vector of graph to be make_list of graph value node
  175. std::vector<AnfNodePtr> list_vec;
  176. auto make_list_op = NewValueNode(prim::kPrimMakeTuple);
  177. list_vec.emplace_back(make_list_op);
  178. (void)std::transform(std::begin(value_vec), std::end(value_vec), std::back_inserter(list_vec),
  179. [](const ValuePtr &value) { return NewValueNode(value); });
  180. FuncGraphPtr cnode_graph = nullptr;
  181. auto users = manager->node_users()[node];
  182. for (auto &use : users) {
  183. auto use_node = use.first;
  184. MS_EXCEPTION_IF_NULL(use_node);
  185. if (use_node->isa<CNode>()) {
  186. cnode_graph = use_node->func_graph();
  187. }
  188. }
  189. if (cnode_graph) {
  190. CNodePtr list_app = cnode_graph->NewCNode(list_vec);
  191. // replace the ret ptr to be make_list of graph value node
  192. *transformed = list_app;
  193. } else {
  194. MS_LOG(EXCEPTION) << "Can not find apply for node use when replacing node of vector of graph";
  195. }
  196. }
  197. return true;
  198. }
  199. } // namespace
  200. AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
  201. const AnfNodePtr &node) {
  202. if (node->func_graph() == nullptr || manager == nullptr) {
  203. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
  204. }
  205. SymbolResolver symbol_resolver(name_space, symbol, node);
  206. if (!symbol_resolver.Resolve()) {
  207. MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  208. }
  209. py::object obj = symbol_resolver.result();
  210. ScopeGuard scope_guard(node->scope());
  211. AnfNodePtr resolved_node = nullptr;
  212. TraceManager::DebugTrace(std::make_shared<TraceResolve>(node->debug_info()));
  213. bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node);
  214. if (!success) {
  215. MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  216. }
  217. if (IsValueNode<FuncGraph>(resolved_node)) {
  218. auto new_fg = GetValueNode<FuncGraphPtr>(resolved_node);
  219. manager->AddFuncGraph(new_fg);
  220. }
  221. // if the constant node is constant of vector of graph ,add graph to manager
  222. if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
  223. (void)TransformVectorGraphValueNode(manager, node, resolved_node->cast<ValueNodePtr>(), &resolved_node);
  224. }
  225. TraceManager::EndTrace();
  226. return resolved_node;
  227. }
  228. namespace {
  229. opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
  230. opt::OptPassGroupMap map({
  231. {"resolve",
  232. {
  233. // for resolve and getattr primitive;
  234. irpass.resolver_resolve_,
  235. irpass.resolver_getattr_,
  236. }},
  237. });
  238. return map;
  239. }
  240. } // namespace
  241. bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) {
  242. if (func_graph == nullptr || res == nullptr) {
  243. MS_LOG(ERROR) << "func_graph or resource is null";
  244. return false;
  245. }
  246. opt::irpass::ResolveIRPassLib irpass;
  247. opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass));
  248. (void)parse::python_adapter::set_python_scoped();
  249. MS_EXCEPTION_IF_NULL(opt_resolve);
  250. (void)opt_resolve->step(func_graph, use_profile);
  251. return true;
  252. }
  253. bool ResolveAll(const FuncGraphManagerPtr &manager) {
  254. if (manager == nullptr) {
  255. MS_LOG(ERROR) << "func graph manager is null";
  256. return false;
  257. }
  258. if (manager->roots().size() > 1) {
  259. MS_LOG(WARNING)
  260. << "After call ResolveAll, only one graph will be kept in GraphManager. ResolveAll can resolve graphs"
  261. "called from root graph, so it's not necessary to pass all graphs as roots. "
  262. "Please ensure your usage.";
  263. }
  264. // should not use pipeline::Resource as Resource::Clean will clean some
  265. // global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop
  266. // fail as valid scope has been cleaned.
  267. auto res = std::make_shared<pipeline::ResourceBase>();
  268. res->set_manager(manager);
  269. auto roots = manager->roots();
  270. for (auto &fg : roots) {
  271. bool ret = ResolveFuncGraph(fg, res, false);
  272. if (!ret) {
  273. MS_EXCEPTION_IF_NULL(fg);
  274. MS_LOG(ERROR) << "Resolve fg " << fg->ToString() << " failed";
  275. }
  276. }
  277. return true;
  278. }
  279. } // namespace parse
  280. } // namespace mindspore