You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pipeline.cc 36 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/pipeline.h"
  19. #include <sstream>
  20. #include <map>
  21. #include <unordered_map>
  22. #include <cstdlib>
  23. #include <algorithm>
  24. #include "ir/param_info.h"
  25. #include "pipeline/jit/pass.h"
  26. #include "pipeline/jit/parse/data_converter.h"
  27. #include "frontend/optimizer/ad/dfunctor.h"
  28. #include "debug/anf_ir_dump.h"
  29. #include "debug/dump_proto.h"
  30. #include "debug/anf_ir_utils.h"
  31. #include "utils/config_manager.h"
  32. #include "utils/convert_utils.h"
  33. #include "utils/convert_utils_py.h"
  34. #include "utils/context/context_extends.h"
  35. #include "vm/segment_runner.h"
  36. #include "frontend/parallel/context.h"
  37. #include "frontend/parallel/graph_util/get_parallel_info.h"
  38. #include "runtime/device/kernel_runtime_manager.h"
  39. #include "backend/session/executor_manager.h"
  40. #include "debug/trace.h"
  41. #include "debug/draw.h"
  42. #include "pipeline/pynative/pynative_execute.h"
  43. #include "frontend/optimizer/py_pass_manager.h"
  44. #include "pybind_api/pybind_patch.h"
  45. #include "utils/shape_utils.h"
  46. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  47. #include "ps/common.h"
  48. #include "ps/util.h"
  49. #include "ps/worker.h"
  50. #endif
  51. #if (ENABLE_GE || ENABLE_D)
  52. #include "pipeline/jit/pipeline_ge.h"
  53. #include "transform/graph_ir/convert.h"
  54. #include "transform/graph_ir/df_graph_manager.h"
  55. #include "transform/graph_ir/op_adapter_map.h"
  56. #endif
  57. namespace mindspore {
  58. // namespace to support intermediate representation definition
  59. namespace pipeline {
  60. using Tensor = mindspore::tensor::Tensor;
  61. using MetaTensor = mindspore::tensor::MetaTensor;
  62. using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>;
  63. using mindspore::abstract::AbstractTensor;
  64. using mindspore::abstract::AbstractTensorPtr;
  65. using mindspore::abstract::AbstractTuple;
  66. using mindspore::abstract::AbstractTuplePtr;
  67. const char IR_TYPE_ANF[] = "anf_ir";
  68. const char IR_TYPE_ONNX[] = "onnx_ir";
  69. const char IR_TYPE_MINDIR[] = "mind_ir";
  70. ExecutorPyPtr ExecutorPy::executor_ = nullptr;
  71. std::mutex ExecutorPy::instance_lock_;
  72. bool ExecutorPy::debugger_terminate_ = false;
  73. std::unordered_map<abstract::AbstractBasePtrList, int, abstract::AbstractBasePtrListHasher,
  74. abstract::AbstractBasePtrListEqual>
  75. g_args_cache;
  76. namespace {
  77. std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) {
  78. std::ostringstream oss;
  79. oss << stage_idx << "_" << action_name;
  80. return oss.str();
  81. }
  82. void CheckArgIsTensor(const ValuePtr &arg, std::size_t idx) {
  83. MS_EXCEPTION_IF_NULL(arg);
  84. auto tensor_arg = arg->cast<TensorPtr>();
  85. if (tensor_arg == nullptr) {
  86. MS_EXCEPTION(TypeError) << "For 'graph mode', the " << idx << "th arg: " << arg->ToString() << " is not a tensor.";
  87. }
  88. if (tensor_arg->is_parameter()) {
  89. MS_EXCEPTION(TypeError) << "The inputs could not be Parameter.";
  90. }
  91. }
  92. } // namespace
  93. py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) {
  94. MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size();
  95. abstract::AbstractBasePtrList args_spec;
  96. for (const auto &arg : defaults) {
  97. if (py::isinstance<py::module>(arg.second)) {
  98. MS_LOG(EXCEPTION) << "GenerateKey failed, argument input should not be py::module";
  99. }
  100. ValuePtr converted = nullptr;
  101. if (!parse::ConvertData(arg.second, &converted)) {
  102. MS_LOG(EXCEPTION) << "GenerateKey convert arg failed";
  103. }
  104. args_spec.push_back(abstract::FromValue(converted, true));
  105. }
  106. if (g_args_cache.count(args_spec) == 0) {
  107. static int key = 0;
  108. MS_LOG(INFO) << "Start new args and compile key:" << key;
  109. g_args_cache[args_spec] = key++;
  110. }
  111. auto argSpec = py::tuple(2);
  112. argSpec[0] = name;
  113. argSpec[1] = g_args_cache[args_spec];
  114. return argSpec;
  115. }
  116. py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple &inputs) {
  117. MS_LOG(DEBUG) << "Verify args size:" << inputs.size();
  118. if (inputs.size() != input_signature.size()) {
  119. MS_LOG(ERROR) << "Signature size not equal to args size";
  120. return false;
  121. }
  122. size_t count = 0;
  123. for (auto arg_obj : inputs) {
  124. if (py::isinstance<Tensor>(arg_obj)) {
  125. MS_LOG(DEBUG) << "Verify Tensor";
  126. auto m_tensor = arg_obj.cast<std::shared_ptr<Tensor>>();
  127. if (m_tensor == nullptr) {
  128. MS_LOG(ERROR) << "Verify Tensor error, get ptr is null";
  129. return false;
  130. }
  131. auto sig = input_signature[count].cast<std::shared_ptr<MetaTensor>>();
  132. ShapeVector sig_shape = sig->shape();
  133. TypePtr sig_type = sig->Dtype();
  134. ShapeVector tensor_shape = m_tensor->shape_c();
  135. if (tensor_shape != sig_shape) {
  136. MS_LOG(ERROR) << "Python input shape is incompatible with input_signature";
  137. return false;
  138. }
  139. if (*m_tensor->Dtype() != *sig_type) {
  140. MS_LOG(ERROR) << "Python input type(" << m_tensor->Dtype()->ToString() << ") incompatible with input_signature("
  141. << sig_type->ToString() << ")";
  142. return false;
  143. }
  144. }
  145. count++;
  146. }
  147. return true;
  148. }
  149. ExecutorPy::ExecutorPy() {}
  150. ResourcePtr ExecutorPy::GetResource(const std::string &phase) {
  151. MS_LOG(DEBUG) << "Phase size:" << info_.size();
  152. if (info_.count(phase) == 0) {
  153. return nullptr;
  154. }
  155. return info_[phase]->resource;
  156. }
  157. FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) {
  158. if (info_.count(phase) == 0) {
  159. MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase);
  160. }
  161. return info_[phase]->func_graph;
  162. }
  163. compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) {
  164. ResourcePtr res = GetResource(phase);
  165. MS_EXCEPTION_IF_NULL(res);
  166. if (res->results().find(kOutput) != res->results().end() && res->results()[kOutput].is<compile::VmEvalFuncPtr>()) {
  167. return res->results()[kOutput].cast<compile::VmEvalFuncPtr>();
  168. }
  169. MS_LOG(ERROR) << "GetVmEvalFunc vm model can't find kOutput:" << kOutput;
  170. return nullptr;
  171. }
  172. bool ExecutorPy::HasCompiled(const std::string &phase) const {
  173. if (info_.count(phase) == 0) {
  174. return false;
  175. }
  176. return true;
  177. }
  178. py::bytes ExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type) {
  179. FuncGraphPtr fg_ptr = GetFuncGraph(phase);
  180. if (fg_ptr == nullptr) {
  181. for (auto &item : info_) {
  182. MS_LOG(DEBUG) << "Phase key is: " << item.first;
  183. }
  184. MS_LOG(EXCEPTION) << "Can not find func graph " << phase;
  185. }
  186. if (ir_type == IR_TYPE_ANF) {
  187. std::string proto_str = GetFuncGraphProtoString(fg_ptr);
  188. if (proto_str.empty()) {
  189. MS_LOG(EXCEPTION) << "Graph proto is empty.";
  190. }
  191. return proto_str;
  192. }
  193. if (ir_type == IR_TYPE_ONNX) {
  194. std::string proto_str = GetOnnxProtoString(fg_ptr);
  195. if (proto_str.empty()) {
  196. MS_LOG(EXCEPTION) << "Graph proto is empty.";
  197. }
  198. return proto_str;
  199. }
  200. if (ir_type == IR_TYPE_MINDIR) {
  201. std::string proto_str = GetBinaryProtoString(fg_ptr);
  202. if (proto_str.empty()) {
  203. MS_LOG(EXCEPTION) << "Graph proto is empty.";
  204. }
  205. return proto_str;
  206. }
  207. MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type;
  208. }
  209. py::dict ExecutorPy::GetParameterLayout(const std::string &phase) {
  210. MS_LOG(DEBUG) << "GetParameterLayout!";
  211. std::string layout_graph = phase + kStepParallelGraph;
  212. auto graph = GetFuncGraph(layout_graph);
  213. return mindspore::parallel::GetParameterLayout(graph);
  214. }
  215. py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) {
  216. MS_LOG(DEBUG) << "GetCNodeStrategy!";
  217. return stra_dict_[phase];
  218. }
  219. void ExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) {
  220. MS_LOG(DEBUG) << "SetCNodeStrategy!";
  221. stra_dict_[phase_][py::str(name)] = strategy;
  222. }
  223. py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) {
  224. MS_LOG(INFO) << "GetAllreduceFusion!";
  225. auto graph = GetFuncGraph(phase);
  226. return mindspore::parallel::GetAllreduceFusion(graph);
  227. }
  228. void ExecutorPy::DelNetRes(const std::string &id) {
  229. #ifdef ENABLE_GE
  230. FinalizeBackend();
  231. #else
  232. ConfigManager::GetInstance().ResetIterNum();
  233. #endif
  234. if (executor_ != nullptr) {
  235. bool flag = false;
  236. auto tmp_info = info_;
  237. for (auto &item : tmp_info) {
  238. if (item.first.find(id) != string::npos) {
  239. MS_LOG(DEBUG) << "Delete network res:" << item.first;
  240. item.second = nullptr;
  241. (void)info_.erase(item.first);
  242. flag = true;
  243. }
  244. }
  245. MS_LOG(DEBUG) << "Delete flag:" << flag;
  246. #ifdef ENABLE_GE
  247. if (flag && info_.size() == 0) {
  248. // because Ge only support one Session exist at the same time ,so we delete the old one
  249. transform::DfGraphManager::GetInstance().DeleteGraphRunner();
  250. transform::DfGraphManager::GetInstance().EraseAnfGraph();
  251. transform::DfGraphManager::GetInstance().DeleteGeSession();
  252. }
  253. #endif
  254. }
  255. }
  256. void ExecutorPy::ClearRes() {
  257. MS_LOG(INFO) << "Clean executor resource!";
  258. executor_ = nullptr;
  259. }
  260. ExecutorPy::~ExecutorPy() {
  261. MS_LOG(INFO) << "Release Executor!";
  262. ConfigManager::GetInstance().ResetConfig();
  263. }
  264. std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport(
  265. const std::string &phase_s) {
  266. FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
  267. MS_EXCEPTION_IF_NULL(func_graph);
  268. MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
  269. std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
  270. auto filter = [](const AnfNodePtr &node) {
  271. return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) ||
  272. IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative));
  273. };
  274. std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
  275. auto is_quant_cnode = [](const AnfNodePtr &node) {
  276. return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
  277. IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel);
  278. };
  279. for (const auto &node : nodes) {
  280. auto cnode = node->cast<CNodePtr>();
  281. if (cnode == nullptr || cnode->size() != 3) {
  282. continue;
  283. }
  284. auto x = cnode->input(1);
  285. auto weight = cnode->input(2);
  286. if (!is_quant_cnode(weight)) {
  287. continue;
  288. }
  289. // get parameter weight's name
  290. cnode = weight->cast<CNodePtr>();
  291. auto weight_node = cnode->input(2);
  292. if (!weight_node->isa<Parameter>()) {
  293. continue;
  294. }
  295. auto weight_name = weight_node->cast<ParameterPtr>()->name();
  296. // find the fakequant from input
  297. int count = 0;
  298. const int max_depth = 5;
  299. while (!is_quant_cnode(x)) {
  300. if (count >= max_depth) {
  301. break;
  302. }
  303. cnode = x->cast<CNodePtr>();
  304. if (cnode == nullptr || cnode->size() <= 1) {
  305. break;
  306. }
  307. x = cnode->input(1);
  308. count += 1;
  309. }
  310. if (x->isa<Parameter>()) {
  311. fake_quant_table[weight_name] = std::make_pair(nullptr, "input");
  312. }
  313. // get the fakequant parameter minq's name
  314. if (!is_quant_cnode(x)) {
  315. continue;
  316. }
  317. cnode = x->cast<CNodePtr>();
  318. if (cnode == nullptr || cnode->size() != 4) {
  319. continue;
  320. }
  321. auto fakequant_min_node = cnode->input(2);
  322. if (!fakequant_min_node->isa<Parameter>()) {
  323. continue;
  324. }
  325. auto fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name();
  326. auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
  327. if (!quant_op_value->isa<PrimitivePy>()) {
  328. continue;
  329. }
  330. auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
  331. fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name);
  332. }
  333. return fake_quant_table;
  334. }
  335. void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
  336. // save the graph to ExecutorPy
  337. FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
  338. MS_EXCEPTION_IF_NULL(func_graph);
  339. MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
  340. std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
  341. MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
  342. info_[phase_s]->func_graph = func_graph;
  343. if ((func_graph != nullptr) && func_graph->has_flag(parallel::AUTO_PARALLEL) &&
  344. ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) {
  345. MS_LOG(DEBUG) << "Save model parallel parameter layout graph!";
  346. func_graph = info_[phase_s]->resource->results()[kStepParallelGraph].cast<FuncGraphPtr>();
  347. ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
  348. std::string layout_graph = phase_s + kStepParallelGraph;
  349. executor_info->func_graph = func_graph;
  350. info_[layout_graph] = executor_info;
  351. } else {
  352. MS_LOG(DEBUG) << "Save model parallel parameter layout graph null!";
  353. }
  354. MS_LOG(INFO) << "End save compiled func graph!";
  355. }
  356. void ExecutorPy::GetGeBackendPolicy() const {
  357. auto ms_context = MsContext::GetInstance();
  358. MS_EXCEPTION_IF_NULL(ms_context);
  359. std::string backend = ms_context->backend_policy();
  360. if (backend != "ge") {
  361. MS_LOG(EXCEPTION) << backend << " backend policy is not supported under ge backend!";
  362. }
  363. }
  364. bool IsPhaseExportAir(const std::string &phase_s) {
  365. auto phase_to_export = "export.air";
  366. return phase_s.rfind(phase_to_export) != std::string::npos;
  367. }
  368. std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) {
  369. bool is_air = IsPhaseExportAir(phase_s);
  370. std::string backend = MsContext::GetInstance()->backend_policy();
  371. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  372. if (mindspore::ps::Util::IsParamServerMode()) {
  373. mindspore::ps::Util::SetInternalEnvVar();
  374. }
  375. if (ps::Util::IsRoleOfPServer()) {
  376. resource->results()[kBackend] = compile::CreateBackend();
  377. return PServerPipeline();
  378. }
  379. if (ps::Util::IsRoleOfScheduler()) {
  380. return PSchedulerPipeline();
  381. }
  382. #endif
  383. if (use_vm && backend != "ge" && !is_air) {
  384. // Create backend and session
  385. auto backend_ptr = compile::CreateBackend();
  386. // Connect session to debugger
  387. backend_ptr->SetDebugger();
  388. resource->results()[kBackend] = backend_ptr;
  389. return VmPipeline();
  390. }
  391. return GePipeline();
  392. }
  393. bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) {
  394. MS_LOG(DEBUG) << "Start ExecutorPy compile!";
  395. if ((!py::isinstance<py::str>(phase))) {
  396. MS_LOG(ERROR) << "Arg phase must be string.";
  397. return false;
  398. }
  399. // check the arg valid?
  400. if (py::isinstance<py::none>(obj)) {
  401. MS_LOG(ERROR) << "Find error: parse obj is None.";
  402. return false;
  403. }
  404. #ifdef ENABLE_GE
  405. GetGeBackendPolicy();
  406. #endif
  407. ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
  408. auto phase_s = py::cast<std::string>(phase);
  409. phase_ = phase_s;
  410. MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!";
  411. ResourcePtr resource = std::make_shared<Resource>(obj);
  412. auto p_actions = GetPipline(resource, phase_s, use_vm);
  413. std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, FilterActions(p_actions, phase_s));
  414. // get the parameters items and add the value to args_spec
  415. abstract::AbstractBasePtrList args_spec;
  416. std::size_t size = args.size();
  417. for (std::size_t i = 0; i < size; i++) {
  418. ValuePtr converted = nullptr;
  419. bool succ = parse::ConvertData(args[i], &converted);
  420. if (!succ) {
  421. MS_LOG(EXCEPTION) << "Args convert error";
  422. }
  423. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  424. CheckArgIsTensor(converted, i);
  425. }
  426. bool broaden = true;
  427. args_spec.push_back(abstract::FromValue(converted, broaden));
  428. }
  429. resource->set_args_spec(args_spec);
  430. executor_info->arg_list_size = size;
  431. executor_info->resource = resource;
  432. info_[phase_s] = executor_info;
  433. pip->Run();
  434. // save the run graph func to MsPipeLine
  435. SaveCompiledGraph(phase_s);
  436. opt::python_pass::PyPassManager::GetInstance()->ClearPipelineRes();
  437. // Reclaim all resource used by optimizer;
  438. ReclaimOptimizer();
  439. resource->Clean();
  440. MS_LOG(INFO) << "End ExecutorPy compile!";
  441. return true;
  442. }
  443. std::vector<ActionItem> ExecutorPy::FilterActions(const std::vector<ActionItem> &actions, const std::string &phase) {
  444. // filter action after validate when 'export'.
  445. if (GetPhasePrefix(phase).rfind("export", 0) == std::string::npos) {
  446. return actions;
  447. }
  448. MS_LOG(INFO) << "Phase is '" << phase << "', filter out actions after stage 'validate'";
  449. std::vector<ActionItem> filtered_actions;
  450. for (const auto &item : actions) {
  451. filtered_actions.emplace_back(item);
  452. if (item.first == "validate") {
  453. break;
  454. }
  455. }
  456. return filtered_actions;
  457. }
  458. void ExecutorPy::ReleaseResource(const py::object &phase) {
  459. ResourcePtr res = GetResource(py::cast<std::string>(phase));
  460. if (res != nullptr) {
  461. res->Clean();
  462. }
  463. // Reclaim all resource used by optimizer;
  464. ReclaimOptimizer();
  465. }
  466. static std::string PrintArgs(const py::tuple &args) {
  467. py::print(args);
  468. return "";
  469. }
  470. bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) {
  471. bool ret_value = false;
  472. try {
  473. MS_LOG(DEBUG) << PrintArgs(args);
  474. ret_value = CompileInner(obj, args, phase, use_vm);
  475. } catch (const py::error_already_set &ex) {
  476. // print function call stack info before release
  477. std::ostringstream oss;
  478. trace::TraceGraphEval();
  479. trace::GetEvalStackInfo(oss);
  480. // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
  481. // these info from screen, no need to open log file to find these info
  482. py::print(oss.str());
  483. MS_LOG(ERROR) << oss.str();
  484. ReleaseResource(phase);
  485. // re-throw this exception to Python interpreter to handle it
  486. throw(py::error_already_set(ex));
  487. } catch (const py::type_error &ex) {
  488. ReleaseResource(phase);
  489. throw py::type_error(ex);
  490. } catch (const py::value_error &ex) {
  491. ReleaseResource(phase);
  492. throw py::value_error(ex);
  493. } catch (const py::index_error &ex) {
  494. ReleaseResource(phase);
  495. throw py::index_error(ex);
  496. } catch (const py::key_error &ex) {
  497. ReleaseResource(phase);
  498. throw py::key_error(ex);
  499. } catch (const py::attribute_error &ex) {
  500. ReleaseResource(phase);
  501. throw py::attribute_error(ex);
  502. } catch (const std::exception &ex) {
  503. ReleaseResource(phase);
  504. // re-throw this exception to Python interpreter to handle it
  505. throw(std::runtime_error(ex.what()));
  506. } catch (...) {
  507. ReleaseResource(phase);
  508. std::string exName(abi::__cxa_current_exception_type()->name());
  509. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
  510. }
  511. return ret_value;
  512. }
  513. #ifdef ENABLE_LOAD_ANF_IR
  514. // get MindSpore Intermediate Representation File
  515. std::string GetMsIrFile(void) {
  516. std::string file;
  517. const char *path = getenv("MS_IR_FILE");
  518. if (path == nullptr) {
  519. return file;
  520. }
  521. char real_path[PATH_MAX] = {0};
  522. if (realpath(path, real_path) == nullptr) {
  523. MS_LOG(ERROR) << "MS IR path error, " << path;
  524. return file;
  525. }
  526. file = real_path;
  527. return file;
  528. }
  529. void RunPipelineAction(const ActionItem &action, pipeline::ResourcePtr resource, bool *result) {
  530. MS_EXCEPTION_IF_NULL(resource);
  531. MS_EXCEPTION_IF_NULL(result);
  532. std::string ir_file = GetMsIrFile();
  533. (void)parse::python_adapter::set_python_scoped();
  534. if (ir_file.empty()) {
  535. *result = action.second(resource);
  536. return;
  537. }
  538. // when in loading anf ir mode, action `parse` do nothing
  539. if (action.first == "parse") {
  540. return;
  541. }
  542. // load MindSpore IR from file
  543. if (action.first == "symbol_resolve") {
  544. MS_LOG(DEBUG) << action.first << " read ir file: " << ir_file;
  545. std::vector<FuncGraphPtr> graphs = ImportIR(ir_file);
  546. if (graphs.size() == 0) {
  547. MS_LOG(EXCEPTION) << action.first << " read ir file " << ir_file << " failed as no graph found";
  548. }
  549. auto manager = resource->manager();
  550. MS_EXCEPTION_IF_NULL(manager);
  551. for (auto &graph : graphs) {
  552. manager->AddFuncGraph(graph);
  553. }
  554. resource->set_func_graph(graphs[0]);
  555. return;
  556. }
  557. // do normal action when not in `parse` and `symbol_resolve` stage
  558. *result = action.second(resource);
  559. }
  560. #endif
  561. void Pipeline::Run() {
  562. MS_LOG(INFO) << "Pipeline run";
  563. MS_EXCEPTION_IF_NULL(resource_);
  564. FuncGraphPtr user_graph = nullptr;
  565. WITH(MsProfile::GetProfile())[&user_graph, this]() {
  566. int i = 0;
  567. for (auto &action : actions_) {
  568. #ifdef ENABLE_TIMELINE
  569. DumpTime &dump_time = DumpTime::GetInstance();
  570. dump_time.Record(action.first, GetTime(), true);
  571. #endif
  572. bool result = true;
  573. WITH(MsProfile::GetProfile()->Step(action.first))[&result, &action, this]() {
  574. MS_LOG(DEBUG) << "Action " << action.first << " start ...";
  575. #ifdef ENABLE_LOAD_ANF_IR
  576. RunPipelineAction(action, resource_, &result);
  577. #else
  578. result = action.second(resource_);
  579. #endif
  580. MS_LOG(DEBUG) << "Action " << action.first << " end.";
  581. };
  582. if (!result) {
  583. MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;
  584. }
  585. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && resource_->func_graph() != nullptr) {
  586. auto graph = resource_->func_graph();
  587. if (graph != nullptr) {
  588. user_graph = graph;
  589. std::string base_name = GetBaseNameForIR(i, action.first);
  590. // generate IR file in dot format, which can be converted to svg file using graphviz dot command
  591. draw::Draw(base_name + ".dot", graph);
  592. // generate IR file in human readable format
  593. DumpIR(base_name + ".ir", graph);
  594. // generate IR file in a heavily commented format, which can also be reloaded
  595. ExportIR(base_name + ".dat", std::to_string(i), graph);
  596. }
  597. #ifdef MS_DEBUG
  598. // Dump graph cnode list
  599. MS_LOG(INFO) << "Show CNode list after " << action.first;
  600. graph->DumpCNodeList();
  601. #endif
  602. }
  603. if (resource_->func_graph() != nullptr) {
  604. auto func_graph = resource_->func_graph();
  605. if (func_graph->has_flag(GRAPH_FLAG_HAS_EFFECT)) {
  606. func_graph->EraseUnusedNodeInOrder();
  607. func_graph->CheckOrder();
  608. for (auto fg : func_graph->func_graphs_used_total()) {
  609. MS_LOG(DEBUG) << "Check order graph " << fg->ToString() << ".";
  610. fg->EraseUnusedNodeInOrder();
  611. fg->CheckOrder();
  612. }
  613. }
  614. }
  615. i++;
  616. #ifdef ENABLE_TIMELINE
  617. dump_time.Record(action.first, GetTime(), false);
  618. #endif
  619. }
  620. };
  621. #ifdef ENABLE_PROFILE
  622. MsProfile::Print();
  623. MsProfile::Reset();
  624. #endif
  625. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && (user_graph != nullptr)) {
  626. draw::DrawUserFuncGraph("ModelDigraph.dot", user_graph);
  627. }
  628. MS_LOG(INFO) << "End";
  629. }
  630. void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) {
  631. MS_EXCEPTION_IF_NULL(arg_list);
  632. std::size_t size = args.size();
  633. bool arg_list_inited = !arg_list->empty();
  634. for (std::size_t i = 0; i < size; i++) {
  635. py::object arg = args[i];
  636. auto ms_context = MsContext::GetInstance();
  637. if (ms_context->backend_policy() == kMsConvert && py::isinstance<py::array>(arg)) {
  638. MS_LOG(EXCEPTION) << "The " << i << "th arg is numpy array, not tensor.";
  639. }
  640. ValuePtr converted = nullptr;
  641. bool succ = parse::ConvertData(arg, &converted);
  642. if (!succ) {
  643. MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed.";
  644. }
  645. if (!arg_list_inited) {
  646. arg_list->push_back(converted);
  647. continue;
  648. }
  649. if (i >= arg_list->size()) {
  650. MS_LOG(EXCEPTION) << "i:" << i << " output of range:" << arg_list->size();
  651. }
  652. (*arg_list)[i] = converted;
  653. }
  654. MS_EXCEPTION_IF_NULL(res);
  655. auto graph = res->func_graph();
  656. MS_EXCEPTION_IF_NULL(graph);
  657. std::vector<AnfNodePtr> graph_params = graph->parameters();
  658. std::size_t graph_params_size = graph_params.size();
  659. if ((*arg_list).size() != graph_params_size) {
  660. // maybe some default parameter
  661. for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) {
  662. MS_EXCEPTION_IF_NULL(graph_params[i]);
  663. auto param_ptr = (graph_params[i])->cast<ParameterPtr>();
  664. if (!param_ptr->has_default()) {
  665. MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
  666. }
  667. if (!param_ptr->default_param()->isa<Tensor>()) {
  668. MS_LOG(EXCEPTION) << "Parameter[" << param_ptr->ToString()
  669. << "] is not initialized, need to call `.init_data()`";
  670. }
  671. arg_list->push_back(param_ptr->default_param());
  672. }
  673. }
  674. }
  675. void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list) {
  676. ProcessVmArgInner(args, GetResource(phase), arg_list);
  677. }
  678. void ExecutorPy::TerminateDebugger() {
  679. if (debugger_terminate_) {
  680. MS_LOG(INFO) << "Terminate debugger and clear resources!";
  681. ClearResAtexit();
  682. exit(1);
  683. }
  684. }
  685. py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) {
  686. // Mindspore debugger notify main thread to exit after one step, and will not run next step
  687. TerminateDebugger();
  688. std::size_t size = args.size();
  689. if (!py::isinstance<py::str>(phase)) {
  690. MS_LOG(EXCEPTION) << "Run failed, phase input is not a str";
  691. }
  692. auto phase_s = py::cast<std::string>(phase);
  693. std::string backend = MsContext::GetInstance()->backend_policy();
  694. #ifdef ENABLE_GE
  695. if (backend == "ge") {
  696. return ExecDFGraph(info_, args, phase_s);
  697. }
  698. #else
  699. auto ret_val = std::make_shared<py::object>();
  700. if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) {
  701. if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) {
  702. // Check the input arg must be Tensor when backend is "ms".
  703. if (MsContext::GetInstance()->backend_policy() == kMsConvert) {
  704. for (std::size_t i = 0; i < size; i++) {
  705. ValuePtr converted = nullptr;
  706. if (!parse::ConvertData(args[i], &converted)) {
  707. MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed.";
  708. }
  709. if (!converted->isa<tensor::Tensor>()) {
  710. MS_EXCEPTION(TypeError) << "The " << i << "th arg: " << converted->ToString() << " is not tensor.";
  711. }
  712. }
  713. }
  714. return *ret_val;
  715. }
  716. }
  717. if (backend == "ge") {
  718. // Virtual output constructed for test cases.
  719. if (!args.empty()) {
  720. return args[0];
  721. }
  722. return args;
  723. }
  724. #endif
  725. auto iter = info_.find(phase_s);
  726. if (iter == info_.end()) {
  727. MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase_s);
  728. }
  729. auto &execute_info = iter->second;
  730. MS_EXCEPTION_IF_NULL(execute_info);
  731. if (size > execute_info->arg_list_size) {
  732. MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << execute_info->arg_list_size;
  733. }
  734. ProcessVmArg(args, phase_s, &execute_info->arg_list);
  735. // Start to run phase.
  736. compile::VmEvalFuncPtr run = GetVmEvalFunc(phase_s);
  737. if (run == nullptr) {
  738. MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s;
  739. }
  740. MS_LOG(DEBUG) << "Eval run" << backend;
  741. BaseRef value = (*run)(execute_info->arg_list);
  742. MS_LOG(DEBUG) << "Run end";
  743. return BaseRefToPyData(value);
  744. }
  745. FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase,
  746. const py::object &broadcast_params) {
  747. #if (ENABLE_GE || ENABLE_D)
  748. return BuildDFGraph(info_, init_params, phase, broadcast_params);
  749. #else
  750. return nullptr;
  751. #endif
  752. }
  753. void ExecutorPy::UpdataParamNodeDefaultInput(const std::string &phase,
  754. const std::unordered_map<std::string, tensor::TensorPtr> &params_value) {
  755. FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
  756. MS_EXCEPTION_IF_NULL(func_graph);
  757. MS_LOG(DEBUG) << "UpdataParamNodeDefaultInput for func graph(" << func_graph->ToString() << ") phase(" << phase
  758. << ")!";
  759. auto &params = func_graph->parameters();
  760. for (const auto &param : params) {
  761. MS_EXCEPTION_IF_NULL(param);
  762. auto param_cast = param->cast<ParameterPtr>();
  763. MS_EXCEPTION_IF_NULL(param_cast);
  764. auto iter = params_value.find(param_cast->name());
  765. if (iter != params_value.end()) {
  766. param_cast->set_default_param(iter->second);
  767. }
  768. }
  769. }
  770. void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) {
  771. #if ENABLE_GE
  772. RunGEInitGraph(init_params, phase);
  773. #endif
  774. }
  775. bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
  776. const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
  777. const std::vector<int64_t> &input_indexes, const std::string &phase, bool need_run) {
  778. std::string name = MsContext::GetInstance()->backend_policy();
  779. #ifndef NO_DLIB
  780. auto ms_context = MsContext::GetInstance();
  781. MS_EXCEPTION_IF_NULL(ms_context);
  782. if (!context::IsTsdOpened(ms_context) || !context::IsGeInited(ms_context)) {
  783. (void)InitBackend();
  784. }
  785. #endif
  786. if (iter_num == -1) {
  787. iter_num = INT32_MAX;
  788. }
  789. if (name == kMsConvert || name == kMsVm) {
  790. return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run);
  791. }
  792. #if ENABLE_GE
  793. return InitExecDatasetGe(queue_name, iter_num, batch_size, types, shapes, input_indexes, phase);
  794. #else
  795. std::string backend = MsContext::GetInstance()->backend_policy();
  796. if (backend == "ge") {
  797. return true;
  798. }
  799. #endif
  800. return false;
  801. }
  802. bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
  803. const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
  804. const std::vector<int64_t> &input_indexes, bool need_run) {
  805. MS_LOG(INFO) << "Start InitDataSet Entry";
  806. ShapeVector int_input_indexes;
  807. (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes),
  808. [](int64_t item) { return static_cast<int>(item); });
  809. std::vector<ShapeVector> int_shapes;
  810. (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes),
  811. [](const std::vector<int64_t> &item) {
  812. ShapeVector vector_item;
  813. (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item),
  814. [](int64_t inner_item) { return static_cast<int>(inner_item); });
  815. return vector_item;
  816. });
  817. auto p_init = std::make_shared<Primitive>("InitDataSetQueue");
  818. p_init->set_attr("queue_name", MakeValue(queue_name));
  819. p_init->set_attr("size", MakeValue(static_cast<int>(size)));
  820. p_init->set_attr("batch_size", MakeValue(static_cast<int>(batch_size)));
  821. p_init->set_attr("types", MakeValue(types));
  822. p_init->set_attr("shapes", MakeValue(int_shapes));
  823. p_init->set_attr("input_indexes", MakeValue(int_input_indexes));
  824. const std::vector<std::string> empty_str_list;
  825. p_init->set_attr("input_names", MakeValue(empty_str_list));
  826. p_init->set_attr("output_names", MakeValue(empty_str_list));
  827. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  828. auto app_init = std::make_shared<CNode>(AnfNodePtrList{NewValueNode(p_init)}, func_graph);
  829. func_graph->set_output(app_init);
  830. auto manager = MakeManager();
  831. manager->AddFuncGraph(func_graph);
  832. // AbstractNone indicates there is no output for this apply node.
  833. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  834. app_init->set_abstract(abstract_none);
  835. auto backend = compile::CreateBackend();
  836. MS_EXCEPTION_IF_NULL(backend);
  837. auto convert_fn = backend->convert_fn();
  838. MS_EXCEPTION_IF_NULL(convert_fn);
  839. // Convert CNodeList to LinConvertResult.
  840. ConfigManager::GetInstance().set_iter_num(1);
  841. auto runner = convert_fn({app_init}, "");
  842. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
  843. backend->Link(runner.graph_id);
  844. }
  845. ConfigManager::GetInstance().set_iter_num(size);
  846. if (!(*runner.run)) {
  847. // empty function
  848. MS_LOG(EXCEPTION) << "Backend " << backend->name() << " unsupported tdt dataset.";
  849. }
  850. // launch init dataset runner without inputs and outputs
  851. VectorRef args;
  852. auto fn = runner.run;
  853. if (need_run) {
  854. (void)(*fn)(args);
  855. }
  856. MS_LOG(DEBUG) << "InitDataSetVm End.";
  857. return true;
  858. }
  859. void ResetOpId() { mindspore::id_generator::reset_id(); }
  860. void InitHccl() {
  861. #ifdef ENABLE_GE
  862. (void)InitBackend();
  863. #else
  864. mindspore::parse::python_adapter::set_python_env_flag(true);
  865. auto ms_context = MsContext::GetInstance();
  866. MS_EXCEPTION_IF_NULL(ms_context);
  867. (void)context::OpenTsd(ms_context);
  868. uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  869. std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  870. ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
  871. if (ms_context->backend_policy() == "ms" &&
  872. ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
  873. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id);
  874. MS_EXCEPTION_IF_NULL(runtime_instance);
  875. if (!runtime_instance->Init()) {
  876. MS_LOG(ERROR) << "Kernel runtime init error.";
  877. return;
  878. }
  879. }
  880. #endif
  881. }
  882. void FinalizeHccl() {
  883. #ifdef ENABLE_GE
  884. (void)FinalizeBackend();
  885. #else
  886. device::KernelRuntimeManager::Instance().ClearRuntimeResource();
  887. #endif
  888. }
  889. void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) {
  890. #if (ENABLE_GE || ENABLE_D)
  891. ExportDFGraph(file_name, phase);
  892. #else
  893. MS_EXCEPTION(ValueError) << "Only support export file in 'AIR' format with Ascend backend.";
  894. #endif
  895. }
  896. void ReleaseGeTsd() {
  897. auto context_ptr = MsContext::GetInstance();
  898. if (context_ptr != nullptr) {
  899. (void)context::FinalizeGe(context_ptr, true);
  900. (void)context::CloseTsd(context_ptr, true);
  901. }
  902. }
  903. void InitBackend() {
  904. // set python env flag
  905. mindspore::parse::python_adapter::set_python_env_flag(true);
  906. // open tsd before ge initialize
  907. auto ms_context = MsContext::GetInstance();
  908. MS_EXCEPTION_IF_NULL(ms_context);
  909. if (!context::OpenTsd(ms_context)) {
  910. MS_LOG(EXCEPTION) << "Open tsd failed";
  911. }
  912. (void)context::InitGe(ms_context);
  913. }
  914. void FinalizeBackend() {
  915. auto context_ptr = MsContext::GetInstance();
  916. MS_EXCEPTION_IF_NULL(context_ptr);
  917. (void)context::FinalizeGe(context_ptr);
  918. (void)context::CloseTsd(context_ptr);
  919. }
  920. void ClearResAtexit() {
  921. MS_LOG(DEBUG) << "Pipeline clear all resource";
  922. pynative::ClearPyNativeSession();
  923. session::ClearPythonParasMap();
  924. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  925. if (ps::Util::IsParamServerMode()) {
  926. if (ps::Util::IsRoleOfWorker()) {
  927. ps::worker.Finalize();
  928. }
  929. }
  930. #endif
  931. ad::g_k_prims.clear();
  932. abstract::ClearPrimEvaluatorMap();
  933. compile::ClearConvertCache();
  934. pipeline::GetMethodMap().clear();
  935. pipeline::GetAttrMap().clear();
  936. pipeline::ExecutorPy::ClearRes();
  937. pipeline::ReclaimOptimizer();
  938. pynative::PynativeExecutor::GetInstance()->ClearRes();
  939. opt::python_pass::PyPassManager::GetInstance()->ClearRes();
  940. #ifdef ENABLE_GE
  941. transform::DfGraphManager::GetInstance().ClearGraph();
  942. transform::OpAdapterMap::get().clear();
  943. #else
  944. ConfigManager::GetInstance().ResetIterNum();
  945. #endif
  946. session::ExecutorManager::Instance().Clear();
  947. device::KernelRuntimeManager::Instance().ClearRuntimeResource();
  948. ReleaseGeTsd();
  949. parse::python_adapter::ResetPythonScope();
  950. }
  951. } // namespace pipeline
  952. } // namespace mindspore