You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pynative_execute.cc 43 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pynative/pynative_execute.h"
  17. #include <typeinfo>
  18. #include <map>
  19. #include <set>
  20. #include <unordered_set>
  21. #include <algorithm>
  22. #include "debug/trace.h"
  23. #include "ir/tensor_py.h"
  24. #include "ir/param_value_py.h"
  25. #include "utils/any.h"
  26. #include "utils/utils.h"
  27. #include "utils/context/ms_context.h"
  28. #include "operator/ops.h"
  29. #include "operator/composite/composite.h"
  30. #include "operator/composite/do_signature.h"
  31. #include "pipeline/parse/data_converter.h"
  32. #include "pipeline/parse/parse_base.h"
  33. #include "pipeline/parse/resolve.h"
  34. #include "pipeline/static_analysis/prim.h"
  35. #include "session/session_factory.h"
  36. #include "pre_activate/pass/const_input_to_attr_registry.h"
  37. #include "pre_activate/common/helper.h"
  38. #include "pipeline/action.h"
  39. #include "pynative/base.h"
  40. #include "pybind_api/api_register.h"
  41. #include "vm/transform.h"
  42. #include "optimizer/ad/grad.h"
  43. #include "pipeline/resource.h"
  44. #include "pipeline/pipeline.h"
  45. #include "pipeline/pass.h"
  46. #ifdef ENABLE_GE
  47. #include "pynative/pynative_execute_ge.h"
  48. #endif
  49. using mindspore::tensor::TensorPy;
  50. const char SINGLE_OP_GRAPH[] = "single_op_graph";
  51. // primitive unable to infer value for constant input in PyNative mode
  52. const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"};
  53. namespace mindspore {
  54. namespace pynative {
  55. static std::shared_ptr<session::SessionBasic> session = nullptr;
  56. PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
  57. std::mutex PynativeExecutor::instance_lock_;
  58. ResourcePtr PynativeExecutor::resource_;
  59. template <typename... Args>
  60. void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) {
  61. try {
  62. (executor->*method)(args...);
  63. } catch (const py::error_already_set &ex) {
  64. // print function call stack info before release
  65. std::ostringstream oss;
  66. trace::TraceGraphEval();
  67. trace::GetEvalStackInfo(oss);
  68. // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
  69. // these info from screen, no need to open log file to find these info
  70. py::print(oss.str());
  71. MS_LOG(ERROR) << oss.str();
  72. PynativeExecutor::GetInstance()->Clean();
  73. // re-throw this exception to Python interpreter to handle it
  74. throw(py::error_already_set(ex));
  75. } catch (const py::type_error &ex) {
  76. PynativeExecutor::GetInstance()->Clean();
  77. throw py::type_error(ex);
  78. } catch (const py::value_error &ex) {
  79. PynativeExecutor::GetInstance()->Clean();
  80. throw py::value_error(ex);
  81. } catch (const py::index_error &ex) {
  82. PynativeExecutor::GetInstance()->Clean();
  83. throw py::index_error(ex);
  84. } catch (const std::exception &ex) {
  85. PynativeExecutor::GetInstance()->Clean();
  86. // re-throw this exception to Python interpreter to handle it
  87. throw(std::runtime_error(ex.what()));
  88. } catch (...) {
  89. PynativeExecutor::GetInstance()->Clean();
  90. std::string exName(abi::__cxa_current_exception_type()->name());
  91. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
  92. }
  93. }
  94. inline ValuePtr PyAttrValue(const py::object &obj) {
  95. ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
  96. if (!converted_ret) {
  97. MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj));
  98. }
  99. return converted_ret;
  100. }
  101. std::string GetId(const py::object &obj) {
  102. py::object to_process = obj;
  103. std::string prefix = "";
  104. if (py::isinstance<py::tuple>(to_process)) {
  105. auto p_list = py::cast<py::tuple>(to_process);
  106. if (p_list.size() == 0) {
  107. return "empty";
  108. }
  109. prefix = "tuple:";
  110. std::string key = "";
  111. for (size_t i = 0; i < p_list.size(); ++i) {
  112. key += std::string(py::str(GetId(p_list[i]))) + ":";
  113. }
  114. return prefix + key;
  115. }
  116. if (py::isinstance<py::int_>(to_process)) {
  117. return prefix + std::string(py::str(to_process));
  118. }
  119. if (py::isinstance<py::float_>(to_process)) {
  120. return prefix + std::string(py::str(to_process));
  121. }
  122. if (py::isinstance<tensor::Tensor>(to_process)) {
  123. auto tensor_ptr = py::cast<tensor::TensorPtr>(to_process);
  124. return prefix + tensor_ptr->id();
  125. }
  126. py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
  127. return py::cast<std::string>(ret);
  128. }
  129. py::object GetTupleObj(const py::object &obj) {
  130. py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
  131. py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
  132. return obj_tuple;
  133. }
  134. std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
  135. std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
  136. for (size_t i = 0; i < dtypes.size(); ++i) {
  137. auto it = type_indexes.find(dtypes[i]);
  138. if (it == type_indexes.end()) {
  139. (void)type_indexes.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
  140. } else {
  141. it->second.push_back(i);
  142. }
  143. }
  144. return type_indexes;
  145. }
  146. std::map<SignatureEnumDType, size_t> GetDstType(const py::tuple &py_args,
  147. const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
  148. std::map<SignatureEnumDType, size_t> dst_type;
  149. for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
  150. auto type = it->first;
  151. auto indexes = it->second;
  152. if (indexes.size() < 2) {
  153. continue;
  154. }
  155. size_t m_index = indexes[0];
  156. for (size_t i = 1; i < indexes.size(); ++i) {
  157. if (py::isinstance<tensor::Tensor>(py_args[indexes[i]])) {
  158. m_index = indexes[i];
  159. }
  160. }
  161. (void)dst_type.insert(std::make_pair(type, m_index));
  162. }
  163. return dst_type;
  164. }
  165. py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args,
  166. py::list *const out_args_list) {
  167. auto &py_args = *out_args;
  168. py::tuple input_mask(args.size());
  169. for (size_t i = 0; i < args.size(); ++i) {
  170. if (py::hasattr(args[i], "__parameter__")) {
  171. input_mask[i] = true;
  172. } else {
  173. input_mask[i] = false;
  174. }
  175. py_args[i] = GetTupleObj(args[i]);
  176. }
  177. auto signature = prim->signatures();
  178. std::vector<SignatureEnumDType> dtypes;
  179. (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
  180. [](const Signature &sig) { return sig.dtype; });
  181. int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
  182. if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
  183. return input_mask;
  184. }
  185. auto type_indexes = GetTypeIndex(dtypes);
  186. auto dst_type = GetDstType(py_args, type_indexes);
  187. for (size_t i = 0; i < py_args.size(); ++i) {
  188. auto it = dst_type.find(dtypes[i]);
  189. if (it != dst_type.end() && it->second != i &&
  190. (py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
  191. auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
  192. if (py::isinstance<py::int_>(py_args[i])) {
  193. py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
  194. (*out_args_list)[i] = py_args[i];
  195. } else {
  196. double arg_value = py::cast<py::float_>(py_args[i]);
  197. py_args[i] = std::make_shared<tensor::Tensor>(arg_value, tensor_ptr->Dtype());
  198. (*out_args_list)[i] = py_args[i];
  199. }
  200. continue;
  201. }
  202. }
  203. return input_mask;
  204. }
  205. void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
  206. size_t size = py_args.size();
  207. AbstractBasePtrList args_spec_list;
  208. for (size_t i = 0; i < size; i++) {
  209. ValuePtr input_value = PyAttrValue(py_args[i]);
  210. if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) {
  211. args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
  212. } else {
  213. args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
  214. }
  215. }
  216. AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
  217. op_exec_info->abstract = infer_res;
  218. }
  219. OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) {
  220. if (args.size() != PY_ARGS_NUM) {
  221. MS_LOG(ERROR) << "Three args are needed by RunOp";
  222. return nullptr;
  223. }
  224. auto op_exec_info = std::make_shared<OpExecInfo>();
  225. MS_EXCEPTION_IF_NULL(op_exec_info);
  226. op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
  227. auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
  228. auto pyobj = prim->GetPyObj();
  229. if (pyobj == nullptr) {
  230. MS_LOG(EXCEPTION) << "pyobj is empty";
  231. }
  232. py::list a = args[PY_INPUTS];
  233. size_t input_num = a.size();
  234. op_exec_info->op_inputs = py::tuple(input_num);
  235. op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs, out_args);
  236. // use python infer method
  237. if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
  238. PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
  239. }
  240. op_exec_info->py_primitive = prim;
  241. op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
  242. if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
  243. MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
  244. return nullptr;
  245. }
  246. return op_exec_info;
  247. }
  248. std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
  249. const std::vector<tensor::TensorPtr> &input_tensors) {
  250. MS_EXCEPTION_IF_NULL(op_exec_info);
  251. std::string graph_info;
  252. // get input tensor info
  253. size_t input_num = op_exec_info->op_inputs.size();
  254. for (size_t index = 0; index < input_num; ++index) {
  255. auto input = op_exec_info->op_inputs[index];
  256. if (py::isinstance<tensor::Tensor>(input)) {
  257. auto tensor_ptr = py::cast<tensor::TensorPtr>(input);
  258. (void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_");
  259. }
  260. }
  261. // get prim and abstract info
  262. MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
  263. (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" +
  264. op_exec_info->abstract->ToString());
  265. return graph_info;
  266. }
  267. py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  268. MS_LOG(INFO) << "RunOpInVM start";
  269. MS_EXCEPTION_IF_NULL(status);
  270. MS_EXCEPTION_IF_NULL(op_exec_info);
  271. MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
  272. if (op_exec_info->op_name == "HookBackward") {
  273. auto op_inputs = op_exec_info->op_inputs;
  274. py::tuple result(op_inputs.size());
  275. for (size_t i = 0; i < op_inputs.size(); i++) {
  276. py::object input = op_inputs[i];
  277. if (py::hasattr(input, "__parameter__")) {
  278. result[i] = py::getattr(input, "data");
  279. } else {
  280. auto tensor = py::cast<tensor::TensorPtr>(input);
  281. auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
  282. new_tensor->set_device_address(tensor->device_address());
  283. new_tensor->set_dirty(tensor->is_dirty());
  284. result[i] = new_tensor;
  285. }
  286. }
  287. *status = PYNATIVE_SUCCESS;
  288. MS_LOG(INFO) << "RunOpInVM end";
  289. return std::move(result);
  290. }
  291. auto func = op_exec_info->py_primitive->GetComputeFunction();
  292. if (py::isinstance<py::none>(func)) {
  293. MS_LOG(ERROR) << "VM failed to get func";
  294. *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
  295. py::tuple err_ret(0);
  296. return std::move(err_ret);
  297. }
  298. // execute op
  299. py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs));
  300. *status = PYNATIVE_SUCCESS;
  301. MS_LOG(INFO) << "RunOpInVM end";
  302. return std::move(result);
  303. }
  304. bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
  305. const std::unordered_set<size_t> &input_attrs) {
  306. MS_EXCEPTION_IF_NULL(op_prim);
  307. auto input_names_value = op_prim->GetAttr(kAttrInputNames);
  308. if (input_names_value == nullptr) {
  309. return false;
  310. }
  311. auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
  312. if (input_index >= input_names_vec.size()) {
  313. MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
  314. }
  315. if (input_attrs.find(input_index) != input_attrs.end()) {
  316. ValuePtr value = parse::data_converter::PyDataToValue(input_object);
  317. MS_EXCEPTION_IF_NULL(value);
  318. auto input_name = input_names_vec[input_index];
  319. op_prim->set_attr(input_name, value);
  320. return true;
  321. }
  322. return false;
  323. }
  324. void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
  325. std::vector<tensor::TensorPtr> *input_tensors) {
  326. MS_EXCEPTION_IF_NULL(op_prim);
  327. MS_EXCEPTION_IF_NULL(input_tensors);
  328. for (const auto &input_object : tuple_inputs) {
  329. if (!py::isinstance<tensor::Tensor>(input_object)) {
  330. MS_LOG(EXCEPTION) << "The input object is not a tensor!";
  331. }
  332. auto tensor = py::cast<tensor::TensorPtr>(input_object);
  333. MS_EXCEPTION_IF_NULL(tensor);
  334. input_tensors->push_back(tensor);
  335. }
  336. op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())}));
  337. }
  338. void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensors) {
  339. MS_EXCEPTION_IF_NULL(input_tensors);
  340. ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
  341. MS_EXCEPTION_IF_NULL(input_value);
  342. if (!input_value->isa<ValueTuple>()) {
  343. MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
  344. }
  345. auto value_tuple = input_value->cast<ValueTuplePtr>();
  346. MS_EXCEPTION_IF_NULL(value_tuple);
  347. tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
  348. MS_EXCEPTION_IF_NULL(tensor_ptr);
  349. input_tensors->push_back(tensor_ptr);
  350. }
  351. void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  352. std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
  353. MS_EXCEPTION_IF_NULL(op_prim);
  354. MS_EXCEPTION_IF_NULL(input_tensors);
  355. MS_EXCEPTION_IF_NULL(tensor_mask);
  356. if (!py::isinstance<py::tuple>(input_object)) {
  357. MS_LOG(EXCEPTION) << "The input should be a tuple!";
  358. }
  359. auto tuple_inputs = py::cast<py::tuple>(input_object);
  360. if (tuple_inputs.size() == 0) {
  361. MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
  362. }
  363. if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
  364. PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
  365. } else {
  366. ConvertValueTupleToTensor(input_object, input_tensors);
  367. *tensor_mask = kValueNodeTensorMask;
  368. }
  369. }
  370. void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  371. std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
  372. MS_EXCEPTION_IF_NULL(op_prim);
  373. MS_EXCEPTION_IF_NULL(input_tensors);
  374. MS_EXCEPTION_IF_NULL(tensor_mask);
  375. tensor::TensorPtr tensor_ptr = nullptr;
  376. if (py::isinstance<tensor::Tensor>(input_object)) {
  377. tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
  378. } else if (py::isinstance<py::float_>(input_object)) {
  379. double input_value = py::cast<py::float_>(input_object);
  380. tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
  381. *tensor_mask = kValueNodeTensorMask;
  382. } else if (py::isinstance<py::int_>(input_object)) {
  383. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32);
  384. *tensor_mask = kValueNodeTensorMask;
  385. } else if (py::isinstance<py::array>(input_object)) {
  386. tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
  387. } else if (py::isinstance<py::list>(input_object)) {
  388. auto list_inputs = py::cast<py::list>(input_object);
  389. py::tuple tuple_inputs(list_inputs.size());
  390. for (size_t i = 0; i < tuple_inputs.size(); ++i) {
  391. tuple_inputs[i] = list_inputs[i];
  392. }
  393. ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
  394. return;
  395. } else if (py::isinstance<py::tuple>(input_object)) {
  396. ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
  397. return;
  398. } else if (py::isinstance<py::none>(input_object)) {
  399. return;
  400. } else {
  401. MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
  402. }
  403. MS_EXCEPTION_IF_NULL(tensor_ptr);
  404. input_tensors->push_back(tensor_ptr);
  405. }
  406. void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *tensors_mask,
  407. std::vector<tensor::TensorPtr> *input_tensors) {
  408. MS_EXCEPTION_IF_NULL(op_run_info);
  409. MS_EXCEPTION_IF_NULL(tensors_mask);
  410. MS_EXCEPTION_IF_NULL(input_tensors);
  411. PrimitivePtr op_prim = op_run_info->py_primitive;
  412. MS_EXCEPTION_IF_NULL(op_prim);
  413. if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) {
  414. MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size "
  415. << op_run_info->inputs_mask.size();
  416. }
  417. opt::ConstInputToAttrInfoRegister reg;
  418. bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
  419. size_t input_num = op_run_info->op_inputs.size();
  420. for (size_t index = 0; index < input_num; ++index) {
  421. // convert const input to attr
  422. if (reg_exist &&
  423. RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
  424. continue;
  425. }
  426. // convert const and tuple input to tensor
  427. int tensor_mask = py::cast<int>(op_run_info->inputs_mask[index]);
  428. ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
  429. // mark tensors, data : 0, weight : 1, valuenode: 2
  430. std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
  431. tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
  432. }
  433. }
  434. void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) {
  435. MS_EXCEPTION_IF_NULL(input_tensors);
  436. if (input_tensors->size() != tensors_mask.size()) {
  437. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
  438. << tensors_mask.size();
  439. }
  440. std::vector<tensor::TensorPtr> new_input_tensors;
  441. for (size_t index = 0; index < tensors_mask.size(); ++index) {
  442. if (tensors_mask[index] != kValueNodeTensorMask) {
  443. new_input_tensors.push_back(input_tensors->at(index));
  444. }
  445. }
  446. *input_tensors = new_input_tensors;
  447. }
  448. py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  449. MS_EXCEPTION_IF_NULL(op_exec_info);
  450. MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
  451. auto ms_context = MsContext::GetInstance();
  452. MS_EXCEPTION_IF_NULL(ms_context);
  453. ms_context->set_enable_pynative_infer(true);
  454. std::string device_target = ms_context->device_target();
  455. if (device_target != kAscendDevice && device_target != kGPUDevice) {
  456. MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
  457. }
  458. if (session == nullptr) {
  459. session = session::SessionFactory::Get().Create(device_target);
  460. }
  461. MS_EXCEPTION_IF_NULL(session);
  462. session->Init(ms_context->device_id());
  463. std::vector<tensor::TensorPtr> input_tensors;
  464. std::vector<int> tensors_mask;
  465. ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
  466. // get graph info for checking it whether existing in the cache
  467. std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
  468. session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask);
  469. EraseValueNodeTensor(tensors_mask, &input_tensors);
  470. py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors);
  471. ms_context->set_enable_pynative_infer(false);
  472. *status = PYNATIVE_SUCCESS;
  473. return result;
  474. }
  475. py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr op_exec_info,
  476. PynativeStatusCode *const status) {
  477. MS_EXCEPTION_IF_NULL(status);
  478. py::object result;
  479. switch (backend_policy) {
  480. case kMsBackendVmOnly: {
  481. // use vm only
  482. MS_LOG(INFO) << "RunOp use VM only backend";
  483. result = RunOpInVM(op_exec_info, status);
  484. break;
  485. }
  486. case kMsBackendGePrior: {
  487. #ifdef ENABLE_GE
  488. // use GE first, use vm when GE fails
  489. MS_LOG(INFO) << "RunOp use GE first backend";
  490. result = RunOpInGE(op_exec_info, status);
  491. if (*status != PYNATIVE_SUCCESS) {
  492. result = RunOpInVM(op_exec_info, status);
  493. }
  494. #endif
  495. break;
  496. }
  497. case kMsBackendMsPrior: {
  498. // use Ms fisrt,use others when ms failed
  499. MS_LOG(INFO) << "RunOp use Ms first backend";
  500. result = RunOpInMs(op_exec_info, status);
  501. if (*status != PYNATIVE_SUCCESS) {
  502. MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
  503. }
  504. break;
  505. }
  506. default:
  507. MS_LOG(ERROR) << "No backend configured for run op";
  508. }
  509. return result;
  510. }
  511. AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
  512. if (!grad_flag_ || graph_info_map_.size() == 0) {
  513. return nullptr;
  514. }
  515. std::vector<AnfNodePtr> inputs;
  516. auto prim = op_exec_info->py_primitive;
  517. inputs.push_back(NewValueNode(prim));
  518. py::tuple op_masks = op_exec_info->inputs_mask;
  519. AbstractBasePtrList args_spec_list;
  520. for (size_t i = 0; i < args.size(); i++) {
  521. auto node = GetInput(args[i], op_masks[i]);
  522. args_spec_list.push_back(node->abstract());
  523. inputs.push_back(node);
  524. }
  525. auto cnode = curr_g_->NewCNode(inputs);
  526. MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4);
  527. py::object out_real = out;
  528. if (out.size() == 1) {
  529. MS_LOG(DEBUG) << "MakeCnode out size is one.";
  530. out_real = out[0];
  531. }
  532. std::string obj_id = GetId(out_real);
  533. if (py::isinstance<py::tuple>(out_real)) {
  534. auto value = py::cast<py::tuple>(out_real);
  535. if (value.size() > 1) {
  536. for (int i = 0; i < static_cast<int>(value.size()); i++) {
  537. auto value_id = GetId(value[i]);
  538. MS_LOG(DEBUG) << "MakeCnode set node id " << value_id;
  539. set_obj_node_map(curr_g_, value_id, cnode, i);
  540. }
  541. }
  542. }
  543. MS_LOG(DEBUG) << "MakeCnode set node id " << obj_id;
  544. set_obj_node_map(curr_g_, obj_id, cnode);
  545. set_pyobj(curr_g_, obj_id);
  546. return cnode;
  547. }
  548. AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
  549. auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)];
  550. if (out.second.size() == 1 && out.second[0] == -1) {
  551. return out.first;
  552. }
  553. auto node = out.first;
  554. MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString();
  555. for (auto &idx : out.second) {
  556. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
  557. node = curr_g_->NewCNode(tuple_get_item_inputs);
  558. }
  559. MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6);
  560. return node;
  561. }
  562. py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
  563. MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
  564. mindspore::parse::python_adapter::set_python_env_flag(true);
  565. MsBackendPolicy backend_policy;
  566. #if (!defined ENABLE_GE)
  567. auto ms_context = MsContext::GetInstance();
  568. MS_EXCEPTION_IF_NULL(ms_context);
  569. if (ms_context->backend_policy() == "ms") {
  570. backend_policy = kMsBackendMsPrior;
  571. } else {
  572. backend_policy = kMsBackendVmOnly;
  573. }
  574. #else
  575. auto ms_context = MsContext::GetInstance();
  576. MS_EXCEPTION_IF_NULL(ms_context);
  577. ms_context->PynativeInitGe();
  578. backend_policy = kMsBackendGeOnly;
  579. #endif
  580. if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
  581. backend_policy = kMsBackendVmOnly;
  582. }
  583. PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
  584. // returns a null py::tuple on error
  585. py::tuple err_ret(0);
  586. py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
  587. if (status != PYNATIVE_SUCCESS) {
  588. MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name;
  589. return err_ret;
  590. }
  591. auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
  592. if (node != nullptr) {
  593. node->set_abstract(op_exec_info->abstract);
  594. MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString();
  595. }
  596. MS_LOG(DEBUG) << "RunOp end";
  597. return result;
  598. }
  599. py::tuple RunOpInner(const py::args &args) {
  600. MS_LOG(DEBUG) << "RunOp start" << args.size();
  601. py::list args_input = args[PY_INPUTS];
  602. OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args, &args_input);
  603. MS_EXCEPTION_IF_NULL(op_exec_info);
  604. if (op_exec_info->abstract != nullptr) {
  605. py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
  606. if (!output["value"].is_none()) {
  607. py::tuple value_ret(1);
  608. value_ret[0] = output["value"];
  609. return value_ret;
  610. }
  611. if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
  612. py::tuple value_ret(1);
  613. value_ret[0] = "";
  614. return value_ret;
  615. }
  616. }
  617. return RunOpInner(op_exec_info, args_input);
  618. }
  619. py::tuple RunOp(const py::args &args) {
  620. try {
  621. return RunOpInner(args);
  622. } catch (const py::error_already_set &ex) {
  623. // print function call stack info before release
  624. std::ostringstream oss;
  625. trace::TraceGraphEval();
  626. trace::GetEvalStackInfo(oss);
  627. // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
  628. // these info from screen, no need to open log file to find these info
  629. py::print(oss.str());
  630. MS_LOG(ERROR) << oss.str();
  631. PynativeExecutor::GetInstance()->Clean();
  632. // re-throw this exception to Python interpreter to handle it
  633. throw(py::error_already_set(ex));
  634. } catch (const py::type_error &ex) {
  635. PynativeExecutor::GetInstance()->Clean();
  636. throw py::type_error(ex);
  637. } catch (const py::value_error &ex) {
  638. PynativeExecutor::GetInstance()->Clean();
  639. throw py::value_error(ex);
  640. } catch (const py::index_error &ex) {
  641. PynativeExecutor::GetInstance()->Clean();
  642. throw py::index_error(ex);
  643. } catch (const std::exception &ex) {
  644. PynativeExecutor::GetInstance()->Clean();
  645. // re-throw this exception to Python interpreter to handle it
  646. throw(std::runtime_error(ex.what()));
  647. } catch (...) {
  648. PynativeExecutor::GetInstance()->Clean();
  649. std::string exName(abi::__cxa_current_exception_type()->name());
  650. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
  651. }
  652. }
  653. void ClearPyNativeSession() { session = nullptr; }
  654. PynativeExecutor::~PynativeExecutor() { ClearRes(); }
  655. PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
  656. void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
  657. auto cell_id = GetId(cell);
  658. if (cell_graph_map_.count(cell_id) != 0) {
  659. MS_LOG(DEBUG) << "Newgraph already compiled";
  660. return;
  661. }
  662. auto g = std::make_shared<FuncGraph>();
  663. if (top_g_ == nullptr) {
  664. top_g_ = curr_g_ = g;
  665. df_builder_ = std::make_shared<FuncGraph>();
  666. MS_LOG(DEBUG) << "First new graph" << top_g_.get();
  667. Pushp();
  668. } else {
  669. Pushp();
  670. curr_g_ = g;
  671. }
  672. if (graph_info_map_.count(g) == 0) {
  673. graph_info_map_[g] = GraphInfo();
  674. }
  675. for (size_t i = 0; i < args.size(); i++) {
  676. auto new_param = g->add_parameter();
  677. std::string param_obj = GetId(args[i]);
  678. graph_info_map_[g].param_map[param_obj] = new_param;
  679. }
  680. }
  681. AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) {
  682. ValuePtr converted_ret = nullptr;
  683. parse::ConvertData(obj, &converted_ret);
  684. auto node = NewValueNode(converted_ret);
  685. set_obj_node_map(curr_g_, obj_id, node);
  686. return node;
  687. }
  688. AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) {
  689. AnfNodePtr node = nullptr;
  690. std::string obj_id = GetId(obj);
  691. if (op_mask != nullptr && py::cast<bool>(op_mask)) {
  692. MS_LOG(DEBUG) << "Topgraph free parameter";
  693. // get the parameter name from parameter object
  694. auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name");
  695. if (py::isinstance<py::none>(name_attr)) {
  696. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  697. }
  698. std::string param_name = py::cast<std::string>(name_attr);
  699. if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
  700. auto free_param = df_builder_->add_parameter();
  701. free_param->set_name(param_name);
  702. auto free_param_new = std::make_shared<ParamValuePy>(obj);
  703. free_param->set_default_param(free_param_new);
  704. free_param->debug_info()->set_name(param_name);
  705. MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
  706. graph_info_map_[df_builder_].param_map[obj_id] = free_param;
  707. return free_param;
  708. }
  709. return graph_info_map_[df_builder_].param_map[obj_id];
  710. }
  711. // if input is graph output
  712. if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) {
  713. // op(x, y)
  714. node = graph_info_map_[curr_g_].param_map[obj_id];
  715. } else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) {
  716. // out = op(op1(x, y))
  717. // out = op(cell1(x, y))
  718. // out = op(cell1(x, y)[0])
  719. node = GetObjNode(obj);
  720. } else if (py::isinstance<py::tuple>(obj)) {
  721. // out = op((x, y))
  722. // out = cell((x, y))
  723. auto tuple = obj.cast<py::tuple>();
  724. // cell((1,2)): support not mix (scalar, tensor)
  725. if (tuple.size() > 0 && !py::isinstance<tensor::Tensor>(tuple[0])) {
  726. return MakeValueNode(obj, obj_id);
  727. }
  728. std::vector<AnfNodePtr> args;
  729. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  730. auto tuple_size = static_cast<int>(tuple.size());
  731. for (int i = 0; i < tuple_size; i++) {
  732. args.push_back(GetInput(tuple[i], py::object()));
  733. }
  734. auto cnode = curr_g_->NewCNode(args);
  735. set_obj_node_map(curr_g_, GetId(obj), cnode);
  736. node = cnode;
  737. } else {
  738. node = MakeValueNode(obj, obj_id);
  739. }
  740. MS_LOG(DEBUG) << "Now getinput node " << node->ToString() << obj_id;
  741. return node;
  742. }
  743. // for output[0][1] need getitem multi
  744. void PynativeExecutor::SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector<int> idx) {
  745. if (py::isinstance<py::tuple>(obj)) {
  746. auto tuple = obj.cast<py::tuple>();
  747. for (int i = 0; i < static_cast<int>(tuple.size()); i++) {
  748. std::vector<int> tmp = idx;
  749. tmp.push_back(i);
  750. set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, tmp);
  751. SetTupleOutput(tuple[i], cnode, tmp);
  752. }
  753. }
  754. }
  755. void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); }
  756. void PynativeExecutor::Popp() {
  757. if (graph_p_.empty()) {
  758. MS_LOG(EXCEPTION) << "Stack graph_p_ is empty";
  759. }
  760. curr_g_ = graph_p_.top();
  761. graph_p_.pop();
  762. }
  763. void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
  764. auto cell_id = GetId(cell);
  765. if (cell_graph_map_.count(cell_id) != 0) {
  766. MS_LOG(DEBUG) << "Endgraph already compiled";
  767. return;
  768. }
  769. cell_graph_map_[cell_id] = curr_g_;
  770. auto out_id = GetId(out);
  771. if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) {
  772. // cell construct return x, y
  773. if (py::isinstance<py::tuple>(out)) {
  774. std::vector<AnfNodePtr> args;
  775. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  776. auto tuple = out.cast<py::tuple>();
  777. MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size();
  778. auto tuple_size = static_cast<int>(tuple.size());
  779. auto cnode = curr_g_->NewCNode(args);
  780. for (int i = 0; i < tuple_size; i++) {
  781. args.push_back(GetInput(tuple[i], py::object()));
  782. set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i);
  783. SetTupleOutput(tuple[i], cnode, std::vector<int>{i});
  784. }
  785. cnode->set_inputs(args);
  786. set_obj_node_map(curr_g_, out_id, cnode);
  787. } else {
  788. MS_LOG(ERROR) << "Graph has no this out: " << out_id;
  789. return;
  790. }
  791. }
  792. EndGraphByOutId(out_id, cell, out, args);
  793. }
  794. void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out,
  795. const py::args &args) {
  796. AnfNodePtr output_node;
  797. if (graph_info_map_[curr_g_].param_map.count(out_id)) {
  798. output_node = graph_info_map_[curr_g_].param_map[out_id];
  799. } else {
  800. output_node = GetObjNode(out);
  801. }
  802. curr_g_->set_output(output_node);
  803. std::vector<AnfNodePtr> inputs;
  804. inputs.push_back(NewValueNode(curr_g_));
  805. MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
  806. resource_->manager()->AddFuncGraph(curr_g_);
  807. // custom bprop debug
  808. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  809. MS_LOG(DEBUG) << "Use cell custom bprop function.";
  810. FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
  811. if (bprop_graph != nullptr) {
  812. (void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
  813. (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_)));
  814. }
  815. }
  816. auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
  817. if (curr_g_ != top_g_) {
  818. Popp();
  819. for (size_t i = 0; i < args.size(); i++) {
  820. auto input = GetInput(args[i], py::object());
  821. inputs.push_back(input);
  822. }
  823. auto out_cnode = curr_g_->NewCNode(inputs);
  824. set_pyobj(curr_g_, GetId(cell));
  825. if (py::isinstance<py::tuple>(out)) {
  826. auto out_list = py::cast<py::tuple>(out);
  827. auto out_size = static_cast<int>(out_list.size());
  828. for (int i = 0; i < out_size; i++) {
  829. set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i);
  830. SetTupleOutput(out_list[i], out_cnode, std::vector<int>{i});
  831. }
  832. }
  833. set_obj_node_map(curr_g_, GetId(out), out_cnode);
  834. } else {
  835. parse::ResolveFuncGraph(newfg, resource_);
  836. resource_->set_func_graph(newfg);
  837. }
  838. }
  839. std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weights) {
  840. std::vector<AnfNodePtr> w_args;
  841. if (py::hasattr(weights, "__parameter_tuple__")) {
  842. auto tuple = weights.cast<py::tuple>();
  843. MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size();
  844. w_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  845. for (size_t it = 0; it < tuple.size(); ++it) {
  846. auto param = tuple[it];
  847. auto param_id = GetId(param);
  848. AnfNodePtr para_node = nullptr;
  849. if (graph_info_map_[df_builder_].param_map.count(param_id)) {
  850. para_node = graph_info_map_[df_builder_].param_map[param_id];
  851. AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node);
  852. AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
  853. auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
  854. AnfNodePtr ref_key_node = NewValueNode(refkey);
  855. AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node});
  856. w_args.push_back(ref_node);
  857. }
  858. }
  859. } else {
  860. MS_LOG(EXCEPTION) << "training not paramter_tuple";
  861. }
  862. return w_args;
  863. }
  864. abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) {
  865. abstract::AbstractBasePtrList args_spec;
  866. std::size_t size = args.size();
  867. for (std::size_t i = 0; i < size; i++) {
  868. ValuePtr converted = nullptr;
  869. bool succ = parse::ConvertData(args[i], &converted);
  870. if (!succ) {
  871. MS_LOG(EXCEPTION) << "Args convert error";
  872. }
  873. bool broaden = true;
  874. auto abs = abstract::FromValue(converted, broaden);
  875. args_spec.push_back(abs);
  876. auto param_node = std::static_pointer_cast<Parameter>(df_builder_->parameters()[i]);
  877. param_node->set_abstract(abs);
  878. }
  879. for (const auto &param : df_builder_->parameters()) {
  880. auto param_node = std::static_pointer_cast<Parameter>(param);
  881. if (param_node->has_default()) {
  882. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
  883. AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
  884. if (ptr == nullptr) {
  885. MS_LOG(EXCEPTION) << "Args convert error";
  886. }
  887. args_spec.push_back(ptr);
  888. param_node->set_abstract(ptr);
  889. }
  890. }
  891. return args_spec;
  892. }
  893. void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  894. const py::args &args) {
  895. MS_LOG(INFO) << "GradNet start" << args.size();
  896. std::size_t size = args.size();
  897. auto cell_id = GetId(cell);
  898. if (graph_map_.count(cell_id) != 0) {
  899. MS_LOG(DEBUG) << "GradNet already compiled";
  900. return;
  901. }
  902. MS_LOG(DEBUG) << "GradNet first compiled";
  903. std::vector<AnfNodePtr> new_params;
  904. for (size_t i = 0; i < size; i++) {
  905. ParameterPtr p = std::make_shared<Parameter>(df_builder_);
  906. new_params.push_back(p);
  907. }
  908. MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size();
  909. new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end());
  910. df_builder_->set_parameters(new_params);
  911. resource_->manager()->SetParameters(df_builder_, new_params);
  912. std::vector<AnfNodePtr> w_args = GetWeightsArgs(weights);
  913. MS_EXCEPTION_IF_NULL(resource_->func_graph());
  914. auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
  915. resource_->set_func_graph(g);
  916. resource_->manager()->KeepRoots({g});
  917. // get the parameters items and add the value to args_spec
  918. abstract::AbstractBasePtrList args_spec = GetArgsSpec(args);
  919. MS_LOG(DEBUG) << "Args_spec size" << args_spec.size();
  920. resource_->set_args_spec(args_spec);
  921. MS_LOG(DEBUG) << "Start opt";
  922. // Create backend and session
  923. resource_->results()[pipeline::kBackend] = compile::CreateBackend();
  924. graph_map_[cell_id] = g;
  925. PynativeOptimizeAction(resource_);
  926. TaskEmitAction(resource_);
  927. ExecuteAction(resource_);
  928. resource_->Clean();
  929. ad::CleanRes();
  930. pipeline::ReclaimOptimizer();
  931. }
  932. void PynativeExecutor::Clear(const std::string &flag) {
  933. if (!flag.empty()) {
  934. MS_LOG(INFO) << "Clear res";
  935. (void)graph_map_.erase(flag);
  936. (void)cell_graph_map_.erase(flag);
  937. Clean();
  938. // Maybe exit in the pynative runing op, so need reset pynative flag.
  939. auto ms_context = MsContext::GetInstance();
  940. if (ms_context != nullptr) {
  941. ms_context->set_enable_pynative_infer(false);
  942. }
  943. return;
  944. }
  945. MS_LOG(INFO) << "Clear";
  946. top_g_ = nullptr;
  947. curr_g_ = nullptr;
  948. graph_info_map_.clear();
  949. std::stack<FuncGraphPtr>().swap(graph_p_);
  950. }
  951. void PynativeExecutor::Clean() {
  952. MS_LOG(INFO) << "Clean all res";
  953. Clear();
  954. grad_flag_ = false;
  955. df_builder_ = nullptr;
  956. ad::CleanRes();
  957. pipeline::ReclaimOptimizer();
  958. }
  959. void PynativeExecutor::ClearRes() {
  960. Clean();
  961. resource_.reset();
  962. }
  963. py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) {
  964. VectorRef arg_list;
  965. pipeline::ProcessVmArgInner(args, resource_, &arg_list);
  966. if (resource_->results().find(pipeline::kOutput) == resource_->results().end() ||
  967. !resource_->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
  968. MS_LOG(EXCEPTION) << "Can't find run graph func for ";
  969. }
  970. compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
  971. if (run == nullptr) {
  972. MS_LOG(EXCEPTION) << "Can't find run graph func for ";
  973. }
  974. std::string backend = MsContext::GetInstance()->backend_policy();
  975. MS_LOG(DEBUG) << "Eval run" << backend;
  976. BaseRef value = (*run)(arg_list);
  977. MS_LOG(DEBUG) << "Run end" << value.ToString();
  978. return BaseRefToPyData(value);
  979. }
  980. FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op,
  981. const std::vector<AnfNodePtr> &weights, size_t arg_size) {
  982. auto nparam = top_g_->parameters().size();
  983. std::ostringstream ss;
  984. ss << "grad{" << nparam << "}";
  985. df_builder_->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  986. df_builder_->debug_info()->set_name(ss.str());
  987. auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights);
  988. std::vector<AnfNodePtr> inputs = {NewValueNode(df)};
  989. for (size_t i = 0; i < arg_size; ++i) {
  990. inputs.push_back(df_builder_->parameters()[i]);
  991. }
  992. auto out = df_builder_->NewCNode(inputs);
  993. df_builder_->set_output(out);
  994. resource_->manager()->AddFuncGraph(df);
  995. resource_->manager()->AddFuncGraph(df_builder_);
  996. return df_builder_;
  997. }
  998. void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
  999. PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args);
  1000. }
  1001. void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
  1002. PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
  1003. }
  1004. void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  1005. const py::args &args) {
  1006. PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
  1007. }
  1008. REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
  1009. (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
  1010. .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
  1011. .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
  1012. .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
  1013. .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
  1014. .def("clear", &PynativeExecutor::Clear, "pynative clear status.")
  1015. .def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""),
  1016. "Executor run function.")
  1017. .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
  1018. "Executor set grad flag.");
  1019. }));
  1020. } // namespace pynative
  1021. } // namespace mindspore