You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pynative_execute.cc 37 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pynative/pynative_execute.h"
  17. #include <typeinfo>
  18. #include <map>
  19. #include <set>
  20. #include <unordered_set>
  21. #include <algorithm>
  22. #include "ir/param_value_py.h"
  23. #include "utils/any.h"
  24. #include "utils/utils.h"
  25. #include "utils/context/ms_context.h"
  26. #include "operator/ops.h"
  27. #include "operator/composite/composite.h"
  28. #include "operator/composite/do_signature.h"
  29. #include "pipeline/parse/data_converter.h"
  30. #include "pipeline/parse/parse_base.h"
  31. #include "pipeline/parse/resolve.h"
  32. #include "pipeline/static_analysis/prim.h"
  33. #include "session/session_factory.h"
  34. #include "pre_activate/pass/const_input_to_attr_registry.h"
  35. #include "pre_activate/common/helper.h"
  36. #include "pipeline/action.h"
  37. #include "pynative/base.h"
  38. #include "pybind_api/api_register.h"
  39. #include "vm/transform.h"
  40. #include "optimizer/ad/grad.h"
  41. #include "pipeline/resource.h"
  42. #include "pipeline/pipeline.h"
  43. #include "pipeline/pass.h"
  44. #ifdef ENABLE_GE
  45. #include "pynative/pynative_execute_ge.h"
  46. #endif
  47. const char SINGLE_OP_GRAPH[] = "single_op_graph";
  48. // primitive unable to infer value for constant input in PyNative mode
  49. const std::set<std::string> vm_operators = {"make_ref", "HookBackward"};
  50. namespace mindspore {
  51. namespace pynative {
  52. static std::shared_ptr<session::SessionBasic> session = nullptr;
  53. PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
  54. std::mutex PynativeExecutor::instance_lock_;
  55. ResourcePtr PynativeExecutor::resource_;
  56. inline ValuePtr PyAttrValue(const py::object &obj) {
  57. ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
  58. if (!converted_ret) {
  59. MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj));
  60. }
  61. return converted_ret;
  62. }
  63. std::string GetId(const py::object &obj) {
  64. py::object to_process = obj;
  65. std::string prefix = "";
  66. if (py::isinstance<py::tuple>(to_process)) {
  67. auto p_list = py::cast<py::tuple>(to_process);
  68. if (p_list.size() == 0) {
  69. return "empty";
  70. }
  71. to_process = p_list[0];
  72. prefix = "tuple:";
  73. if (!py::isinstance<tensor::Tensor>(to_process)) {
  74. std::string key = "";
  75. for (size_t i = 0; i < p_list.size(); ++i) {
  76. key += std::string(py::str(p_list[i])) + ":";
  77. }
  78. return prefix + key;
  79. }
  80. }
  81. if (py::isinstance<py::int_>(to_process)) {
  82. return prefix + std::string(py::str(to_process));
  83. }
  84. if (py::isinstance<py::float_>(to_process)) {
  85. return prefix + std::string(py::str(to_process));
  86. }
  87. if (py::isinstance<tensor::Tensor>(to_process)) {
  88. auto tensor_ptr = py::cast<tensor::TensorPtr>(to_process);
  89. return prefix + tensor_ptr->id();
  90. }
  91. py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
  92. return py::cast<std::string>(ret);
  93. }
  94. py::object GetTupleObj(const py::object &obj) {
  95. py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
  96. py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
  97. return obj_tuple;
  98. }
  99. std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
  100. std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
  101. for (size_t i = 0; i < dtypes.size(); ++i) {
  102. auto it = type_indexes.find(dtypes[i]);
  103. if (it == type_indexes.end()) {
  104. (void)type_indexes.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
  105. } else {
  106. it->second.push_back(i);
  107. }
  108. }
  109. return type_indexes;
  110. }
  111. std::map<SignatureEnumDType, size_t> GetDstType(const py::tuple &py_args,
  112. const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
  113. std::map<SignatureEnumDType, size_t> dst_type;
  114. for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
  115. auto type = it->first;
  116. auto indexes = it->second;
  117. if (indexes.size() < 2) {
  118. continue;
  119. }
  120. size_t m_index = indexes[0];
  121. for (size_t i = 1; i < indexes.size(); ++i) {
  122. if (py::isinstance<tensor::Tensor>(py_args[indexes[i]])) {
  123. m_index = indexes[i];
  124. }
  125. }
  126. (void)dst_type.insert(std::make_pair(type, m_index));
  127. }
  128. return dst_type;
  129. }
  130. py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args) {
  131. auto &py_args = *out_args;
  132. py::tuple input_mask(args.size());
  133. for (size_t i = 0; i < args.size(); ++i) {
  134. if (py::hasattr(args[i], "__parameter__")) {
  135. input_mask[i] = true;
  136. } else {
  137. input_mask[i] = false;
  138. }
  139. py_args[i] = GetTupleObj(args[i]);
  140. }
  141. auto signature = prim->signatures();
  142. std::vector<SignatureEnumDType> dtypes;
  143. (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
  144. [](const Signature &sig) { return sig.dtype; });
  145. int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
  146. if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
  147. return input_mask;
  148. }
  149. auto type_indexes = GetTypeIndex(dtypes);
  150. auto dst_type = GetDstType(py_args, type_indexes);
  151. for (size_t i = 0; i < py_args.size(); ++i) {
  152. auto it = dst_type.find(dtypes[i]);
  153. if (it != dst_type.end() && it->second != i &&
  154. (py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
  155. auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
  156. if (py::isinstance<py::int_>(py_args[i])) {
  157. py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
  158. } else {
  159. py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
  160. }
  161. continue;
  162. }
  163. }
  164. return input_mask;
  165. }
  166. void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
  167. size_t size = py_args.size();
  168. AbstractBasePtrList args_spec_list;
  169. for (size_t i = 0; i < size; i++) {
  170. ValuePtr input_value = PyAttrValue(py_args[i]);
  171. if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) {
  172. args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
  173. } else {
  174. args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
  175. }
  176. }
  177. AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
  178. op_exec_info->abstract = infer_res;
  179. }
  180. OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
  181. if (args.size() != PY_ARGS_NUM) {
  182. MS_LOG(ERROR) << "Three args are needed by RunOp";
  183. return nullptr;
  184. }
  185. auto op_exec_info = std::make_shared<OpExecInfo>();
  186. MS_EXCEPTION_IF_NULL(op_exec_info);
  187. op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
  188. auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
  189. auto pyobj = prim->GetPyObj();
  190. if (pyobj == nullptr) {
  191. MS_LOG(EXCEPTION) << "pyobj is empty";
  192. }
  193. py::list a = args[PY_INPUTS];
  194. size_t input_num = a.size();
  195. op_exec_info->op_inputs = py::tuple(input_num);
  196. op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
  197. // use python infer method
  198. if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
  199. PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
  200. }
  201. op_exec_info->py_primitive = prim;
  202. op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
  203. if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
  204. MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
  205. return nullptr;
  206. }
  207. return op_exec_info;
  208. }
  209. std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
  210. const std::vector<tensor::TensorPtr> &input_tensors) {
  211. MS_EXCEPTION_IF_NULL(op_exec_info);
  212. std::string graph_info;
  213. // get input tensor info
  214. size_t input_num = op_exec_info->op_inputs.size();
  215. for (size_t index = 0; index < input_num; ++index) {
  216. auto input = op_exec_info->op_inputs[index];
  217. if (py::isinstance<tensor::Tensor>(input)) {
  218. auto tensor_ptr = py::cast<tensor::TensorPtr>(input);
  219. (void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_");
  220. }
  221. }
  222. // get prim and abstract info
  223. MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
  224. (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" +
  225. op_exec_info->abstract->ToString());
  226. return graph_info;
  227. }
  228. py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  229. MS_LOG(INFO) << "RunOpInVM start";
  230. MS_EXCEPTION_IF_NULL(status);
  231. MS_EXCEPTION_IF_NULL(op_exec_info);
  232. MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
  233. if (op_exec_info->op_name == "HookBackward") {
  234. auto op_inputs = op_exec_info->op_inputs;
  235. py::tuple result(op_inputs.size());
  236. for (size_t i = 0; i < op_inputs.size(); i++) {
  237. py::object input = op_inputs[i];
  238. if (py::hasattr(input, "__parameter__")) {
  239. result[i] = py::getattr(input, "data");
  240. } else {
  241. auto tensor = py::cast<tensor::TensorPtr>(op_inputs[i]);
  242. auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data());
  243. result[i] = new_tensor;
  244. }
  245. }
  246. *status = PYNATIVE_SUCCESS;
  247. MS_LOG(INFO) << "RunOpInVM end";
  248. return std::move(result);
  249. }
  250. auto func = op_exec_info->py_primitive->GetComputeFunction();
  251. if (py::isinstance<py::none>(func)) {
  252. MS_LOG(ERROR) << "VM failed to get func";
  253. *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
  254. py::tuple err_ret(0);
  255. return std::move(err_ret);
  256. }
  257. // execute op
  258. py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs));
  259. *status = PYNATIVE_SUCCESS;
  260. MS_LOG(INFO) << "RunOpInVM end";
  261. return std::move(result);
  262. }
  263. bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
  264. const std::unordered_set<size_t> &input_attrs) {
  265. MS_EXCEPTION_IF_NULL(op_prim);
  266. auto input_names_value = op_prim->GetAttr(kAttrInputNames);
  267. if (input_names_value == nullptr) {
  268. return false;
  269. }
  270. auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
  271. if (input_index >= input_names_vec.size()) {
  272. MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
  273. }
  274. if (input_attrs.find(input_index) != input_attrs.end()) {
  275. ValuePtr value = parse::data_converter::PyDataToValue(input_object);
  276. MS_EXCEPTION_IF_NULL(value);
  277. auto input_name = input_names_vec[input_index];
  278. op_prim->set_attr(input_name, value);
  279. return true;
  280. }
  281. return false;
  282. }
  283. void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
  284. std::vector<tensor::TensorPtr> *input_tensors) {
  285. MS_EXCEPTION_IF_NULL(op_prim);
  286. MS_EXCEPTION_IF_NULL(input_tensors);
  287. for (const auto &input_object : tuple_inputs) {
  288. if (!py::isinstance<tensor::Tensor>(input_object)) {
  289. MS_LOG(EXCEPTION) << "The input object is not a tensor!";
  290. }
  291. auto tensor = py::cast<tensor::TensorPtr>(input_object);
  292. MS_EXCEPTION_IF_NULL(tensor);
  293. input_tensors->push_back(tensor);
  294. }
  295. op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())}));
  296. }
  297. void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensors) {
  298. MS_EXCEPTION_IF_NULL(input_tensors);
  299. ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
  300. MS_EXCEPTION_IF_NULL(input_value);
  301. if (!input_value->isa<ValueTuple>()) {
  302. MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
  303. }
  304. auto value_tuple = input_value->cast<ValueTuplePtr>();
  305. MS_EXCEPTION_IF_NULL(value_tuple);
  306. tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
  307. MS_EXCEPTION_IF_NULL(tensor_ptr);
  308. input_tensors->push_back(tensor_ptr);
  309. }
  310. void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  311. std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
  312. MS_EXCEPTION_IF_NULL(op_prim);
  313. MS_EXCEPTION_IF_NULL(input_tensors);
  314. MS_EXCEPTION_IF_NULL(tensor_mask);
  315. if (!py::isinstance<py::tuple>(input_object)) {
  316. MS_LOG(EXCEPTION) << "The input should be a tuple!";
  317. }
  318. auto tuple_inputs = py::cast<py::tuple>(input_object);
  319. if (tuple_inputs.size() == 0) {
  320. MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
  321. }
  322. if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
  323. PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
  324. } else {
  325. ConvertValueTupleToTensor(input_object, input_tensors);
  326. *tensor_mask = kValueNodeTensorMask;
  327. }
  328. }
  329. void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  330. std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
  331. MS_EXCEPTION_IF_NULL(op_prim);
  332. MS_EXCEPTION_IF_NULL(input_tensors);
  333. MS_EXCEPTION_IF_NULL(tensor_mask);
  334. tensor::TensorPtr tensor_ptr = nullptr;
  335. if (py::isinstance<tensor::Tensor>(input_object)) {
  336. tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
  337. } else if (py::isinstance<py::float_>(input_object)) {
  338. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(input_object), kFloat32);
  339. *tensor_mask = kValueNodeTensorMask;
  340. } else if (py::isinstance<py::int_>(input_object)) {
  341. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32);
  342. *tensor_mask = kValueNodeTensorMask;
  343. } else if (py::isinstance<py::array>(input_object)) {
  344. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr);
  345. } else if (py::isinstance<py::list>(input_object)) {
  346. auto list_inputs = py::cast<py::list>(input_object);
  347. py::tuple tuple_inputs(list_inputs.size());
  348. for (size_t i = 0; i < tuple_inputs.size(); ++i) {
  349. tuple_inputs[i] = list_inputs[i];
  350. }
  351. ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
  352. return;
  353. } else if (py::isinstance<py::tuple>(input_object)) {
  354. ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
  355. return;
  356. } else if (py::isinstance<py::none>(input_object)) {
  357. return;
  358. } else {
  359. MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
  360. }
  361. MS_EXCEPTION_IF_NULL(tensor_ptr);
  362. input_tensors->push_back(tensor_ptr);
  363. }
  364. void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *tensors_mask,
  365. std::vector<tensor::TensorPtr> *input_tensors) {
  366. MS_EXCEPTION_IF_NULL(op_run_info);
  367. MS_EXCEPTION_IF_NULL(tensors_mask);
  368. MS_EXCEPTION_IF_NULL(input_tensors);
  369. PrimitivePtr op_prim = op_run_info->py_primitive;
  370. MS_EXCEPTION_IF_NULL(op_prim);
  371. if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) {
  372. MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size "
  373. << op_run_info->inputs_mask.size();
  374. }
  375. opt::ConstInputToAttrInfoRegister reg;
  376. bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
  377. size_t input_num = op_run_info->op_inputs.size();
  378. for (size_t index = 0; index < input_num; ++index) {
  379. // convert const input to attr
  380. if (reg_exist &&
  381. RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
  382. continue;
  383. }
  384. // convert const and tuple input to tensor
  385. int tensor_mask = py::cast<int>(op_run_info->inputs_mask[index]);
  386. ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
  387. // mark tensors, data : 0, weight : 1, valuenode: 2
  388. std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
  389. tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
  390. }
  391. }
  392. void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) {
  393. MS_EXCEPTION_IF_NULL(input_tensors);
  394. if (input_tensors->size() != tensors_mask.size()) {
  395. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
  396. << tensors_mask.size();
  397. }
  398. std::vector<tensor::TensorPtr> new_input_tensors;
  399. for (size_t index = 0; index < tensors_mask.size(); ++index) {
  400. if (tensors_mask[index] != kValueNodeTensorMask) {
  401. new_input_tensors.push_back(input_tensors->at(index));
  402. }
  403. }
  404. *input_tensors = new_input_tensors;
  405. }
  406. py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  407. MS_EXCEPTION_IF_NULL(op_exec_info);
  408. MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
  409. auto ms_context = MsContext::GetInstance();
  410. MS_EXCEPTION_IF_NULL(ms_context);
  411. ms_context->set_enable_pynative_infer(true);
  412. std::string device_target = ms_context->device_target();
  413. if (device_target != kAscendDevice && device_target != kGPUDevice) {
  414. MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
  415. }
  416. if (session == nullptr) {
  417. session = session::SessionFactory::Get().Create(device_target);
  418. }
  419. MS_EXCEPTION_IF_NULL(session);
  420. session->Init(ms_context->device_id());
  421. std::vector<tensor::TensorPtr> input_tensors;
  422. std::vector<int> tensors_mask;
  423. ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
  424. // get graph info for checking it whether existing in the cache
  425. std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
  426. session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask);
  427. EraseValueNodeTensor(tensors_mask, &input_tensors);
  428. py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors);
  429. ms_context->set_enable_pynative_infer(false);
  430. *status = PYNATIVE_SUCCESS;
  431. return result;
  432. }
  433. py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr op_exec_info,
  434. PynativeStatusCode *const status) {
  435. MS_EXCEPTION_IF_NULL(status);
  436. py::object result;
  437. switch (backend_policy) {
  438. case kMsBackendVmOnly: {
  439. // use vm only
  440. MS_LOG(INFO) << "RunOp use VM only backend";
  441. result = RunOpInVM(op_exec_info, status);
  442. break;
  443. }
  444. case kMsBackendGePrior: {
  445. #ifdef ENABLE_GE
  446. // use GE first, use vm when GE fails
  447. MS_LOG(INFO) << "RunOp use GE first backend";
  448. result = RunOpInGE(op_exec_info, status);
  449. if (*status != PYNATIVE_SUCCESS) {
  450. result = RunOpInVM(op_exec_info, status);
  451. }
  452. #endif
  453. break;
  454. }
  455. case kMsBackendMsPrior: {
  456. // use Ms fisrt,use others when ms failed
  457. MS_LOG(INFO) << "RunOp use Ms first backend";
  458. result = RunOpInMs(op_exec_info, status);
  459. if (*status != PYNATIVE_SUCCESS) {
  460. MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
  461. }
  462. break;
  463. }
  464. default:
  465. MS_LOG(ERROR) << "No backend configured for run op";
  466. }
  467. return result;
  468. }
  469. AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
  470. if (!grad_flag_ || graph_info_map_.size() == 0) {
  471. return nullptr;
  472. }
  473. std::vector<AnfNodePtr> inputs;
  474. auto prim = op_exec_info->py_primitive;
  475. inputs.push_back(NewValueNode(prim));
  476. py::tuple op_masks = op_exec_info->inputs_mask;
  477. py::list op_args = args[PY_INPUTS];
  478. AbstractBasePtrList args_spec_list;
  479. for (size_t i = 0; i < op_args.size(); i++) {
  480. auto node = GetInput(op_args[i], op_masks[i]);
  481. args_spec_list.push_back(node->abstract());
  482. inputs.push_back(node);
  483. }
  484. auto cnode = curr_g_->NewCNode(inputs);
  485. MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString();
  486. py::object out_real = out;
  487. if (out.size() == 1) {
  488. MS_LOG(DEBUG) << "MakeCnode out size is one.";
  489. out_real = out[0];
  490. }
  491. std::string obj_id = GetId(out_real);
  492. if (py::isinstance<py::tuple>(out_real)) {
  493. auto value = py::cast<py::tuple>(out_real);
  494. if (value.size() > 1) {
  495. for (int i = 0; i < static_cast<int>(value.size()); i++) {
  496. auto value_id = GetId(value[i]);
  497. set_obj_node_map(curr_g_, value_id, cnode, i);
  498. }
  499. }
  500. }
  501. set_obj_node_map(curr_g_, obj_id, cnode);
  502. set_pyobj(curr_g_, obj_id);
  503. return cnode;
  504. }
  505. AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
  506. auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)];
  507. if (out.second == -1) {
  508. return out.first;
  509. }
  510. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), out.first,
  511. NewValueNode(out.second)};
  512. return curr_g_->NewCNode(tuple_get_item_inputs);
  513. }
  514. py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) {
  515. MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
  516. mindspore::parse::python_adapter::set_python_env_flag(true);
  517. MsBackendPolicy backend_policy;
  518. #if (!defined ENABLE_GE)
  519. auto ms_context = MsContext::GetInstance();
  520. MS_EXCEPTION_IF_NULL(ms_context);
  521. if (ms_context->backend_policy() == "ms") {
  522. backend_policy = kMsBackendMsPrior;
  523. } else {
  524. backend_policy = kMsBackendVmOnly;
  525. }
  526. #else
  527. auto ms_context = MsContext::GetInstance();
  528. MS_EXCEPTION_IF_NULL(ms_context);
  529. ms_context->PynativeInitGe();
  530. backend_policy = kMsBackendGeOnly;
  531. #endif
  532. if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
  533. backend_policy = kMsBackendVmOnly;
  534. }
  535. PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
  536. // returns a null py::tuple on error
  537. py::tuple err_ret(0);
  538. py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
  539. if (status != PYNATIVE_SUCCESS) {
  540. MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name;
  541. return err_ret;
  542. }
  543. auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
  544. if (node != nullptr) {
  545. node->set_abstract(op_exec_info->abstract);
  546. MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString();
  547. }
  548. MS_LOG(DEBUG) << "RunOp end";
  549. return result;
  550. }
  551. py::tuple RunOp(const py::args &args) {
  552. MS_LOG(DEBUG) << "RunOp start" << args.size();
  553. OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args);
  554. MS_EXCEPTION_IF_NULL(op_exec_info);
  555. if (op_exec_info->abstract != nullptr) {
  556. py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
  557. if (!output["value"].is_none()) {
  558. py::tuple value_ret(1);
  559. value_ret[0] = output["value"];
  560. return value_ret;
  561. }
  562. if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
  563. py::tuple value_ret(1);
  564. value_ret[0] = "";
  565. return value_ret;
  566. }
  567. }
  568. return RunOp(op_exec_info, args);
  569. }
  570. void ClearPyNativeSession() { session = nullptr; }
  571. PynativeExecutor::~PynativeExecutor() { ClearRes(); }
  572. PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
  573. void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
  574. auto cell_id = GetId(cell);
  575. if (cell_graph_map_.count(cell_id) != 0) {
  576. MS_LOG(DEBUG) << "Newgraph already compiled";
  577. return;
  578. }
  579. auto g = std::make_shared<FuncGraph>();
  580. if (top_g_ == nullptr) {
  581. top_g_ = curr_g_ = g;
  582. df_builder_ = std::make_shared<FuncGraph>();
  583. MS_LOG(DEBUG) << "First new graph" << top_g_.get();
  584. Pushp();
  585. } else {
  586. Pushp();
  587. curr_g_ = g;
  588. }
  589. if (graph_info_map_.count(g) == 0) {
  590. graph_info_map_[g] = GraphInfo();
  591. }
  592. for (size_t i = 0; i < args.size(); i++) {
  593. auto new_param = g->add_parameter();
  594. std::string param_obj = GetId(args[i]);
  595. graph_info_map_[g].param_map[param_obj] = new_param;
  596. }
  597. }
  598. AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) {
  599. AnfNodePtr node = nullptr;
  600. std::string obj_id = GetId(obj);
  601. if (op_mask != nullptr && py::cast<bool>(op_mask)) {
  602. MS_LOG(DEBUG) << "Topgraph free parameter";
  603. // get the parameter name from parameter object
  604. auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name");
  605. if (py::isinstance<py::none>(name_attr)) {
  606. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  607. }
  608. std::string param_name = py::cast<std::string>(name_attr);
  609. if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
  610. auto free_param = df_builder_->add_parameter();
  611. free_param->set_name(param_name);
  612. auto free_param_new = std::make_shared<ParamValuePy>(obj);
  613. free_param->set_default_param(free_param_new);
  614. free_param->debug_info()->set_name(param_name);
  615. MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
  616. graph_info_map_[df_builder_].param_map[obj_id] = free_param;
  617. return free_param;
  618. }
  619. return graph_info_map_[df_builder_].param_map[obj_id];
  620. }
  621. // if input is graph output
  622. if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) {
  623. // op(x, y)
  624. node = graph_info_map_[curr_g_].param_map[obj_id];
  625. } else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) {
  626. // out = op(op1(x, y))
  627. // out = op(cell1(x, y))
  628. // out = op(cell1(x, y)[0])
  629. node = GetObjNode(obj);
  630. } else if (py::isinstance<py::tuple>(obj)) {
  631. // out = op((x, y))
  632. // out = cell((x, y))
  633. std::vector<AnfNodePtr> args;
  634. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  635. auto tuple = obj.cast<py::tuple>();
  636. auto tuple_size = static_cast<int>(tuple.size());
  637. for (int i = 0; i < tuple_size; i++) {
  638. args.push_back(GetInput(tuple[i], py::object()));
  639. }
  640. auto cnode = curr_g_->NewCNode(args);
  641. set_obj_node_map(curr_g_, GetId(obj), cnode);
  642. node = cnode;
  643. } else {
  644. // out = op(x, 1)
  645. ValuePtr converted_ret = nullptr;
  646. parse::ConvertData(obj, &converted_ret);
  647. node = NewValueNode(converted_ret);
  648. set_obj_node_map(curr_g_, obj_id, node);
  649. }
  650. MS_LOG(DEBUG) << "Now getinput " << py::str(obj) << " node " << node->ToString();
  651. return node;
  652. }
  653. void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); }
  654. void PynativeExecutor::Popp() {
  655. if (graph_p_.empty()) {
  656. MS_LOG(EXCEPTION) << "Stack graph_p_ is empty";
  657. }
  658. curr_g_ = graph_p_.top();
  659. graph_p_.pop();
  660. }
  661. void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
  662. auto cell_id = GetId(cell);
  663. if (cell_graph_map_.count(cell_id) != 0) {
  664. MS_LOG(DEBUG) << "Endgraph already compiled";
  665. return;
  666. }
  667. cell_graph_map_[cell_id] = curr_g_;
  668. auto out_id = GetId(out);
  669. if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) {
  670. // cell construct return x, y
  671. if (py::isinstance<py::tuple>(out)) {
  672. std::vector<AnfNodePtr> args;
  673. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  674. auto tuple = out.cast<py::tuple>();
  675. MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size();
  676. auto tuple_size = static_cast<int>(tuple.size());
  677. auto cnode = curr_g_->NewCNode(args);
  678. for (int i = 0; i < tuple_size; i++) {
  679. args.push_back(GetInput(tuple[i], py::object()));
  680. set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i);
  681. }
  682. cnode->set_inputs(args);
  683. set_obj_node_map(curr_g_, out_id, cnode);
  684. } else {
  685. MS_LOG(ERROR) << "Graph has no this out: " << out_id;
  686. return;
  687. }
  688. }
  689. EndGraphByOutId(out_id, cell, out, args);
  690. }
  691. void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out,
  692. const py::args &args) {
  693. AnfNodePtr output_node;
  694. if (graph_info_map_[curr_g_].param_map.count(out_id)) {
  695. output_node = graph_info_map_[curr_g_].param_map[out_id];
  696. } else {
  697. output_node = GetObjNode(out);
  698. }
  699. curr_g_->set_output(output_node);
  700. std::vector<AnfNodePtr> inputs;
  701. inputs.push_back(NewValueNode(curr_g_));
  702. MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
  703. resource_->manager()->AddFuncGraph(curr_g_);
  704. // custom bprop debug
  705. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  706. MS_LOG(DEBUG) << "Use cell custom bprop function.";
  707. FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
  708. if (bprop_graph != nullptr) {
  709. (void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
  710. (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_)));
  711. }
  712. }
  713. auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
  714. if (curr_g_ != top_g_) {
  715. Popp();
  716. for (size_t i = 0; i < args.size(); i++) {
  717. auto input = GetInput(args[i], py::object());
  718. inputs.push_back(input);
  719. }
  720. auto out_cnode = curr_g_->NewCNode(inputs);
  721. set_pyobj(curr_g_, GetId(cell));
  722. if (py::isinstance<py::tuple>(out)) {
  723. auto out_list = py::cast<py::tuple>(out);
  724. auto out_size = static_cast<int>(out_list.size());
  725. for (int i = 0; i < out_size; i++) {
  726. set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i);
  727. }
  728. }
  729. set_obj_node_map(curr_g_, GetId(out), out_cnode);
  730. } else {
  731. parse::ResolveFuncGraph(newfg, resource_);
  732. resource_->set_func_graph(newfg);
  733. }
  734. }
  735. std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weights) {
  736. std::vector<AnfNodePtr> w_args;
  737. if (py::hasattr(weights, "__parameter_tuple__")) {
  738. auto tuple = weights.cast<py::tuple>();
  739. MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size();
  740. w_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  741. for (size_t it = 0; it < tuple.size(); ++it) {
  742. auto param = tuple[it];
  743. auto param_id = GetId(param);
  744. AnfNodePtr para_node = nullptr;
  745. if (graph_info_map_[df_builder_].param_map.count(param_id)) {
  746. para_node = graph_info_map_[df_builder_].param_map[param_id];
  747. AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node);
  748. AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
  749. auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
  750. AnfNodePtr ref_key_node = NewValueNode(refkey);
  751. AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node});
  752. w_args.push_back(ref_node);
  753. }
  754. }
  755. } else {
  756. MS_LOG(EXCEPTION) << "training not paramter_tuple";
  757. }
  758. return w_args;
  759. }
  760. abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) {
  761. abstract::AbstractBasePtrList args_spec;
  762. std::size_t size = args.size();
  763. for (std::size_t i = 0; i < size; i++) {
  764. ValuePtr converted = nullptr;
  765. bool succ = parse::ConvertData(args[i], &converted);
  766. if (!succ) {
  767. MS_LOG(EXCEPTION) << "Args convert error";
  768. }
  769. bool broaden = true;
  770. auto abs = abstract::FromValue(converted, broaden);
  771. args_spec.push_back(abs);
  772. auto param_node = std::static_pointer_cast<Parameter>(df_builder_->parameters()[i]);
  773. param_node->set_abstract(abs);
  774. }
  775. for (const auto &param : df_builder_->parameters()) {
  776. auto param_node = std::static_pointer_cast<Parameter>(param);
  777. if (param_node->has_default()) {
  778. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
  779. AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
  780. if (ptr == nullptr) {
  781. MS_LOG(EXCEPTION) << "Args convert error";
  782. }
  783. args_spec.push_back(ptr);
  784. param_node->set_abstract(ptr);
  785. }
  786. }
  787. return args_spec;
  788. }
  789. void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  790. const py::args &args) {
  791. MS_LOG(INFO) << "GradNet start" << args.size();
  792. std::size_t size = args.size();
  793. auto cell_id = GetId(cell);
  794. if (graph_map_.count(cell_id) != 0) {
  795. MS_LOG(DEBUG) << "GradNet already compiled";
  796. return;
  797. }
  798. MS_LOG(DEBUG) << "GradNet first compiled";
  799. std::vector<AnfNodePtr> new_params;
  800. for (size_t i = 0; i < size; i++) {
  801. ParameterPtr p = std::make_shared<Parameter>(df_builder_);
  802. new_params.push_back(p);
  803. }
  804. MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size();
  805. new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end());
  806. df_builder_->set_parameters(new_params);
  807. resource_->manager()->SetParameters(df_builder_, new_params);
  808. std::vector<AnfNodePtr> w_args = GetWeightsArgs(weights);
  809. MS_EXCEPTION_IF_NULL(resource_->func_graph());
  810. auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
  811. resource_->set_func_graph(g);
  812. // get the parameters items and add the value to args_spec
  813. abstract::AbstractBasePtrList args_spec = GetArgsSpec(args);
  814. MS_LOG(DEBUG) << "Args_spec size" << args_spec.size();
  815. resource_->set_args_spec(args_spec);
  816. MS_LOG(DEBUG) << "Start opt";
  817. // Create backend and session
  818. resource_->results()[pipeline::kBackend] = compile::CreateBackend();
  819. graph_map_[cell_id] = g;
  820. PynativeOptimizeAction(resource_);
  821. TaskEmitAction(resource_);
  822. ExecuteAction(resource_);
  823. resource_->Clean();
  824. ad::CleanRes();
  825. pipeline::ReclaimOptimizer();
  826. }
  827. void PynativeExecutor::Clear(const std::string &flag) {
  828. if (flag == "resource") {
  829. MS_LOG(INFO) << "Clear res";
  830. Clean();
  831. return;
  832. }
  833. MS_LOG(INFO) << "Clear";
  834. top_g_ = nullptr;
  835. curr_g_ = nullptr;
  836. graph_info_map_.clear();
  837. std::stack<FuncGraphPtr>().swap(graph_p_);
  838. }
  839. void PynativeExecutor::Clean() {
  840. MS_LOG(INFO) << "Clean all res";
  841. Clear();
  842. grad_flag_ = false;
  843. df_builder_ = nullptr;
  844. ad::CleanRes();
  845. pipeline::ReclaimOptimizer();
  846. }
  847. void PynativeExecutor::ClearRes() {
  848. Clean();
  849. resource_.reset();
  850. }
  851. py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) {
  852. VectorRef arg_list;
  853. pipeline::ProcessVmArgInner(args, resource_, &arg_list);
  854. if (resource_->results().find(pipeline::kOutput) == resource_->results().end() ||
  855. !resource_->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
  856. MS_LOG(EXCEPTION) << "Can't find run graph func for ";
  857. }
  858. compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
  859. if (run == nullptr) {
  860. MS_LOG(EXCEPTION) << "Can't find run graph func for ";
  861. }
  862. std::string backend = MsContext::GetInstance()->backend_policy();
  863. MS_LOG(DEBUG) << "Eval run" << backend;
  864. BaseRef value = (*run)(arg_list);
  865. MS_LOG(DEBUG) << "Run end" << value.ToString();
  866. return BaseRefToPyData(value);
  867. }
  868. FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op,
  869. const std::vector<AnfNodePtr> &weights, size_t arg_size) {
  870. auto nparam = top_g_->parameters().size();
  871. std::ostringstream ss;
  872. ss << "grad{" << nparam << "}";
  873. df_builder_->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  874. df_builder_->debug_info()->set_name(ss.str());
  875. auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights);
  876. std::vector<AnfNodePtr> inputs = {NewValueNode(df)};
  877. for (size_t i = 0; i < arg_size; ++i) {
  878. inputs.push_back(df_builder_->parameters()[i]);
  879. }
  880. auto out = df_builder_->NewCNode(inputs);
  881. df_builder_->set_output(out);
  882. resource_->manager()->AddFuncGraph(df);
  883. resource_->manager()->AddFuncGraph(df_builder_);
  884. return df_builder_;
  885. }
  886. REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
  887. (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
  888. .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
  889. .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
  890. .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
  891. .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
  892. .def("clear", &PynativeExecutor::Clear, "pynative clear status.")
  893. .def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""),
  894. "Executor run function.")
  895. .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
  896. "Executor set grad flag.");
  897. }));
  898. } // namespace pynative
  899. } // namespace mindspore