You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resolve.cc 17 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/jit/parse/resolve.h"
  17. #include <string>
  18. #include <memory>
  19. #include <vector>
  20. #include "ir/param_info.h"
  21. #include "pipeline/jit/parse/data_converter.h"
  22. #include "pipeline/jit/parse/parse.h"
  23. #include "pipeline/jit/parse/python_adapter.h"
  24. #include "utils/any.h"
  25. #include "frontend/operator/ops.h"
  26. #include "frontend/optimizer/opt.h"
  27. #include "frontend/optimizer/irpass.h"
  28. #include "frontend/optimizer/irpass/symbol_resolver.h"
  29. namespace mindspore {
  30. namespace parse {
  31. abstract::AbstractBasePtr ClassObject::ToAbstract() {
  32. ClassPtr cls_ptr = ParseDataClass(obj());
  33. auto abs_scalar = std::make_shared<abstract::AbstractScalar>();
  34. abs_scalar->set_type(std::make_shared<TypeType>());
  35. abs_scalar->set_value(cls_ptr);
  36. AbstractBasePtrList args_spec_list = {abs_scalar};
  37. auto func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeRecord);
  38. return std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
  39. }
  40. static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
  41. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  42. auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj);
  43. if (!py::isinstance<py::bool_>(res)) {
  44. MS_LOG(ERROR) << "Expect a bool type, but got " << py::str(res);
  45. return false;
  46. }
  47. return res.cast<bool>();
  48. }
  49. abstract::AbstractBasePtr ClassType::ToAbstract() {
  50. auto abs_scalar =
  51. std::make_shared<abstract::AbstractScalar>(shared_from_base<ClassType>(), std::make_shared<TypeType>());
  52. // The fallback feature is enabled in default.
  53. // Not support change the flag during the process is alive.
  54. static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
  55. static const auto use_fallback = (support_fallback != "0");
  56. if (use_fallback && !IsSupportedCreateInstanceType(obj())) {
  57. return abs_scalar;
  58. }
  59. AbstractBasePtrList args_spec_list = {abs_scalar};
  60. auto func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
  61. auto ret_val = std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
  62. ret_val->set_value_desc(ToString());
  63. return ret_val;
  64. }
  65. // call python PYTHON_MOD_RESOLVE_FUNCTION interface to resolve the symbol in corresponding namespace
  66. bool SymbolResolver::Resolve() {
  67. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  68. py::object obj = namespace_->obj();
  69. std::string symbol = symbol_->symbol();
  70. if (py::isinstance<py::none>(obj)) {
  71. MS_EXCEPTION(NameError) << "The name \'" << symbol << "\' is not defined.";
  72. }
  73. result_ = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_FUNCTION, obj, common::SafeCStr(symbol));
  74. return true;
  75. }
  76. namespace {
  77. // If any mixed precision flag add a cast node after the parameter node.
  78. // argument obj should be python Parameter object
  79. // it will be converted to Parameter node here
  80. AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
  81. MS_EXCEPTION_IF_NULL(func_graph);
  82. // Parameter object should not be none
  83. if (py::isinstance<py::none>(obj)) {
  84. MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null.";
  85. }
  86. if (!py::hasattr(obj, "name")) {
  87. MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj";
  88. }
  89. // Get the parameter name from parameter object
  90. auto name_attr = python_adapter::GetPyObjAttr(obj, "name");
  91. if (py::isinstance<py::none>(name_attr)) {
  92. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  93. }
  94. auto param_name = py::cast<std::string>(name_attr);
  95. auto top_func_graph = Parser::GetTopFuncGraph();
  96. // If the parameter node has been created , return it.
  97. AnfNodePtr para_node = nullptr;
  98. for (auto const &param : top_func_graph->parameters()) {
  99. auto param_node = dyn_cast<Parameter>(param);
  100. if (param_node != nullptr && param_node->name() == param_name && !param_node->is_top_graph_param()) {
  101. para_node = param;
  102. MS_LOG(DEBUG) << "Found existing parameter for " << func_graph->ToString()
  103. << ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
  104. break;
  105. }
  106. }
  107. if (para_node == nullptr) {
  108. auto node = top_func_graph->AddWeightParameter(param_name);
  109. auto value = py::cast<tensor::MetaTensorPtr>(obj);
  110. node->set_default_param(value);
  111. // Set abstract for parameter
  112. auto abs = value->ToAbstract();
  113. node->set_abstract(abs);
  114. para_node = node;
  115. MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString()
  116. << ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
  117. }
  118. func_graph->add_parameter_obj_node(para_node);
  119. return para_node;
  120. }
  121. void BroadenCNodeAbstract(const FuncGraphPtr &func_graph) {
  122. std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
  123. for (const AnfNodePtr &node : nodes) {
  124. if (!node->isa<CNode>()) {
  125. continue;
  126. }
  127. auto abstract = node->abstract();
  128. if (abstract != nullptr) {
  129. node->set_abstract(abstract->Broaden());
  130. }
  131. }
  132. }
  133. void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
  134. if (!value->isa<FuncGraph>()) {
  135. return;
  136. }
  137. auto resolved_graph = value->cast<FuncGraphPtr>();
  138. MS_EXCEPTION_IF_NULL(resolved_graph);
  139. if (!resolved_graph->has_attr("is_load")) {
  140. return;
  141. }
  142. auto top_graph = Parser::GetTopFuncGraph();
  143. std::vector<AnfNodePtr> input_params;
  144. for (auto const &param : resolved_graph->parameters()) {
  145. auto param_ptr = dyn_cast<Parameter>(param);
  146. MS_EXCEPTION_IF_NULL(param_ptr);
  147. if (param_ptr->has_default()) {
  148. param_ptr->set_func_graph(top_graph);
  149. func_graph->add_parameter_obj_node(param_ptr);
  150. // Update top_graph
  151. top_graph->add_parameter(param_ptr);
  152. size_t hyper_param_count = top_graph->hyper_param_count();
  153. top_graph->set_hyper_param_count(hyper_param_count + 1);
  154. } else {
  155. input_params.push_back(param_ptr);
  156. }
  157. }
  158. resolved_graph->set_parameters(input_params);
  159. BroadenCNodeAbstract(resolved_graph);
  160. }
  161. bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
  162. AnfNodePtr output = nullptr;
  163. if (py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj)) {
  164. auto param = ResolveParameterObj(func_graph, obj);
  165. if (param == nullptr) {
  166. MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr";
  167. return false;
  168. }
  169. MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString();
  170. output = param;
  171. } else if (py::hasattr(obj, "__parameter_tuple__")) {
  172. auto tuple = obj.cast<py::tuple>();
  173. std::vector<AnfNodePtr> args;
  174. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  175. for (size_t it = 0; it < tuple.size(); ++it) {
  176. AnfNodePtr out = nullptr;
  177. bool success = ResolveObjectToNode(func_graph, tuple[it], &out);
  178. if (!success) {
  179. MS_LOG(ERROR) << "Resolve object to node failed";
  180. return false;
  181. }
  182. args.push_back(out);
  183. }
  184. output = NewCNode(std::move(args), func_graph);
  185. } else {
  186. ValuePtr convert_result = nullptr;
  187. bool converted = ConvertData(obj, &convert_result, parse::python_adapter::UseSignatureInResolve());
  188. if (!converted) {
  189. MS_LOG(ERROR) << "Convert data failed";
  190. return false;
  191. }
  192. MS_EXCEPTION_IF_NULL(convert_result);
  193. ConvertLoadedGraph(func_graph, convert_result);
  194. output = NewValueNode(convert_result);
  195. if (convert_result->isa<tensor::Tensor>()) {
  196. output = GetMixedPrecisionCastHelp(func_graph, output);
  197. }
  198. }
  199. *node = output;
  200. return true;
  201. }
  202. bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) {
  203. if (value_vec.empty()) {
  204. return false;
  205. }
  206. for (auto &elem : value_vec) {
  207. MS_EXCEPTION_IF_NULL(elem);
  208. if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
  209. const auto &vec = GetValue<ValuePtrList>(elem);
  210. auto is_graph = IsAllFuncInValueSequence(vec);
  211. if (!is_graph) {
  212. return false;
  213. }
  214. } else if (!elem->isa<FuncGraph>() && !elem->isa<Primitive>()) {
  215. return false;
  216. }
  217. }
  218. return true;
  219. }
  220. AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
  221. const std::vector<ValuePtr> &value_vec) {
  222. std::vector<AnfNodePtr> nodes;
  223. nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  224. for (auto &elem : value_vec) {
  225. MS_EXCEPTION_IF_NULL(elem);
  226. AnfNodePtr node = nullptr;
  227. if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
  228. const auto &vec = GetValue<std::vector<ValuePtr>>(elem);
  229. node = TransformToMakeTupleNodes(manager, func_graph, vec);
  230. } else if (elem->isa<FuncGraph>()) {
  231. FuncGraphPtr new_fg = elem->cast<FuncGraphPtr>();
  232. manager->AddFuncGraph(new_fg);
  233. node = NewValueNode(new_fg);
  234. } else if (elem->isa<Primitive>()) {
  235. node = NewValueNode(elem);
  236. } else {
  237. MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString();
  238. }
  239. nodes.emplace_back(node);
  240. }
  241. auto cnode = func_graph->NewCNode(std::move(nodes));
  242. return cnode;
  243. }
  244. // Transform the ValueTuple or ValueList of graph/primitive node to make tuple of const graph/primitive node
  245. bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
  246. const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
  247. MS_EXCEPTION_IF_NULL(value_node);
  248. const auto &value_vec = GetValue<ValuePtrList>(value_node->value());
  249. if (!IsAllFuncInValueSequence(value_vec)) {
  250. return false;
  251. }
  252. // (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
  253. // So if has graph in list, try to replace the node with make tuple of graph value node.
  254. // We do this because the graph manager won't investigate the graph inside valuetuple,
  255. // change the vector of graph to be make_tuple of graph value node.
  256. // (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all
  257. // independent nodes.
  258. auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec);
  259. // Replace the ret ptr to be make tuple of graph value node
  260. *transformed = node_tuple_graphs;
  261. return true;
  262. }
  263. // Resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager.
  264. AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj,
  265. const AnfNodePtr &node) {
  266. ScopeGuard scope_guard(node->scope());
  267. AnfNodePtr resolved_node = nullptr;
  268. bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node);
  269. if (!success) {
  270. MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo.";
  271. }
  272. if (IsValueNode<FuncGraph>(resolved_node)) {
  273. auto new_fg = GetValueNode<FuncGraphPtr>(resolved_node);
  274. manager->AddFuncGraph(new_fg);
  275. }
  276. // If the constant node is constant of vector of graph, add graph to manager.
  277. if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
  278. (void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
  279. &resolved_node);
  280. }
  281. return resolved_node;
  282. }
  283. } // namespace
  284. AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
  285. const AnfNodePtr &node) {
  286. MS_EXCEPTION_IF_NULL(node);
  287. TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
  288. if (node->func_graph() == nullptr || manager == nullptr) {
  289. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
  290. }
  291. SymbolResolver symbol_resolver(name_space, symbol, node);
  292. symbol_resolver.Resolve();
  293. py::object obj = symbol_resolver.result();
  294. AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
  295. return resolved_node;
  296. }
  297. AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
  298. const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr) {
  299. MS_EXCEPTION_IF_NULL(node);
  300. TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
  301. if (node->func_graph() == nullptr || manager == nullptr) {
  302. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
  303. }
  304. SymbolResolver symbol_resolver(name_space, symbol, node);
  305. if (!symbol_resolver.Resolve()) {
  306. MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo.";
  307. }
  308. py::object obj = symbol_resolver.result();
  309. if (!data_converter::IsCellInstance(obj)) {
  310. AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
  311. AnfNodePtrList inputs = {NewValueNode(prim::kPrimGetAttr), resolved_node, attr};
  312. AnfNodePtr res_node = node->func_graph()->NewCNode(std::move(inputs));
  313. TraceManager::ClearParseOrResolveDebugInfo();
  314. return res_node;
  315. }
  316. const std::string fn = PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL;
  317. const std::string module = "mindspore._extends.parse.parser";
  318. py::object namespace_obj = parse::python_adapter::GetPyFn(module, fn)(obj);
  319. auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
  320. std::string attr_as_string = GetValueNode<StringImmPtr>(attr)->value();
  321. auto new_symbol = std::make_shared<Symbol>(attr_as_string);
  322. MS_LOG(DEBUG) << "name_space: " << new_namespace->ToString() << ", symbol: " << new_symbol->ToString();
  323. AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)};
  324. AnfNodePtr resolved_node = node->func_graph()->NewCNode(std::move(inputs));
  325. TraceManager::ClearParseOrResolveDebugInfo();
  326. return resolved_node;
  327. }
  328. namespace {
  329. opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
  330. // For resolve and getattr primitive.
  331. opt::OptPassGroupMap map({
  332. {"resolve",
  333. {
  334. irpass.resolver_getattr_resolve_,
  335. }},
  336. });
  337. return map;
  338. }
  339. } // namespace
  340. bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) {
  341. if (func_graph == nullptr || res == nullptr) {
  342. MS_LOG(ERROR) << "func_graph or resource is null";
  343. return false;
  344. }
  345. opt::irpass::ResolveIRPassLib irpass;
  346. opt::OptimizerPtr opt_resolve =
  347. opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass), false, false, false);
  348. (void)parse::python_adapter::set_python_scoped();
  349. MS_EXCEPTION_IF_NULL(opt_resolve);
  350. (void)opt_resolve->step(func_graph, use_profile);
  351. return true;
  352. }
  353. bool ResolveAll(const FuncGraphManagerPtr &manager) {
  354. if (manager == nullptr) {
  355. MS_LOG(ERROR) << "func graph manager is null";
  356. return false;
  357. }
  358. if (manager->roots().size() > 1) {
  359. MS_LOG(WARNING)
  360. << "After call ResolveAll, only one graph will be kept in GraphManager. ResolveAll can resolve graphs"
  361. "called from root graph, so it's not necessary to pass all graphs as roots. "
  362. "Please ensure your usage.";
  363. }
  364. // Should not use pipeline::Resource as Resource::Clean will clean some
  365. // global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop
  366. // fail as valid scope has been cleaned.
  367. auto res = std::make_shared<pipeline::ResourceBase>();
  368. res->set_manager(manager);
  369. auto roots = manager->roots();
  370. for (auto &fg : roots) {
  371. bool ret = ResolveFuncGraph(fg, res, false);
  372. if (!ret) {
  373. MS_EXCEPTION_IF_NULL(fg);
  374. MS_LOG(ERROR) << "Resolve fg " << fg->ToString() << " failed";
  375. }
  376. }
  377. return true;
  378. }
  379. } // namespace parse
  380. } // namespace mindspore