You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pynative_execute.cc 39 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pynative/pynative_execute.h"
  17. #include <typeinfo>
  18. #include <map>
  19. #include <set>
  20. #include <unordered_set>
  21. #include <algorithm>
  22. #include "ir/param_value_py.h"
  23. #include "utils/any.h"
  24. #include "utils/utils.h"
  25. #include "utils/context/ms_context.h"
  26. #include "operator/ops.h"
  27. #include "operator/composite/composite.h"
  28. #include "operator/composite/do_signature.h"
  29. #include "pipeline/parse/data_converter.h"
  30. #include "pipeline/parse/parse_base.h"
  31. #include "pipeline/parse/resolve.h"
  32. #include "pipeline/static_analysis/prim.h"
  33. #include "session/session_factory.h"
  34. #include "pre_activate/pass/const_input_to_attr_registry.h"
  35. #include "pre_activate/common/helper.h"
  36. #include "pipeline/action.h"
  37. #include "pynative/base.h"
  38. #include "pybind_api/api_register.h"
  39. #include "vm/transform.h"
  40. #include "optimizer/ad/grad.h"
  41. #include "pipeline/resource.h"
  42. #include "pipeline/pipeline.h"
  43. #include "pipeline/pass.h"
  44. #ifdef ENABLE_GE
  45. #include "pynative/pynative_execute_ge.h"
  46. #endif
  47. const char SINGLE_OP_GRAPH[] = "single_op_graph";
  48. // primitive unable to infer value for constant input in PyNative mode
  49. const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"};
  50. namespace mindspore {
  51. namespace pynative {
  52. static std::shared_ptr<session::SessionBasic> session = nullptr;
  53. PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
  54. std::mutex PynativeExecutor::instance_lock_;
  55. ResourcePtr PynativeExecutor::resource_;
  56. inline ValuePtr PyAttrValue(const py::object &obj) {
  57. ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
  58. if (!converted_ret) {
  59. MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj));
  60. }
  61. return converted_ret;
  62. }
  63. std::string GetId(const py::object &obj) {
  64. py::object to_process = obj;
  65. std::string prefix = "";
  66. if (py::isinstance<py::tuple>(to_process)) {
  67. auto p_list = py::cast<py::tuple>(to_process);
  68. if (p_list.size() == 0) {
  69. return "empty";
  70. }
  71. prefix = "tuple:";
  72. std::string key = "";
  73. for (size_t i = 0; i < p_list.size(); ++i) {
  74. key += std::string(py::str(GetId(p_list[i]))) + ":";
  75. }
  76. return prefix + key;
  77. }
  78. if (py::isinstance<py::int_>(to_process)) {
  79. return prefix + std::string(py::str(to_process));
  80. }
  81. if (py::isinstance<py::float_>(to_process)) {
  82. return prefix + std::string(py::str(to_process));
  83. }
  84. if (py::isinstance<tensor::Tensor>(to_process)) {
  85. auto tensor_ptr = py::cast<tensor::TensorPtr>(to_process);
  86. return prefix + tensor_ptr->id();
  87. }
  88. py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
  89. return py::cast<std::string>(ret);
  90. }
  91. py::object GetTupleObj(const py::object &obj) {
  92. py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
  93. py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
  94. return obj_tuple;
  95. }
  96. std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
  97. std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
  98. for (size_t i = 0; i < dtypes.size(); ++i) {
  99. auto it = type_indexes.find(dtypes[i]);
  100. if (it == type_indexes.end()) {
  101. (void)type_indexes.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
  102. } else {
  103. it->second.push_back(i);
  104. }
  105. }
  106. return type_indexes;
  107. }
  108. std::map<SignatureEnumDType, size_t> GetDstType(const py::tuple &py_args,
  109. const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
  110. std::map<SignatureEnumDType, size_t> dst_type;
  111. for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
  112. auto type = it->first;
  113. auto indexes = it->second;
  114. if (indexes.size() < 2) {
  115. continue;
  116. }
  117. size_t m_index = indexes[0];
  118. for (size_t i = 1; i < indexes.size(); ++i) {
  119. if (py::isinstance<tensor::Tensor>(py_args[indexes[i]])) {
  120. m_index = indexes[i];
  121. }
  122. }
  123. (void)dst_type.insert(std::make_pair(type, m_index));
  124. }
  125. return dst_type;
  126. }
  127. py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args,
  128. py::list *out_args_list) {
  129. auto &py_args = *out_args;
  130. py::tuple input_mask(args.size());
  131. for (size_t i = 0; i < args.size(); ++i) {
  132. if (py::hasattr(args[i], "__parameter__")) {
  133. input_mask[i] = true;
  134. } else {
  135. input_mask[i] = false;
  136. }
  137. py_args[i] = GetTupleObj(args[i]);
  138. }
  139. auto signature = prim->signatures();
  140. std::vector<SignatureEnumDType> dtypes;
  141. (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
  142. [](const Signature &sig) { return sig.dtype; });
  143. int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
  144. if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
  145. return input_mask;
  146. }
  147. auto type_indexes = GetTypeIndex(dtypes);
  148. auto dst_type = GetDstType(py_args, type_indexes);
  149. for (size_t i = 0; i < py_args.size(); ++i) {
  150. auto it = dst_type.find(dtypes[i]);
  151. if (it != dst_type.end() && it->second != i &&
  152. (py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
  153. auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
  154. if (py::isinstance<py::int_>(py_args[i])) {
  155. py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
  156. (*out_args_list)[i] = py_args[i];
  157. } else {
  158. py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
  159. (*out_args_list)[i] = py_args[i];
  160. }
  161. continue;
  162. }
  163. }
  164. return input_mask;
  165. }
  166. void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
  167. size_t size = py_args.size();
  168. AbstractBasePtrList args_spec_list;
  169. for (size_t i = 0; i < size; i++) {
  170. ValuePtr input_value = PyAttrValue(py_args[i]);
  171. if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) {
  172. args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
  173. } else {
  174. args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
  175. }
  176. }
  177. AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
  178. op_exec_info->abstract = infer_res;
  179. }
  180. OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) {
  181. if (args.size() != PY_ARGS_NUM) {
  182. MS_LOG(ERROR) << "Three args are needed by RunOp";
  183. return nullptr;
  184. }
  185. auto op_exec_info = std::make_shared<OpExecInfo>();
  186. MS_EXCEPTION_IF_NULL(op_exec_info);
  187. op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
  188. auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
  189. auto pyobj = prim->GetPyObj();
  190. if (pyobj == nullptr) {
  191. MS_LOG(EXCEPTION) << "pyobj is empty";
  192. }
  193. py::list a = args[PY_INPUTS];
  194. size_t input_num = a.size();
  195. op_exec_info->op_inputs = py::tuple(input_num);
  196. op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs, out_args);
  197. // use python infer method
  198. if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
  199. PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
  200. }
  201. op_exec_info->py_primitive = prim;
  202. op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
  203. if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
  204. MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
  205. return nullptr;
  206. }
  207. return op_exec_info;
  208. }
  209. std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
  210. const std::vector<tensor::TensorPtr> &input_tensors) {
  211. MS_EXCEPTION_IF_NULL(op_exec_info);
  212. std::string graph_info;
  213. // get input tensor info
  214. size_t input_num = op_exec_info->op_inputs.size();
  215. for (size_t index = 0; index < input_num; ++index) {
  216. auto input = op_exec_info->op_inputs[index];
  217. if (py::isinstance<tensor::Tensor>(input)) {
  218. auto tensor_ptr = py::cast<tensor::TensorPtr>(input);
  219. (void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_");
  220. }
  221. }
  222. // get prim and abstract info
  223. MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
  224. (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" +
  225. op_exec_info->abstract->ToString());
  226. return graph_info;
  227. }
  228. py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  229. MS_LOG(INFO) << "RunOpInVM start";
  230. MS_EXCEPTION_IF_NULL(status);
  231. MS_EXCEPTION_IF_NULL(op_exec_info);
  232. MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
  233. if (op_exec_info->op_name == "HookBackward") {
  234. auto op_inputs = op_exec_info->op_inputs;
  235. py::tuple result(op_inputs.size());
  236. for (size_t i = 0; i < op_inputs.size(); i++) {
  237. py::object input = op_inputs[i];
  238. if (py::hasattr(input, "__parameter__")) {
  239. result[i] = py::getattr(input, "data");
  240. } else {
  241. auto tensor = py::cast<tensor::TensorPtr>(op_inputs[i]);
  242. auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data());
  243. result[i] = new_tensor;
  244. }
  245. }
  246. *status = PYNATIVE_SUCCESS;
  247. MS_LOG(INFO) << "RunOpInVM end";
  248. return std::move(result);
  249. }
  250. auto func = op_exec_info->py_primitive->GetComputeFunction();
  251. if (py::isinstance<py::none>(func)) {
  252. MS_LOG(ERROR) << "VM failed to get func";
  253. *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
  254. py::tuple err_ret(0);
  255. return std::move(err_ret);
  256. }
  257. // execute op
  258. py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs));
  259. *status = PYNATIVE_SUCCESS;
  260. MS_LOG(INFO) << "RunOpInVM end";
  261. return std::move(result);
  262. }
  263. bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
  264. const std::unordered_set<size_t> &input_attrs) {
  265. MS_EXCEPTION_IF_NULL(op_prim);
  266. auto input_names_value = op_prim->GetAttr(kAttrInputNames);
  267. if (input_names_value == nullptr) {
  268. return false;
  269. }
  270. auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
  271. if (input_index >= input_names_vec.size()) {
  272. MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
  273. }
  274. if (input_attrs.find(input_index) != input_attrs.end()) {
  275. ValuePtr value = parse::data_converter::PyDataToValue(input_object);
  276. MS_EXCEPTION_IF_NULL(value);
  277. auto input_name = input_names_vec[input_index];
  278. op_prim->set_attr(input_name, value);
  279. return true;
  280. }
  281. return false;
  282. }
  283. void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
  284. std::vector<tensor::TensorPtr> *input_tensors) {
  285. MS_EXCEPTION_IF_NULL(op_prim);
  286. MS_EXCEPTION_IF_NULL(input_tensors);
  287. for (const auto &input_object : tuple_inputs) {
  288. if (!py::isinstance<tensor::Tensor>(input_object)) {
  289. MS_LOG(EXCEPTION) << "The input object is not a tensor!";
  290. }
  291. auto tensor = py::cast<tensor::TensorPtr>(input_object);
  292. MS_EXCEPTION_IF_NULL(tensor);
  293. input_tensors->push_back(tensor);
  294. }
  295. op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())}));
  296. }
  297. void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensors) {
  298. MS_EXCEPTION_IF_NULL(input_tensors);
  299. ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
  300. MS_EXCEPTION_IF_NULL(input_value);
  301. if (!input_value->isa<ValueTuple>()) {
  302. MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
  303. }
  304. auto value_tuple = input_value->cast<ValueTuplePtr>();
  305. MS_EXCEPTION_IF_NULL(value_tuple);
  306. tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
  307. MS_EXCEPTION_IF_NULL(tensor_ptr);
  308. input_tensors->push_back(tensor_ptr);
  309. }
  310. void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  311. std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
  312. MS_EXCEPTION_IF_NULL(op_prim);
  313. MS_EXCEPTION_IF_NULL(input_tensors);
  314. MS_EXCEPTION_IF_NULL(tensor_mask);
  315. if (!py::isinstance<py::tuple>(input_object)) {
  316. MS_LOG(EXCEPTION) << "The input should be a tuple!";
  317. }
  318. auto tuple_inputs = py::cast<py::tuple>(input_object);
  319. if (tuple_inputs.size() == 0) {
  320. MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
  321. }
  322. if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
  323. PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
  324. } else {
  325. ConvertValueTupleToTensor(input_object, input_tensors);
  326. *tensor_mask = kValueNodeTensorMask;
  327. }
  328. }
  329. void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  330. std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
  331. MS_EXCEPTION_IF_NULL(op_prim);
  332. MS_EXCEPTION_IF_NULL(input_tensors);
  333. MS_EXCEPTION_IF_NULL(tensor_mask);
  334. tensor::TensorPtr tensor_ptr = nullptr;
  335. if (py::isinstance<tensor::Tensor>(input_object)) {
  336. tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
  337. } else if (py::isinstance<py::float_>(input_object)) {
  338. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(input_object), kFloat32);
  339. *tensor_mask = kValueNodeTensorMask;
  340. } else if (py::isinstance<py::int_>(input_object)) {
  341. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32);
  342. *tensor_mask = kValueNodeTensorMask;
  343. } else if (py::isinstance<py::array>(input_object)) {
  344. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr);
  345. } else if (py::isinstance<py::list>(input_object)) {
  346. auto list_inputs = py::cast<py::list>(input_object);
  347. py::tuple tuple_inputs(list_inputs.size());
  348. for (size_t i = 0; i < tuple_inputs.size(); ++i) {
  349. tuple_inputs[i] = list_inputs[i];
  350. }
  351. ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
  352. return;
  353. } else if (py::isinstance<py::tuple>(input_object)) {
  354. ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
  355. return;
  356. } else if (py::isinstance<py::none>(input_object)) {
  357. return;
  358. } else {
  359. MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
  360. }
  361. MS_EXCEPTION_IF_NULL(tensor_ptr);
  362. input_tensors->push_back(tensor_ptr);
  363. }
  364. void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *tensors_mask,
  365. std::vector<tensor::TensorPtr> *input_tensors) {
  366. MS_EXCEPTION_IF_NULL(op_run_info);
  367. MS_EXCEPTION_IF_NULL(tensors_mask);
  368. MS_EXCEPTION_IF_NULL(input_tensors);
  369. PrimitivePtr op_prim = op_run_info->py_primitive;
  370. MS_EXCEPTION_IF_NULL(op_prim);
  371. if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) {
  372. MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size "
  373. << op_run_info->inputs_mask.size();
  374. }
  375. opt::ConstInputToAttrInfoRegister reg;
  376. bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
  377. size_t input_num = op_run_info->op_inputs.size();
  378. for (size_t index = 0; index < input_num; ++index) {
  379. // convert const input to attr
  380. if (reg_exist &&
  381. RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
  382. continue;
  383. }
  384. // convert const and tuple input to tensor
  385. int tensor_mask = py::cast<int>(op_run_info->inputs_mask[index]);
  386. ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
  387. // mark tensors, data : 0, weight : 1, valuenode: 2
  388. std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
  389. tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
  390. }
  391. }
  392. void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) {
  393. MS_EXCEPTION_IF_NULL(input_tensors);
  394. if (input_tensors->size() != tensors_mask.size()) {
  395. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
  396. << tensors_mask.size();
  397. }
  398. std::vector<tensor::TensorPtr> new_input_tensors;
  399. for (size_t index = 0; index < tensors_mask.size(); ++index) {
  400. if (tensors_mask[index] != kValueNodeTensorMask) {
  401. new_input_tensors.push_back(input_tensors->at(index));
  402. }
  403. }
  404. *input_tensors = new_input_tensors;
  405. }
  406. py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  407. MS_EXCEPTION_IF_NULL(op_exec_info);
  408. MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
  409. auto ms_context = MsContext::GetInstance();
  410. MS_EXCEPTION_IF_NULL(ms_context);
  411. ms_context->set_enable_pynative_infer(true);
  412. std::string device_target = ms_context->device_target();
  413. if (device_target != kAscendDevice && device_target != kGPUDevice) {
  414. MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
  415. }
  416. if (session == nullptr) {
  417. session = session::SessionFactory::Get().Create(device_target);
  418. }
  419. MS_EXCEPTION_IF_NULL(session);
  420. session->Init(ms_context->device_id());
  421. std::vector<tensor::TensorPtr> input_tensors;
  422. std::vector<int> tensors_mask;
  423. ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
  424. // get graph info for checking it whether existing in the cache
  425. std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
  426. session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask);
  427. EraseValueNodeTensor(tensors_mask, &input_tensors);
  428. py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors);
  429. ms_context->set_enable_pynative_infer(false);
  430. *status = PYNATIVE_SUCCESS;
  431. return result;
  432. }
  433. py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr op_exec_info,
  434. PynativeStatusCode *const status) {
  435. MS_EXCEPTION_IF_NULL(status);
  436. py::object result;
  437. switch (backend_policy) {
  438. case kMsBackendVmOnly: {
  439. // use vm only
  440. MS_LOG(INFO) << "RunOp use VM only backend";
  441. result = RunOpInVM(op_exec_info, status);
  442. break;
  443. }
  444. case kMsBackendGePrior: {
  445. #ifdef ENABLE_GE
  446. // use GE first, use vm when GE fails
  447. MS_LOG(INFO) << "RunOp use GE first backend";
  448. result = RunOpInGE(op_exec_info, status);
  449. if (*status != PYNATIVE_SUCCESS) {
  450. result = RunOpInVM(op_exec_info, status);
  451. }
  452. #endif
  453. break;
  454. }
  455. case kMsBackendMsPrior: {
  456. // use Ms fisrt,use others when ms failed
  457. MS_LOG(INFO) << "RunOp use Ms first backend";
  458. result = RunOpInMs(op_exec_info, status);
  459. if (*status != PYNATIVE_SUCCESS) {
  460. MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
  461. }
  462. break;
  463. }
  464. default:
  465. MS_LOG(ERROR) << "No backend configured for run op";
  466. }
  467. return result;
  468. }
  469. AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
  470. if (!grad_flag_ || graph_info_map_.size() == 0) {
  471. return nullptr;
  472. }
  473. std::vector<AnfNodePtr> inputs;
  474. auto prim = op_exec_info->py_primitive;
  475. inputs.push_back(NewValueNode(prim));
  476. py::tuple op_masks = op_exec_info->inputs_mask;
  477. AbstractBasePtrList args_spec_list;
  478. for (size_t i = 0; i < args.size(); i++) {
  479. auto node = GetInput(args[i], op_masks[i]);
  480. args_spec_list.push_back(node->abstract());
  481. inputs.push_back(node);
  482. }
  483. auto cnode = curr_g_->NewCNode(inputs);
  484. MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4);
  485. py::object out_real = out;
  486. if (out.size() == 1) {
  487. MS_LOG(DEBUG) << "MakeCnode out size is one.";
  488. out_real = out[0];
  489. }
  490. std::string obj_id = GetId(out_real);
  491. if (py::isinstance<py::tuple>(out_real)) {
  492. auto value = py::cast<py::tuple>(out_real);
  493. if (value.size() > 1) {
  494. for (int i = 0; i < static_cast<int>(value.size()); i++) {
  495. auto value_id = GetId(value[i]);
  496. MS_LOG(DEBUG) << "MakeCnode set node id " << value_id;
  497. set_obj_node_map(curr_g_, value_id, cnode, i);
  498. }
  499. }
  500. }
  501. MS_LOG(DEBUG) << "MakeCnode set node id " << obj_id;
  502. set_obj_node_map(curr_g_, obj_id, cnode);
  503. set_pyobj(curr_g_, obj_id);
  504. return cnode;
  505. }
  506. AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
  507. auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)];
  508. if (out.second.size() == 1 && out.second[0] == -1) {
  509. return out.first;
  510. }
  511. auto node = out.first;
  512. MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString();
  513. for (auto &idx : out.second) {
  514. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
  515. node = curr_g_->NewCNode(tuple_get_item_inputs);
  516. }
  517. MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6);
  518. return node;
  519. }
  520. py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) {
  521. MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
  522. mindspore::parse::python_adapter::set_python_env_flag(true);
  523. MsBackendPolicy backend_policy;
  524. #if (!defined ENABLE_GE)
  525. auto ms_context = MsContext::GetInstance();
  526. MS_EXCEPTION_IF_NULL(ms_context);
  527. if (ms_context->backend_policy() == "ms") {
  528. backend_policy = kMsBackendMsPrior;
  529. } else {
  530. backend_policy = kMsBackendVmOnly;
  531. }
  532. #else
  533. auto ms_context = MsContext::GetInstance();
  534. MS_EXCEPTION_IF_NULL(ms_context);
  535. ms_context->PynativeInitGe();
  536. backend_policy = kMsBackendGeOnly;
  537. #endif
  538. if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
  539. backend_policy = kMsBackendVmOnly;
  540. }
  541. PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
  542. // returns a null py::tuple on error
  543. py::tuple err_ret(0);
  544. py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
  545. if (status != PYNATIVE_SUCCESS) {
  546. MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name;
  547. return err_ret;
  548. }
  549. auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
  550. if (node != nullptr) {
  551. node->set_abstract(op_exec_info->abstract);
  552. MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString();
  553. }
  554. MS_LOG(DEBUG) << "RunOp end";
  555. return result;
  556. }
  557. py::tuple RunOp(const py::args &args) {
  558. MS_LOG(DEBUG) << "RunOp start" << args.size();
  559. py::list args_input = args[PY_INPUTS];
  560. OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args, &args_input);
  561. MS_EXCEPTION_IF_NULL(op_exec_info);
  562. if (op_exec_info->abstract != nullptr) {
  563. py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
  564. if (!output["value"].is_none()) {
  565. py::tuple value_ret(1);
  566. value_ret[0] = output["value"];
  567. return value_ret;
  568. }
  569. if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
  570. py::tuple value_ret(1);
  571. value_ret[0] = "";
  572. return value_ret;
  573. }
  574. }
  575. return RunOp(op_exec_info, args_input);
  576. }
  577. void ClearPyNativeSession() { session = nullptr; }
  578. PynativeExecutor::~PynativeExecutor() { ClearRes(); }
  579. PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
  580. void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
  581. auto cell_id = GetId(cell);
  582. if (cell_graph_map_.count(cell_id) != 0) {
  583. MS_LOG(DEBUG) << "Newgraph already compiled";
  584. return;
  585. }
  586. auto g = std::make_shared<FuncGraph>();
  587. if (top_g_ == nullptr) {
  588. top_g_ = curr_g_ = g;
  589. df_builder_ = std::make_shared<FuncGraph>();
  590. MS_LOG(DEBUG) << "First new graph" << top_g_.get();
  591. Pushp();
  592. } else {
  593. Pushp();
  594. curr_g_ = g;
  595. }
  596. if (graph_info_map_.count(g) == 0) {
  597. graph_info_map_[g] = GraphInfo();
  598. }
  599. for (size_t i = 0; i < args.size(); i++) {
  600. auto new_param = g->add_parameter();
  601. std::string param_obj = GetId(args[i]);
  602. graph_info_map_[g].param_map[param_obj] = new_param;
  603. }
  604. }
  605. AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) {
  606. ValuePtr converted_ret = nullptr;
  607. parse::ConvertData(obj, &converted_ret);
  608. auto node = NewValueNode(converted_ret);
  609. set_obj_node_map(curr_g_, obj_id, node);
  610. return node;
  611. }
  612. AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) {
  613. AnfNodePtr node = nullptr;
  614. std::string obj_id = GetId(obj);
  615. if (op_mask != nullptr && py::cast<bool>(op_mask)) {
  616. MS_LOG(DEBUG) << "Topgraph free parameter";
  617. // get the parameter name from parameter object
  618. auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name");
  619. if (py::isinstance<py::none>(name_attr)) {
  620. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  621. }
  622. std::string param_name = py::cast<std::string>(name_attr);
  623. if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
  624. auto free_param = df_builder_->add_parameter();
  625. free_param->set_name(param_name);
  626. auto free_param_new = std::make_shared<ParamValuePy>(obj);
  627. free_param->set_default_param(free_param_new);
  628. free_param->debug_info()->set_name(param_name);
  629. MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
  630. graph_info_map_[df_builder_].param_map[obj_id] = free_param;
  631. return free_param;
  632. }
  633. return graph_info_map_[df_builder_].param_map[obj_id];
  634. }
  635. // if input is graph output
  636. if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) {
  637. // op(x, y)
  638. node = graph_info_map_[curr_g_].param_map[obj_id];
  639. } else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) {
  640. // out = op(op1(x, y))
  641. // out = op(cell1(x, y))
  642. // out = op(cell1(x, y)[0])
  643. node = GetObjNode(obj);
  644. } else if (py::isinstance<py::tuple>(obj)) {
  645. // out = op((x, y))
  646. // out = cell((x, y))
  647. auto tuple = obj.cast<py::tuple>();
  648. // cell((1,2)): support not mix (scalar, tensor)
  649. if (tuple.size() > 0 && !py::isinstance<tensor::Tensor>(tuple[0])) {
  650. return MakeValueNode(obj, obj_id);
  651. }
  652. std::vector<AnfNodePtr> args;
  653. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  654. auto tuple_size = static_cast<int>(tuple.size());
  655. for (int i = 0; i < tuple_size; i++) {
  656. args.push_back(GetInput(tuple[i], py::object()));
  657. }
  658. auto cnode = curr_g_->NewCNode(args);
  659. set_obj_node_map(curr_g_, GetId(obj), cnode);
  660. node = cnode;
  661. } else {
  662. node = MakeValueNode(obj, obj_id);
  663. }
  664. MS_LOG(DEBUG) << "Now getinput node " << node->ToString() << obj_id;
  665. return node;
  666. }
  667. // for output[0][1] need getitem multi
  668. void PynativeExecutor::SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector<int> idx) {
  669. if (py::isinstance<py::tuple>(obj)) {
  670. auto tuple = obj.cast<py::tuple>();
  671. for (int i = 0; i < static_cast<int>(tuple.size()); i++) {
  672. std::vector<int> tmp = idx;
  673. tmp.push_back(i);
  674. set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, tmp);
  675. SetTupleOutput(tuple[i], cnode, tmp);
  676. }
  677. }
  678. }
  679. void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); }
  680. void PynativeExecutor::Popp() {
  681. if (graph_p_.empty()) {
  682. MS_LOG(EXCEPTION) << "Stack graph_p_ is empty";
  683. }
  684. curr_g_ = graph_p_.top();
  685. graph_p_.pop();
  686. }
  687. void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
  688. auto cell_id = GetId(cell);
  689. if (cell_graph_map_.count(cell_id) != 0) {
  690. MS_LOG(DEBUG) << "Endgraph already compiled";
  691. return;
  692. }
  693. cell_graph_map_[cell_id] = curr_g_;
  694. auto out_id = GetId(out);
  695. if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) {
  696. // cell construct return x, y
  697. if (py::isinstance<py::tuple>(out)) {
  698. std::vector<AnfNodePtr> args;
  699. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  700. auto tuple = out.cast<py::tuple>();
  701. MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size();
  702. auto tuple_size = static_cast<int>(tuple.size());
  703. auto cnode = curr_g_->NewCNode(args);
  704. for (int i = 0; i < tuple_size; i++) {
  705. args.push_back(GetInput(tuple[i], py::object()));
  706. set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i);
  707. SetTupleOutput(tuple[i], cnode, std::vector<int>{i});
  708. }
  709. cnode->set_inputs(args);
  710. set_obj_node_map(curr_g_, out_id, cnode);
  711. } else {
  712. MS_LOG(ERROR) << "Graph has no this out: " << out_id;
  713. return;
  714. }
  715. }
  716. EndGraphByOutId(out_id, cell, out, args);
  717. }
  718. void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out,
  719. const py::args &args) {
  720. AnfNodePtr output_node;
  721. if (graph_info_map_[curr_g_].param_map.count(out_id)) {
  722. output_node = graph_info_map_[curr_g_].param_map[out_id];
  723. } else {
  724. output_node = GetObjNode(out);
  725. }
  726. curr_g_->set_output(output_node);
  727. std::vector<AnfNodePtr> inputs;
  728. inputs.push_back(NewValueNode(curr_g_));
  729. MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
  730. resource_->manager()->AddFuncGraph(curr_g_);
  731. // custom bprop debug
  732. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  733. MS_LOG(DEBUG) << "Use cell custom bprop function.";
  734. FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
  735. if (bprop_graph != nullptr) {
  736. (void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
  737. (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_)));
  738. }
  739. }
  740. auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
  741. if (curr_g_ != top_g_) {
  742. Popp();
  743. for (size_t i = 0; i < args.size(); i++) {
  744. auto input = GetInput(args[i], py::object());
  745. inputs.push_back(input);
  746. }
  747. auto out_cnode = curr_g_->NewCNode(inputs);
  748. set_pyobj(curr_g_, GetId(cell));
  749. if (py::isinstance<py::tuple>(out)) {
  750. auto out_list = py::cast<py::tuple>(out);
  751. auto out_size = static_cast<int>(out_list.size());
  752. for (int i = 0; i < out_size; i++) {
  753. set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i);
  754. SetTupleOutput(out_list[i], out_cnode, std::vector<int>{i});
  755. }
  756. }
  757. set_obj_node_map(curr_g_, GetId(out), out_cnode);
  758. } else {
  759. parse::ResolveFuncGraph(newfg, resource_);
  760. resource_->set_func_graph(newfg);
  761. }
  762. }
  763. std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weights) {
  764. std::vector<AnfNodePtr> w_args;
  765. if (py::hasattr(weights, "__parameter_tuple__")) {
  766. auto tuple = weights.cast<py::tuple>();
  767. MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size();
  768. w_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  769. for (size_t it = 0; it < tuple.size(); ++it) {
  770. auto param = tuple[it];
  771. auto param_id = GetId(param);
  772. AnfNodePtr para_node = nullptr;
  773. if (graph_info_map_[df_builder_].param_map.count(param_id)) {
  774. para_node = graph_info_map_[df_builder_].param_map[param_id];
  775. AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node);
  776. AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
  777. auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
  778. AnfNodePtr ref_key_node = NewValueNode(refkey);
  779. AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node});
  780. w_args.push_back(ref_node);
  781. }
  782. }
  783. } else {
  784. MS_LOG(EXCEPTION) << "training not paramter_tuple";
  785. }
  786. return w_args;
  787. }
  788. abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) {
  789. abstract::AbstractBasePtrList args_spec;
  790. std::size_t size = args.size();
  791. for (std::size_t i = 0; i < size; i++) {
  792. ValuePtr converted = nullptr;
  793. bool succ = parse::ConvertData(args[i], &converted);
  794. if (!succ) {
  795. MS_LOG(EXCEPTION) << "Args convert error";
  796. }
  797. bool broaden = true;
  798. auto abs = abstract::FromValue(converted, broaden);
  799. args_spec.push_back(abs);
  800. auto param_node = std::static_pointer_cast<Parameter>(df_builder_->parameters()[i]);
  801. param_node->set_abstract(abs);
  802. }
  803. for (const auto &param : df_builder_->parameters()) {
  804. auto param_node = std::static_pointer_cast<Parameter>(param);
  805. if (param_node->has_default()) {
  806. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
  807. AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
  808. if (ptr == nullptr) {
  809. MS_LOG(EXCEPTION) << "Args convert error";
  810. }
  811. args_spec.push_back(ptr);
  812. param_node->set_abstract(ptr);
  813. }
  814. }
  815. return args_spec;
  816. }
  817. void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  818. const py::args &args) {
  819. MS_LOG(INFO) << "GradNet start" << args.size();
  820. std::size_t size = args.size();
  821. auto cell_id = GetId(cell);
  822. if (graph_map_.count(cell_id) != 0) {
  823. MS_LOG(DEBUG) << "GradNet already compiled";
  824. return;
  825. }
  826. MS_LOG(DEBUG) << "GradNet first compiled";
  827. std::vector<AnfNodePtr> new_params;
  828. for (size_t i = 0; i < size; i++) {
  829. ParameterPtr p = std::make_shared<Parameter>(df_builder_);
  830. new_params.push_back(p);
  831. }
  832. MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size();
  833. new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end());
  834. df_builder_->set_parameters(new_params);
  835. resource_->manager()->SetParameters(df_builder_, new_params);
  836. std::vector<AnfNodePtr> w_args = GetWeightsArgs(weights);
  837. MS_EXCEPTION_IF_NULL(resource_->func_graph());
  838. auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
  839. resource_->set_func_graph(g);
  840. resource_->manager()->KeepRoots({g});
  841. // get the parameters items and add the value to args_spec
  842. abstract::AbstractBasePtrList args_spec = GetArgsSpec(args);
  843. MS_LOG(DEBUG) << "Args_spec size" << args_spec.size();
  844. resource_->set_args_spec(args_spec);
  845. MS_LOG(DEBUG) << "Start opt";
  846. // Create backend and session
  847. resource_->results()[pipeline::kBackend] = compile::CreateBackend();
  848. graph_map_[cell_id] = g;
  849. PynativeOptimizeAction(resource_);
  850. TaskEmitAction(resource_);
  851. ExecuteAction(resource_);
  852. resource_->Clean();
  853. ad::CleanRes();
  854. pipeline::ReclaimOptimizer();
  855. }
  856. void PynativeExecutor::Clear(const std::string &flag) {
  857. if (flag == "resource") {
  858. MS_LOG(INFO) << "Clear res";
  859. Clean();
  860. return;
  861. }
  862. MS_LOG(INFO) << "Clear";
  863. top_g_ = nullptr;
  864. curr_g_ = nullptr;
  865. graph_info_map_.clear();
  866. std::stack<FuncGraphPtr>().swap(graph_p_);
  867. }
  868. void PynativeExecutor::Clean() {
  869. MS_LOG(INFO) << "Clean all res";
  870. Clear();
  871. grad_flag_ = false;
  872. df_builder_ = nullptr;
  873. ad::CleanRes();
  874. pipeline::ReclaimOptimizer();
  875. }
  876. void PynativeExecutor::ClearRes() {
  877. Clean();
  878. resource_.reset();
  879. }
  880. py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) {
  881. VectorRef arg_list;
  882. pipeline::ProcessVmArgInner(args, resource_, &arg_list);
  883. if (resource_->results().find(pipeline::kOutput) == resource_->results().end() ||
  884. !resource_->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
  885. MS_LOG(EXCEPTION) << "Can't find run graph func for ";
  886. }
  887. compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
  888. if (run == nullptr) {
  889. MS_LOG(EXCEPTION) << "Can't find run graph func for ";
  890. }
  891. std::string backend = MsContext::GetInstance()->backend_policy();
  892. MS_LOG(DEBUG) << "Eval run" << backend;
  893. BaseRef value = (*run)(arg_list);
  894. MS_LOG(DEBUG) << "Run end" << value.ToString();
  895. return BaseRefToPyData(value);
  896. }
  897. FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op,
  898. const std::vector<AnfNodePtr> &weights, size_t arg_size) {
  899. auto nparam = top_g_->parameters().size();
  900. std::ostringstream ss;
  901. ss << "grad{" << nparam << "}";
  902. df_builder_->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  903. df_builder_->debug_info()->set_name(ss.str());
  904. auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights);
  905. std::vector<AnfNodePtr> inputs = {NewValueNode(df)};
  906. for (size_t i = 0; i < arg_size; ++i) {
  907. inputs.push_back(df_builder_->parameters()[i]);
  908. }
  909. auto out = df_builder_->NewCNode(inputs);
  910. df_builder_->set_output(out);
  911. resource_->manager()->AddFuncGraph(df);
  912. resource_->manager()->AddFuncGraph(df_builder_);
  913. return df_builder_;
  914. }
  915. REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
  916. (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
  917. .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
  918. .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
  919. .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
  920. .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
  921. .def("clear", &PynativeExecutor::Clear, "pynative clear status.")
  922. .def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""),
  923. "Executor run function.")
  924. .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
  925. "Executor set grad flag.");
  926. }));
  927. } // namespace pynative
  928. } // namespace mindspore