You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

action.cc 23 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/jit/action.h"
  17. #include <memory>
  18. #include <utility>
  19. #include <vector>
  20. #include <string>
  21. #include <algorithm>
  22. #include <functional>
  23. #include "ir/func_graph_cloner.h"
  24. #include "ir/param_info.h"
  25. #include "ir/cell.h"
  26. #include "frontend/parallel/costmodel_context.h"
  27. #include "frontend/parallel/context.h"
  28. #include "pipeline/jit/pass.h"
  29. #include "pipeline/jit/parse/parse_base.h"
  30. #include "pipeline/jit/parse/data_converter.h"
  31. #include "abstract/abstract_value.h"
  32. #include "pipeline/jit/static_analysis/static_analysis.h"
  33. #include "pipeline/jit/static_analysis/program_specialize.h"
  34. #include "pipeline/jit/resource.h"
  35. #include "utils/ms_context.h"
  36. #include "pipeline/jit/remove_value_node_dup.h"
  37. #include "frontend/optimizer/optimizer.h"
  38. #include "vm/transform.h"
  39. #include "parse/python_adapter.h"
  40. #include "frontend/optimizer/py_pass_manager.h"
  41. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  42. #include "ps/parameter_server.h"
  43. #include "ps/scheduler.h"
  44. #include "ps/worker.h"
  45. #endif
  46. namespace mindspore {
  47. namespace pipeline {
  48. using CompileGraphs = compile::CompileGraphs;
  49. using abstract::AnalysisResult;
  50. using mindspore::abstract::AnalysisContextPtr;
  51. abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
  52. const abstract::AbstractBasePtrList &args_spec, bool clear) {
  53. MS_LOG(DEBUG) << "AbstractAnalyze start";
  54. auto engine = res->engine();
  55. MS_EXCEPTION_IF_NULL(engine);
  56. if (clear) {
  57. auto manager = res->manager();
  58. MS_EXCEPTION_IF_NULL(manager);
  59. engine->Clear();
  60. for (auto &node : manager->all_nodes()) {
  61. MS_EXCEPTION_IF_NULL(node);
  62. const AbstractBasePtr &prev_inferred = node->abstract();
  63. // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
  64. if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
  65. node->set_abstract(nullptr);
  66. MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr";
  67. }
  68. }
  69. }
  70. auto ret = engine->Run(func_graph, args_spec);
  71. MS_LOG(DEBUG) << "AbstractAnalyze end";
  72. return ret;
  73. }
  74. FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
  75. const abstract::AnalysisContextPtr &context) {
  76. MS_LOG(DEBUG) << "ProgramSpecialize start";
  77. abstract::ProgramSpecializer spc(res->engine());
  78. FuncGraphPtr result = spc.Run(func_graph, context);
  79. auto manager = res->manager();
  80. MS_EXCEPTION_IF_NULL(manager);
  81. manager->KeepRoots({result});
  82. MS_LOG(DEBUG) << "ProgramSpecialize end";
  83. return result;
  84. }
  85. FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
  86. const abstract::AbstractBasePtrList &args_spec) {
  87. MS_LOG(DEBUG) << "Renormalize start";
  88. #ifdef ENABLE_PROFILE
  89. double t1 = GetTime();
  90. #endif
  91. abstract::AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec, true);
  92. #ifdef ENABLE_PROFILE
  93. double t2 = GetTime();
  94. #endif
  95. auto ret = ProgramSpecialize(res, func_graph, result.context);
  96. res->set_func_graph(ret);
  97. #ifdef ENABLE_PROFILE
  98. double t3 = GetTime();
  99. MsProfile::StatTime("renormalize.infer", t2 - t1);
  100. MsProfile::StatTime("renormalize.specialize", t3 - t2);
  101. #endif
  102. MS_LOG(DEBUG) << "Renormalize end";
  103. return ret;
  104. }
  105. bool ParseAction(const ResourcePtr &res) {
  106. if (!res->input()) {
  107. MS_LOG(EXCEPTION) << "Parse error";
  108. }
  109. py::object input = res->input();
  110. parse::Parser::InitParserEnvironment(input);
  111. py::module path = py::module::import("os.path");
  112. std::string dir = path.attr("dirname")(py::globals()["__file__"]).cast<std::string>();
  113. parse::python_adapter::set_python_env_flag(true);
  114. parse::python_adapter::SetPythonPath(dir);
  115. ValuePtr converted_ret = nullptr;
  116. bool converted = parse::ConvertData(input, &converted_ret, true);
  117. if (!converted) {
  118. MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(input));
  119. }
  120. FuncGraphPtr top_graph = nullptr;
  121. if (py::isinstance<Cell>(input)) {
  122. top_graph = parse::MakeTopGraph(input, converted_ret);
  123. } else if (converted_ret->isa<FuncGraph>()) {
  124. top_graph = converted_ret->cast<FuncGraphPtr>();
  125. } else {
  126. MS_LOG(EXCEPTION) << "Object to parse " << std::string(py::str(input)) << " is not function or cell.";
  127. }
  128. parse::Parser::UpdateTopFuncGraph(top_graph);
  129. res->set_func_graph(top_graph);
  130. FuncGraphManagerPtr manager = res->manager();
  131. if (manager == nullptr) {
  132. MS_LOG(EXCEPTION) << "Manager is nullptr.";
  133. }
  134. manager->AddFuncGraph(top_graph);
  135. return true;
  136. }
  137. // obj_map's graphs have the same construct, these graphs can be optimized to one graph.
  138. // This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}->
  139. // graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx}
  140. // all obj_map's graph shared base_graph
  141. bool CombineLikeGraphs(const ResourcePtr &res) {
  142. auto &obj_map = parse::data_converter::GetObjGraphs();
  143. for (auto it : obj_map) {
  144. auto &graphs = it.second;
  145. MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size();
  146. auto fg = graphs[0];
  147. FuncGraphPtrList func_graphs = {fg};
  148. ClonerPtr cloner = std::make_shared<Cloner>(func_graphs, false, false, true, std::make_shared<TraceCopy>(),
  149. std::make_shared<TraceCombileLikeGraphs>());
  150. cloner->Run();
  151. auto base_graph = cloner->cloned_func_graph()[fg];
  152. MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
  153. if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) {
  154. continue;
  155. }
  156. auto &cloned_nodes = *cloner->cloned_node();
  157. for (auto &fv : fg->paramter_obj_nodes()) {
  158. TraceManager::DebugTrace(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
  159. auto param = base_graph->add_parameter();
  160. TraceManager::EndTrace();
  161. auto &node_users = res->manager()->node_users()[fv];
  162. for (auto &n : node_users) {
  163. // If the user is not in this graph, no need to change.
  164. auto cloned = cloned_nodes[n.first];
  165. if (cloned == nullptr) {
  166. continue;
  167. }
  168. auto repl_n = cloned->cast<CNodePtr>();
  169. repl_n->set_input(n.second, param);
  170. }
  171. }
  172. MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size();
  173. for (auto &g : graphs) {
  174. auto fvs = g->paramter_obj_nodes();
  175. std::vector<AnfNodePtr> new_node_inputs;
  176. new_node_inputs.push_back(NewValueNode(base_graph));
  177. for (auto &p : g->parameters()) {
  178. AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p);
  179. new_node_inputs.push_back(para_after_cast);
  180. }
  181. (void)new_node_inputs.insert(new_node_inputs.end(), fvs.begin(), fvs.end());
  182. AnfNodePtr out = g->NewCNode(new_node_inputs);
  183. g->set_output(out);
  184. MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(4);
  185. }
  186. MS_LOG(DEBUG) << "End combine graph:" << it.first;
  187. }
  188. return true;
  189. }
  190. bool SymbolResolveAction(const ResourcePtr &res) {
  191. if (res->manager() == nullptr) {
  192. MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null";
  193. }
  194. if (res->func_graph() == nullptr) {
  195. MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null";
  196. }
  197. FuncGraphPtr func_graph = res->func_graph();
  198. auto succ = parse::ResolveFuncGraph(func_graph, res);
  199. // Remove unused nodes in cnode order list.
  200. func_graph->EraseUnusedNodeInOrder();
  201. func_graph->ReleaseFullOrderToEffectOrder();
  202. for (auto fg : func_graph->func_graphs_used_total()) {
  203. MS_EXCEPTION_IF_NULL(fg);
  204. fg->EraseUnusedNodeInOrder();
  205. fg->ReleaseFullOrderToEffectOrder();
  206. }
  207. return succ;
  208. }
  209. bool InferenceOptPrepareAction(const ResourcePtr &res) {
  210. if (res->manager() == nullptr) {
  211. MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
  212. }
  213. if (res->func_graph() == nullptr) {
  214. MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null.";
  215. }
  216. return InferenceOptPreparePass(res);
  217. }
  218. bool AbstractSpecializeAction(const ResourcePtr &res) {
  219. if (res->func_graph() == nullptr) {
  220. MS_LOG(EXCEPTION) << "AbstractSpecialize error";
  221. }
  222. FuncGraphPtr func_graph = res->func_graph();
  223. abstract::AbstractBasePtrList args_spec = res->args_spec();
  224. parallel::ParallelParameterContextInit(func_graph);
  225. // suppose that there is not KeywordArgument for the top graph
  226. // get the hyper parameter
  227. for (const auto &param : func_graph->parameters()) {
  228. auto param_node = std::static_pointer_cast<Parameter>(param);
  229. if (param_node->has_default()) {
  230. auto value = param_node->default_param();
  231. auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
  232. auto ref_key = std::make_shared<RefKey>(param_node->name());
  233. auto abs_ref_key = ref_key->ToAbstract();
  234. auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
  235. parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref);
  236. args_spec.push_back(abs_ref);
  237. parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref);
  238. }
  239. }
  240. // Analyze
  241. AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec);
  242. // The top graph may be replaced by infer, update the top graph when the infer is done
  243. parse::Parser::UpdateTopFuncGraph(result.context->func_graph());
  244. // Specialize
  245. FuncGraphPtr new_fg = ProgramSpecialize(res, result.context->func_graph(), result.context);
  246. res->set_func_graph(new_fg);
  247. MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
  248. return true;
  249. }
  250. bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) {
  251. size_t counter = 0;
  252. for (auto &pass : passes) {
  253. WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() {
  254. MS_LOG(DEBUG) << "Pass " << pass.first << " start ...";
  255. auto result = pass.second(res);
  256. if (!result) {
  257. MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first;
  258. }
  259. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && res->func_graph() != nullptr) {
  260. auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first;
  261. auto func_graph = res->func_graph();
  262. MS_EXCEPTION_IF_NULL(func_graph);
  263. func_graph->DumpFuncGraph(fg_name);
  264. DumpIR(fg_name + ".ir", func_graph);
  265. ExportIR(fg_name + ".dat", "", func_graph);
  266. MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
  267. }
  268. counter++;
  269. MS_LOG(DEBUG) << "Pass " << pass.first << " end.";
  270. };
  271. }
  272. return true;
  273. }
  274. bool OptInlineAction(const ResourcePtr &res) {
  275. if (opt::python_pass::PyPassManager::GetInstance()->GetPassGroup(opt::python_pass::Phase::PREAD)->size() != 0) {
  276. return OptimizeAction(res, kInlinePasses);
  277. }
  278. return true;
  279. }
  280. bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }
  281. bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
  282. bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); }
  283. bool PynativeElimOpt(const ResourcePtr &res) {
  284. if (res->manager() == nullptr) {
  285. MS_LOG(EXCEPTION) << "PynativeElimOpt error, manager is null.";
  286. }
  287. if (res->func_graph() == nullptr) {
  288. MS_LOG(EXCEPTION) << "PynativeElimOpt error, graph is null.";
  289. }
  290. return PynativeOptPass(res);
  291. }
  292. static bool IsCtrlSink() {
  293. auto ms_ctx = MsContext::GetInstance();
  294. if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) {
  295. return false;
  296. }
  297. std::string device_target = ms_ctx->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  298. if (device_target != kAscendDevice) {
  299. return false;
  300. }
  301. if (!ms_ctx->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
  302. return false;
  303. }
  304. if (!ms_ctx->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
  305. return false;
  306. }
  307. return true;
  308. }
  309. bool TaskEmitAction(const ResourcePtr &res) {
  310. if (res->func_graph() == nullptr) {
  311. MS_LOG(EXCEPTION) << "TaskEmit args error";
  312. }
  313. FuncGraphPtr func_graph = res->func_graph();
  314. auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
  315. auto context_ptr = MsContext::GetInstance();
  316. std::string backend = MsContext::GetInstance()->backend_policy();
  317. MS_EXCEPTION_IF_NULL(context_ptr);
  318. if (CompileGraphs::ContainMixedTarget(func_graph)) {
  319. bc_ptr->set_is_multi_graph_sink(false);
  320. context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
  321. context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
  322. } else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
  323. std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  324. if (device_target == kAscendDevice && backend != kMsVm) {
  325. bc_ptr->set_is_multi_graph_sink(true);
  326. context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
  327. }
  328. }
  329. if (IsCtrlSink() && backend == kMsConvert) {
  330. res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
  331. return true;
  332. }
  333. std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
  334. if (bc_ptr->name() == kMsConvert) {
  335. cut_list = compile::GetMsNonlinearOps();
  336. }
  337. std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
  338. res->results()[kOutput] = compile->CompileAndLink(func_graph);
  339. return true;
  340. }
  341. bool ExecuteAction(const ResourcePtr &res) {
  342. if (res->results().count(kOutput) == 0) {
  343. MS_LOG(EXCEPTION) << "Execute args error";
  344. }
  345. std::string backend = MsContext::GetInstance()->backend_policy();
  346. if (IsCtrlSink() && backend == kMsConvert) {
  347. if (!res->results()[kOutput].is<GraphId>()) {
  348. MS_LOG(EXCEPTION) << "Execute args error";
  349. }
  350. auto graph_id = res->results()[kOutput].cast<GraphId>();
  351. std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
  352. compile::MsBackend *msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr).get();
  353. MS_EXCEPTION_IF_NULL(msbc_ptr);
  354. compile::VmEvalFuncPtr run =
  355. std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef {
  356. MS_LOG(INFO) << "Execute args size " << args.size();
  357. auto outs = msbc_ptr->RunGraph(graph_id, args);
  358. MS_LOG(DEBUG) << "out size " << outs.size();
  359. return outs[0];
  360. });
  361. res->results()[kOutput] = run;
  362. return true;
  363. }
  364. if (!res->results()[kOutput].is<compile::FinalVMPtr>()) {
  365. MS_LOG(EXCEPTION) << "Execute args error";
  366. }
  367. compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>();
  368. if (vm == nullptr) {
  369. MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM";
  370. return true;
  371. }
  372. compile::VmEvalFuncPtr run =
  373. std::make_shared<compile::VmEvalFunc>(std::bind(&compile::FinalVM::Eval, vm, std::placeholders::_1));
  374. res->results()[kOutput] = run;
  375. return true;
  376. }
  377. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  378. bool StartPSWorkerAction(const ResourcePtr &res) {
  379. ps::worker.Run();
  380. return true;
  381. }
  382. bool StartPSServerAction(const ResourcePtr &res) {
  383. FuncGraphPtr func_graph = res->func_graph();
  384. auto &ps = ps::ParameterServer<float>::GetInstance();
  385. ps.Run(func_graph);
  386. return true;
  387. }
  388. bool StartPSSchedulerAction(const ResourcePtr &res) {
  389. ps::Scheduler::GetInstance().Run();
  390. return true;
  391. }
  392. #endif
  393. // The parallel primitive related valuenode might be partitioned so that its value changes by device,
  394. // that will result in a syncronization error due to different executing order.
  395. // Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive,
  396. // the final solution will be proposed later as a parallel feature.
  397. bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) {
  398. auto &node_users = res->manager()->node_users();
  399. auto &users = node_users[value_node];
  400. auto used_by_keep_value_prim =
  401. std::any_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int> &user) -> bool {
  402. MS_EXCEPTION_IF_NULL(user.first);
  403. auto cnode = user.first->cast<CNodePtr>();
  404. if (cnode == nullptr) {
  405. return false;
  406. }
  407. auto prim_node = cnode->input(0);
  408. if (IsValueNode<Primitive>(prim_node)) {
  409. auto prim = GetValue<PrimitivePtr>(prim_node->cast<ValueNodePtr>()->value());
  410. // value_node is referenced by some parallel primitive
  411. return prim->HasAttr("keep_value_node_input");
  412. }
  413. return false;
  414. });
  415. return used_by_keep_value_prim;
  416. }
  417. bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
  418. if (res->func_graph() == nullptr) {
  419. MS_LOG(EXCEPTION) << "Remove value node duplications error.";
  420. }
  421. FuncGraphPtr func_graph = res->func_graph();
  422. auto manager = res->manager();
  423. // Remove duplicated value nodes, due to replace operation, can't use reference.
  424. auto value_nodes = func_graph->value_nodes();
  425. HashCache hash_cache;
  426. HashValue hashes;
  427. for (const auto &value_pair : value_nodes) {
  428. if (KeepValueNodeDuplication(value_pair.first, res)) {
  429. continue;
  430. }
  431. TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
  432. }
  433. return true;
  434. }
  435. bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
  436. bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
  437. MS_EXCEPTION_IF_NULL(res->manager());
  438. MS_EXCEPTION_IF_NULL(res->func_graph());
  439. auto ppm = opt::python_pass::PyPassManager::GetInstance();
  440. ppm->SetResource(res);
  441. return ppm->GetPassGroup(phase)->Run(res->func_graph());
  442. }
  443. bool PreAdActionPyStub(const ResourcePtr &res) {
  444. if (!ActionPyStub(res, opt::python_pass::Phase::PREAD)) {
  445. MS_LOG(DEBUG) << "No Match.";
  446. }
  447. return true;
  448. }
  449. bool OptActionVmPyStub(const ResourcePtr &res) {
  450. if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
  451. if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
  452. // Renomalize
  453. MS_EXCEPTION_IF_NULL(res->func_graph());
  454. FuncGraphPtr func_graph = res->func_graph();
  455. abstract::AbstractBasePtrList args_spec;
  456. auto parameters = func_graph->parameters();
  457. (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
  458. [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
  459. FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
  460. res->set_func_graph(new_fg);
  461. res->set_args_spec(args_spec);
  462. }
  463. if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
  464. return VmOptimizeAction(res);
  465. }
  466. }
  467. return true;
  468. }
  469. bool OptActionGePyStub(const ResourcePtr &res) {
  470. if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
  471. if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
  472. // Renomalize
  473. MS_EXCEPTION_IF_NULL(res->func_graph());
  474. FuncGraphPtr func_graph = res->func_graph();
  475. abstract::AbstractBasePtrList args_spec;
  476. auto parameters = func_graph->parameters();
  477. (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
  478. [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
  479. FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
  480. res->set_func_graph(new_fg);
  481. res->set_args_spec(args_spec);
  482. }
  483. if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
  484. return GeOptimizeAction(res);
  485. }
  486. }
  487. return true;
  488. }
  489. static std::vector<ActionItem> CommonPipeline() {
  490. std::vector<ActionItem> actions;
  491. // Parse the python ast to ANF graph
  492. actions.emplace_back(std::make_pair("parse", ParseAction));
  493. // Resolve the python func
  494. actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
  495. auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs();
  496. if (!multi_graphs) {
  497. actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
  498. }
  499. actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
  500. // Evaluate type and shape, and specialize
  501. actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
  502. // Do data structure simplifications and inline
  503. actions.emplace_back(std::make_pair("inline", OptInlineAction));
  504. // Add pre-ad, post-inline python pass stub
  505. actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub));
  506. return actions;
  507. }
  508. std::vector<ActionItem> GePipeline() {
  509. auto actions = CommonPipeline();
  510. // optimize
  511. actions.emplace_back(std::make_pair("optimize", GeOptimizeAction));
  512. // Add opt-stage python pass stub
  513. actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub));
  514. actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
  515. actions.emplace_back(std::make_pair("validate", ValidateAction));
  516. return actions;
  517. }
  518. std::vector<ActionItem> VmPipeline() {
  519. auto actions = CommonPipeline();
  520. // optimize
  521. actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
  522. // Add opt-stage python pass stub
  523. actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub));
  524. actions.emplace_back(std::make_pair("validate", ValidateAction));
  525. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  526. if (ps::Util::IsRoleOfWorker()) {
  527. actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
  528. }
  529. #endif
  530. // compile the ANF graph
  531. actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
  532. // to execute the graph
  533. actions.emplace_back(std::make_pair("execute", ExecuteAction));
  534. return actions;
  535. }
  536. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  537. std::vector<ActionItem> PServerPipeline() {
  538. auto actions = CommonPipeline();
  539. actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
  540. actions.emplace_back(std::make_pair("validate", ValidateAction));
  541. actions.emplace_back(std::make_pair("pserver", StartPSServerAction));
  542. return actions;
  543. }
  544. std::vector<ActionItem> PSchedulerPipeline() {
  545. std::vector<ActionItem> actions;
  546. actions.emplace_back(std::make_pair("scheduler", StartPSSchedulerAction));
  547. return actions;
  548. }
  549. #endif
  550. } // namespace pipeline
  551. } // namespace mindspore