You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pynative_execute.cc 35 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pynative/pynative_execute.h"
  17. #include <typeinfo>
  18. #include <map>
  19. #include <set>
  20. #include <unordered_set>
  21. #include <algorithm>
  22. #include "ir/param_value_py.h"
  23. #include "utils/any.h"
  24. #include "utils/utils.h"
  25. #include "utils/context/ms_context.h"
  26. #include "operator/ops.h"
  27. #include "operator/composite/composite.h"
  28. #include "operator/composite/do_signature.h"
  29. #include "pipeline/parse/data_converter.h"
  30. #include "pipeline/parse/parse_base.h"
  31. #include "pipeline/parse/resolve.h"
  32. #include "pipeline/static_analysis/prim.h"
  33. #include "session/session_factory.h"
  34. #include "pre_activate/pass/const_input_to_attr_registry.h"
  35. #include "pre_activate/common/helper.h"
  36. #include "pipeline/action.h"
  37. #include "pynative/base.h"
  38. #include "pybind_api/api_register.h"
  39. #include "vm/transform.h"
  40. #include "optimizer/ad/grad.h"
  41. #include "pipeline/resource.h"
  42. #include "pipeline/pipeline.h"
  43. #include "pipeline/pass.h"
  44. #ifdef ENABLE_GE
  45. #include "pynative/pynative_execute_ge.h"
  46. #endif
  47. const char SINGLE_OP_GRAPH[] = "single_op_graph";
  48. // primitive unable to infer value for constant input in PyNative mode
  49. const std::set<std::string> vm_operators = {"make_ref", "HookBackward"};
  50. namespace mindspore {
  51. namespace pynative {
  52. static std::shared_ptr<session::SessionBasic> session = nullptr;
  53. PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
  54. std::mutex PynativeExecutor::instance_lock_;
  55. ResourcePtr PynativeExecutor::resource_;
  56. inline ValuePtr PyAttrValue(const py::object &obj) {
  57. ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
  58. if (!converted_ret) {
  59. MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj));
  60. }
  61. return converted_ret;
  62. }
  63. std::string GetId(const py::object &obj) {
  64. py::object to_process = obj;
  65. std::string prefix = "";
  66. if (py::isinstance<py::tuple>(to_process)) {
  67. auto p_list = py::cast<py::tuple>(to_process);
  68. if (p_list.size() == 0) {
  69. return "empty";
  70. }
  71. to_process = p_list[0];
  72. prefix = "tuple:";
  73. if (!py::isinstance<tensor::Tensor>(to_process)) {
  74. std::string key = "";
  75. for (size_t i = 0; i < p_list.size(); ++i) {
  76. key += std::string(py::str(p_list[i])) + ":";
  77. }
  78. return prefix + key;
  79. }
  80. }
  81. if (py::isinstance<py::int_>(to_process)) {
  82. return prefix + std::string(py::str(to_process));
  83. }
  84. if (py::isinstance<py::float_>(to_process)) {
  85. return prefix + std::string(py::str(to_process));
  86. }
  87. if (py::isinstance<tensor::Tensor>(to_process)) {
  88. auto tensor_ptr = py::cast<tensor::TensorPtr>(to_process);
  89. return prefix + tensor_ptr->id();
  90. }
  91. py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
  92. return py::cast<std::string>(ret);
  93. }
  94. py::object GetTupleObj(const py::object &obj) {
  95. py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
  96. py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
  97. return obj_tuple;
  98. }
  99. void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
  100. auto &py_args = *out_args;
  101. for (size_t i = 0; i < args.size(); ++i) {
  102. py_args[i] = GetTupleObj(args[i]);
  103. }
  104. auto signature = prim->signatures();
  105. std::vector<SignatureEnumDType> dtypes;
  106. (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
  107. [](const Signature &sig) { return sig.dtype; });
  108. int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
  109. if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
  110. return;
  111. }
  112. std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
  113. for (size_t i = 0; i < dtypes.size(); ++i) {
  114. auto it = type_indexs.find(dtypes[i]);
  115. if (it == type_indexs.end()) {
  116. (void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
  117. } else {
  118. it->second.push_back(i);
  119. }
  120. }
  121. std::map<SignatureEnumDType, size_t> dst_type;
  122. for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
  123. auto type = it->first;
  124. auto indexs = it->second;
  125. if (indexs.size() < 2) {
  126. continue;
  127. }
  128. size_t m_index = indexs[0];
  129. for (size_t i = 1; i < indexs.size(); ++i) {
  130. if (py::isinstance<tensor::Tensor>(py_args[indexs[i]])) {
  131. m_index = indexs[i];
  132. }
  133. }
  134. (void)dst_type.insert(std::make_pair(type, m_index));
  135. }
  136. for (size_t i = 0; i < py_args.size(); ++i) {
  137. auto it = dst_type.find(dtypes[i]);
  138. if (it != dst_type.end() && it->second != i &&
  139. (py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
  140. auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
  141. if (py::isinstance<py::int_>(py_args[i])) {
  142. py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
  143. } else {
  144. py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
  145. }
  146. continue;
  147. }
  148. }
  149. }
  150. void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
  151. size_t size = py_args.size();
  152. AbstractBasePtrList args_spec_list;
  153. for (size_t i = 0; i < size; i++) {
  154. ValuePtr input_value = PyAttrValue(py_args[i]);
  155. if (input_value->isa<tensor::Tensor>()) {
  156. args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
  157. } else {
  158. args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
  159. }
  160. }
  161. AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
  162. op_exec_info->abstract = infer_res;
  163. }
  164. OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
  165. if (args.size() != PY_ARGS_NUM) {
  166. MS_LOG(ERROR) << "Four args are needed by RunOp";
  167. return nullptr;
  168. }
  169. auto op_exec_info = std::make_shared<OpExecInfo>();
  170. MS_EXCEPTION_IF_NULL(op_exec_info);
  171. op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
  172. auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
  173. auto pyobj = prim->GetPyObj();
  174. if (pyobj == nullptr) {
  175. MS_LOG(EXCEPTION) << "pyobj is empty";
  176. }
  177. py::list a = args[PY_INPUTS];
  178. size_t input_num = a.size();
  179. op_exec_info->op_inputs = py::tuple(input_num);
  180. ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
  181. // use python infer method
  182. if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
  183. PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
  184. }
  185. op_exec_info->py_primitive = prim;
  186. op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
  187. op_exec_info->inputs_mask = args[PY_INPUT_MASK];
  188. if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
  189. MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
  190. return nullptr;
  191. }
  192. return op_exec_info;
  193. }
  194. std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
  195. const std::vector<tensor::TensorPtr> &input_tensors) {
  196. MS_EXCEPTION_IF_NULL(op_exec_info);
  197. std::string graph_info;
  198. // get input tensor info
  199. size_t input_num = op_exec_info->op_inputs.size();
  200. for (size_t index = 0; index < input_num; ++index) {
  201. auto input = op_exec_info->op_inputs[index];
  202. if (py::isinstance<tensor::Tensor>(input)) {
  203. auto tensor_ptr = py::cast<tensor::TensorPtr>(input);
  204. (void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_");
  205. }
  206. }
  207. // get prim and abstract info
  208. MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
  209. (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" +
  210. op_exec_info->abstract->ToString());
  211. return graph_info;
  212. }
  213. py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  214. MS_LOG(INFO) << "RunOpInVM start";
  215. MS_EXCEPTION_IF_NULL(status);
  216. MS_EXCEPTION_IF_NULL(op_exec_info);
  217. MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
  218. if (op_exec_info->op_name == "HookBackward") {
  219. auto op_inputs = op_exec_info->op_inputs;
  220. py::tuple result(op_inputs.size());
  221. for (size_t i = 0; i < op_inputs.size(); i++) {
  222. py::object input = op_inputs[i];
  223. if (py::hasattr(input, "__parameter__")) {
  224. result[i] = py::getattr(input, "data");
  225. } else {
  226. auto tensor = py::cast<tensor::TensorPtr>(op_inputs[i]);
  227. auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data());
  228. result[i] = new_tensor;
  229. }
  230. }
  231. *status = PYNATIVE_SUCCESS;
  232. MS_LOG(INFO) << "RunOpInVM end";
  233. return std::move(result);
  234. }
  235. auto func = op_exec_info->py_primitive->GetComputeFunction();
  236. if (py::isinstance<py::none>(func)) {
  237. MS_LOG(ERROR) << "VM failed to get func";
  238. *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
  239. py::tuple err_ret(0);
  240. return std::move(err_ret);
  241. }
  242. // execute op
  243. py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs));
  244. *status = PYNATIVE_SUCCESS;
  245. MS_LOG(INFO) << "RunOpInVM end";
  246. return std::move(result);
  247. }
  248. bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
  249. const std::unordered_set<size_t> &input_attrs) {
  250. MS_EXCEPTION_IF_NULL(op_prim);
  251. auto input_names_value = op_prim->GetAttr(kAttrInputNames);
  252. if (input_names_value == nullptr) {
  253. return false;
  254. }
  255. auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
  256. if (input_index >= input_names_vec.size()) {
  257. MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
  258. }
  259. if (input_attrs.find(input_index) != input_attrs.end()) {
  260. ValuePtr value = parse::data_converter::PyDataToValue(input_object);
  261. MS_EXCEPTION_IF_NULL(value);
  262. auto input_name = input_names_vec[input_index];
  263. op_prim->set_attr(input_name, value);
  264. return true;
  265. }
  266. return false;
  267. }
  268. void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
  269. std::vector<tensor::TensorPtr> *input_tensors) {
  270. MS_EXCEPTION_IF_NULL(op_prim);
  271. MS_EXCEPTION_IF_NULL(input_tensors);
  272. for (const auto &input_object : tuple_inputs) {
  273. if (!py::isinstance<tensor::Tensor>(input_object)) {
  274. MS_LOG(EXCEPTION) << "The input object is not a tensor!";
  275. }
  276. auto tensor = py::cast<tensor::TensorPtr>(input_object);
  277. MS_EXCEPTION_IF_NULL(tensor);
  278. input_tensors->push_back(tensor);
  279. }
  280. op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())}));
  281. }
  282. void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensors) {
  283. MS_EXCEPTION_IF_NULL(input_tensors);
  284. ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
  285. MS_EXCEPTION_IF_NULL(input_value);
  286. if (!input_value->isa<ValueTuple>()) {
  287. MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
  288. }
  289. auto value_tuple = input_value->cast<ValueTuplePtr>();
  290. MS_EXCEPTION_IF_NULL(value_tuple);
  291. tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
  292. MS_EXCEPTION_IF_NULL(tensor_ptr);
  293. input_tensors->push_back(tensor_ptr);
  294. }
  295. void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  296. std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
  297. MS_EXCEPTION_IF_NULL(op_prim);
  298. MS_EXCEPTION_IF_NULL(input_tensors);
  299. MS_EXCEPTION_IF_NULL(tensor_mask);
  300. if (!py::isinstance<py::tuple>(input_object)) {
  301. MS_LOG(EXCEPTION) << "The input should be a tuple!";
  302. }
  303. auto tuple_inputs = py::cast<py::tuple>(input_object);
  304. if (tuple_inputs.size() == 0) {
  305. MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
  306. }
  307. if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
  308. PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
  309. } else {
  310. ConvertValueTupleToTensor(input_object, input_tensors);
  311. *tensor_mask = kValueNodeTensorMask;
  312. }
  313. }
  314. void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  315. std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
  316. MS_EXCEPTION_IF_NULL(op_prim);
  317. MS_EXCEPTION_IF_NULL(input_tensors);
  318. MS_EXCEPTION_IF_NULL(tensor_mask);
  319. tensor::TensorPtr tensor_ptr = nullptr;
  320. if (py::isinstance<tensor::Tensor>(input_object)) {
  321. tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
  322. } else if (py::isinstance<py::float_>(input_object)) {
  323. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(input_object), kFloat32);
  324. *tensor_mask = kValueNodeTensorMask;
  325. } else if (py::isinstance<py::int_>(input_object)) {
  326. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32);
  327. *tensor_mask = kValueNodeTensorMask;
  328. } else if (py::isinstance<py::array>(input_object)) {
  329. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr);
  330. } else if (py::isinstance<py::list>(input_object)) {
  331. auto list_inputs = py::cast<py::list>(input_object);
  332. py::tuple tuple_inputs(list_inputs.size());
  333. for (size_t i = 0; i < tuple_inputs.size(); ++i) {
  334. tuple_inputs[i] = list_inputs[i];
  335. }
  336. ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
  337. return;
  338. } else if (py::isinstance<py::tuple>(input_object)) {
  339. ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
  340. return;
  341. } else if (py::isinstance<py::none>(input_object)) {
  342. return;
  343. } else {
  344. MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
  345. }
  346. MS_EXCEPTION_IF_NULL(tensor_ptr);
  347. input_tensors->push_back(tensor_ptr);
  348. }
  349. void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *tensors_mask,
  350. std::vector<tensor::TensorPtr> *input_tensors) {
  351. MS_EXCEPTION_IF_NULL(op_run_info);
  352. MS_EXCEPTION_IF_NULL(tensors_mask);
  353. MS_EXCEPTION_IF_NULL(input_tensors);
  354. PrimitivePtr op_prim = op_run_info->py_primitive;
  355. MS_EXCEPTION_IF_NULL(op_prim);
  356. if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) {
  357. MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size "
  358. << op_run_info->inputs_mask.size();
  359. }
  360. opt::ConstInputToAttrInfoRegister reg;
  361. bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
  362. size_t input_num = op_run_info->op_inputs.size();
  363. for (size_t index = 0; index < input_num; ++index) {
  364. // convert const input to attr
  365. if (reg_exist &&
  366. RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
  367. continue;
  368. }
  369. // convert const and tuple input to tensor
  370. int tensor_mask = py::cast<int>(op_run_info->inputs_mask[index]);
  371. ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
  372. // mark tensors, data : 0, weight : 1, valuenode: 2
  373. std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
  374. tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
  375. }
  376. }
  377. void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) {
  378. MS_EXCEPTION_IF_NULL(input_tensors);
  379. if (input_tensors->size() != tensors_mask.size()) {
  380. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
  381. << tensors_mask.size();
  382. }
  383. std::vector<tensor::TensorPtr> new_input_tensors;
  384. for (size_t index = 0; index < tensors_mask.size(); ++index) {
  385. if (tensors_mask[index] != kValueNodeTensorMask) {
  386. new_input_tensors.push_back(input_tensors->at(index));
  387. }
  388. }
  389. *input_tensors = new_input_tensors;
  390. }
  391. py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  392. MS_EXCEPTION_IF_NULL(op_exec_info);
  393. MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
  394. auto ms_context = MsContext::GetInstance();
  395. MS_EXCEPTION_IF_NULL(ms_context);
  396. ms_context->set_enable_pynative_infer(true);
  397. std::string device_target = ms_context->device_target();
  398. if (device_target != kAscendDevice && device_target != kGPUDevice) {
  399. MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
  400. }
  401. if (session == nullptr) {
  402. session = session::SessionFactory::Get().Create(device_target);
  403. }
  404. MS_EXCEPTION_IF_NULL(session);
  405. session->Init(ms_context->device_id());
  406. std::vector<tensor::TensorPtr> input_tensors;
  407. std::vector<int> tensors_mask;
  408. ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
  409. // get graph info for checking it whether existing in the cache
  410. std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
  411. session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask);
  412. EraseValueNodeTensor(tensors_mask, &input_tensors);
  413. py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors);
  414. ms_context->set_enable_pynative_infer(false);
  415. *status = PYNATIVE_SUCCESS;
  416. return result;
  417. }
  418. py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr op_exec_info,
  419. PynativeStatusCode *const status) {
  420. MS_EXCEPTION_IF_NULL(status);
  421. py::object result;
  422. switch (backend_policy) {
  423. case kMsBackendVmOnly: {
  424. // use vm only
  425. MS_LOG(INFO) << "RunOp use VM only backend";
  426. result = RunOpInVM(op_exec_info, status);
  427. break;
  428. }
  429. case kMsBackendGePrior: {
  430. #ifdef ENABLE_GE
  431. // use GE first, use vm when GE fails
  432. MS_LOG(INFO) << "RunOp use GE first backend";
  433. result = RunOpInGE(op_exec_info, status);
  434. if (*status != PYNATIVE_SUCCESS) {
  435. result = RunOpInVM(op_exec_info, status);
  436. }
  437. #endif
  438. break;
  439. }
  440. case kMsBackendMsPrior: {
  441. // use Ms fisrt,use others when ms failed
  442. MS_LOG(INFO) << "RunOp use Ms first backend";
  443. result = RunOpInMs(op_exec_info, status);
  444. if (*status != PYNATIVE_SUCCESS) {
  445. MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
  446. }
  447. break;
  448. }
  449. default:
  450. MS_LOG(ERROR) << "No backend configured for run op";
  451. }
  452. return result;
  453. }
  454. AnfNodePtr PynativeExecutor::MakeCNode(const py::args &args, const py::tuple &out) {
  455. if (!grad_flag_ || graph_info_map_.size() == 0) {
  456. return nullptr;
  457. }
  458. std::vector<AnfNodePtr> inputs;
  459. auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
  460. inputs.push_back(NewValueNode(prim));
  461. py::tuple op_masks = args[PY_INPUT_MASK];
  462. py::list op_args = args[PY_INPUTS];
  463. AbstractBasePtrList args_spec_list;
  464. for (size_t i = 0; i < op_args.size(); i++) {
  465. auto node = GetInput(op_args[i], op_masks[i]);
  466. args_spec_list.push_back(node->abstract());
  467. inputs.push_back(node);
  468. }
  469. auto cnode = curr_g_->NewCNode(inputs);
  470. MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString();
  471. py::object out_real = out;
  472. if (out.size() == 1) {
  473. MS_LOG(DEBUG) << "MakeCnode out size is one.";
  474. out_real = out[0];
  475. }
  476. std::string obj_id = GetId(out_real);
  477. if (py::isinstance<py::tuple>(out_real)) {
  478. auto value = py::cast<py::tuple>(out_real);
  479. if (value.size() > 1) {
  480. for (int i = 0; i < static_cast<int>(value.size()); i++) {
  481. auto value_id = GetId(value[i]);
  482. set_obj_node_map(curr_g_, value_id, cnode, i);
  483. }
  484. }
  485. }
  486. set_obj_node_map(curr_g_, obj_id, cnode);
  487. set_pyobj(curr_g_, obj_id);
  488. return cnode;
  489. }
  490. AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
  491. auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)];
  492. if (out.second == -1) {
  493. return out.first;
  494. }
  495. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), out.first,
  496. NewValueNode(out.second)};
  497. return curr_g_->NewCNode(tuple_get_item_inputs);
  498. }
  499. py::tuple RunOp(const py::args &args) {
  500. MS_LOG(DEBUG) << "RunOp start" << args.size();
  501. py::object result;
  502. // returns a null py::tuple on error
  503. py::tuple err_ret(0);
  504. PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
  505. OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args);
  506. MS_EXCEPTION_IF_NULL(op_exec_info);
  507. if (op_exec_info->abstract != nullptr) {
  508. py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
  509. if (!output["value"].is_none()) {
  510. py::tuple value_ret(1);
  511. value_ret[0] = output["value"];
  512. return value_ret;
  513. }
  514. if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
  515. py::tuple value_ret(1);
  516. value_ret[0] = "";
  517. return value_ret;
  518. }
  519. }
  520. MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
  521. mindspore::parse::python_adapter::set_python_env_flag(true);
  522. MsBackendPolicy backend_policy;
  523. #if (!defined ENABLE_GE)
  524. auto ms_context = MsContext::GetInstance();
  525. MS_EXCEPTION_IF_NULL(ms_context);
  526. if (ms_context->backend_policy() == "ms") {
  527. backend_policy = kMsBackendMsPrior;
  528. } else {
  529. backend_policy = kMsBackendVmOnly;
  530. }
  531. #else
  532. auto ms_context = MsContext::GetInstance();
  533. MS_EXCEPTION_IF_NULL(ms_context);
  534. ms_context->PynativeInitGe();
  535. backend_policy = kMsBackendGeOnly;
  536. #endif
  537. if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
  538. backend_policy = kMsBackendVmOnly;
  539. }
  540. result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
  541. if (status != PYNATIVE_SUCCESS) {
  542. MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name;
  543. return err_ret;
  544. }
  545. auto node = PynativeExecutor::GetInstance()->MakeCNode(args, result);
  546. if (node != nullptr) {
  547. node->set_abstract(op_exec_info->abstract);
  548. MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString();
  549. }
  550. MS_LOG(DEBUG) << "RunOp end";
  551. return result;
  552. }
  553. void ClearPyNativeSession() { session = nullptr; }
  554. PynativeExecutor::~PynativeExecutor() { Clean(); }
  555. PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
  556. void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
  557. auto cell_id = GetId(cell);
  558. if (cell_graph_map_.count(cell_id) != 0) {
  559. MS_LOG(DEBUG) << "Newgraph already compiled";
  560. return;
  561. }
  562. auto g = std::make_shared<FuncGraph>();
  563. if (top_g_ == nullptr) {
  564. top_g_ = curr_g_ = g;
  565. df_builder_ = std::make_shared<FuncGraph>();
  566. MS_LOG(DEBUG) << "First new graph" << top_g_.get();
  567. Pushp();
  568. } else {
  569. Pushp();
  570. curr_g_ = g;
  571. }
  572. if (graph_info_map_.count(g) == 0) {
  573. graph_info_map_[g] = GraphInfo();
  574. }
  575. for (size_t i = 0; i < args.size(); i++) {
  576. auto new_param = g->add_parameter();
  577. std::string param_obj = GetId(args[i]);
  578. graph_info_map_[g].param_map[param_obj] = new_param;
  579. }
  580. }
  581. AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) {
  582. AnfNodePtr node = nullptr;
  583. std::string obj_id = GetId(obj);
  584. if (op_mask != nullptr && py::cast<bool>(op_mask)) {
  585. MS_LOG(DEBUG) << "Topgraph free parameter";
  586. // get the parameter name from parameter object
  587. auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name");
  588. if (py::isinstance<py::none>(name_attr)) {
  589. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  590. }
  591. std::string param_name = py::cast<std::string>(name_attr);
  592. if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
  593. auto free_param = df_builder_->add_parameter();
  594. free_param->set_name(param_name);
  595. auto free_param_new = std::make_shared<ParamValuePy>(obj);
  596. free_param->set_default_param(free_param_new);
  597. free_param->debug_info()->set_name(param_name);
  598. MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
  599. graph_info_map_[df_builder_].param_map[obj_id] = free_param;
  600. return free_param;
  601. }
  602. return graph_info_map_[df_builder_].param_map[obj_id];
  603. }
  604. // if input is graph output
  605. if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) {
  606. // op(x, y)
  607. node = graph_info_map_[curr_g_].param_map[obj_id];
  608. } else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) {
  609. // out = op(op1(x, y))
  610. // out = op(cell1(x, y))
  611. // out = op(cell1(x, y)[0])
  612. node = GetObjNode(obj);
  613. } else if (py::isinstance<py::tuple>(obj)) {
  614. // out = op((x, y))
  615. // out = cell((x, y))
  616. std::vector<AnfNodePtr> args;
  617. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  618. auto tuple = obj.cast<py::tuple>();
  619. auto tuple_size = static_cast<int>(tuple.size());
  620. for (int i = 0; i < tuple_size; i++) {
  621. args.push_back(GetInput(tuple[i], py::object()));
  622. }
  623. auto cnode = curr_g_->NewCNode(args);
  624. set_obj_node_map(curr_g_, GetId(obj), cnode);
  625. node = cnode;
  626. } else {
  627. // out = op(x, 1)
  628. ValuePtr converted_ret = nullptr;
  629. parse::ConvertData(obj, &converted_ret);
  630. node = NewValueNode(converted_ret);
  631. set_obj_node_map(curr_g_, obj_id, node);
  632. }
  633. MS_LOG(DEBUG) << "Now getinput " << py::str(obj) << " node " << node->ToString();
  634. return node;
  635. }
  636. void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); }
  637. void PynativeExecutor::Popp() {
  638. if (graph_p_.empty()) {
  639. MS_LOG(EXCEPTION) << "Stack graph_p_ is empty";
  640. }
  641. curr_g_ = graph_p_.top();
  642. graph_p_.pop();
  643. }
  644. void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
  645. auto cell_id = GetId(cell);
  646. if (cell_graph_map_.count(cell_id) != 0) {
  647. MS_LOG(DEBUG) << "Endgraph already compiled";
  648. return;
  649. }
  650. cell_graph_map_[cell_id] = curr_g_;
  651. auto out_id = GetId(out);
  652. if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) {
  653. // cell construct return x, y
  654. if (py::isinstance<py::tuple>(out)) {
  655. std::vector<AnfNodePtr> args;
  656. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  657. auto tuple = out.cast<py::tuple>();
  658. MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size();
  659. auto tuple_size = static_cast<int>(tuple.size());
  660. auto cnode = curr_g_->NewCNode(args);
  661. for (int i = 0; i < tuple_size; i++) {
  662. args.push_back(GetInput(tuple[i], py::object()));
  663. set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i);
  664. }
  665. cnode->set_inputs(args);
  666. set_obj_node_map(curr_g_, out_id, cnode);
  667. } else {
  668. MS_LOG(ERROR) << "Graph has no this out: " << out_id;
  669. return;
  670. }
  671. }
  672. auto output_node = GetObjNode(out);
  673. curr_g_->set_output(output_node);
  674. std::vector<AnfNodePtr> inputs;
  675. inputs.push_back(NewValueNode(curr_g_));
  676. MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
  677. resource_->manager()->AddFuncGraph(curr_g_);
  678. auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
  679. if (curr_g_ != top_g_) {
  680. Popp();
  681. for (size_t i = 0; i < args.size(); i++) {
  682. auto input = GetInput(args[i], py::object());
  683. inputs.push_back(input);
  684. }
  685. auto out_cnode = curr_g_->NewCNode(inputs);
  686. set_pyobj(curr_g_, GetId(cell));
  687. if (py::isinstance<py::tuple>(out)) {
  688. auto out_list = py::cast<py::tuple>(out);
  689. auto out_size = static_cast<int>(out_list.size());
  690. for (int i = 0; i < out_size; i++) {
  691. set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i);
  692. }
  693. }
  694. set_obj_node_map(curr_g_, GetId(out), out_cnode);
  695. } else {
  696. parse::ResolveFuncGraph(newfg, resource_);
  697. resource_->set_func_graph(newfg);
  698. }
  699. }
  700. void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  701. const py::args &args) {
  702. MS_LOG(INFO) << "GradNet start" << args.size();
  703. std::size_t size = args.size();
  704. auto cell_id = GetId(cell);
  705. if (graph_map_.count(cell_id) != 0) {
  706. MS_LOG(DEBUG) << "GradNet already compiled";
  707. return;
  708. }
  709. MS_LOG(DEBUG) << "GradNet first compiled";
  710. std::vector<AnfNodePtr> new_params;
  711. for (size_t i = 0; i < size; i++) {
  712. ParameterPtr p = std::make_shared<Parameter>(df_builder_);
  713. new_params.push_back(p);
  714. }
  715. MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size();
  716. new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end());
  717. df_builder_->set_parameters(new_params);
  718. resource_->manager()->SetParameters(df_builder_, new_params);
  719. std::vector<AnfNodePtr> w_args;
  720. if (py::hasattr(weights, "__parameter_tuple__")) {
  721. auto tuple = weights.cast<py::tuple>();
  722. MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size();
  723. w_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  724. for (size_t it = 0; it < tuple.size(); ++it) {
  725. auto param = tuple[it];
  726. auto param_id = GetId(param);
  727. AnfNodePtr para_node = nullptr;
  728. if (graph_info_map_[df_builder_].param_map.count(param_id)) {
  729. para_node = graph_info_map_[df_builder_].param_map[param_id];
  730. AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node);
  731. AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
  732. auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
  733. AnfNodePtr ref_key_node = NewValueNode(refkey);
  734. AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node});
  735. w_args.push_back(ref_node);
  736. }
  737. }
  738. } else {
  739. MS_LOG(EXCEPTION) << "training not paramter_tuple";
  740. }
  741. MS_EXCEPTION_IF_NULL(resource_->func_graph());
  742. auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
  743. resource_->set_func_graph(g);
  744. // get the parameters items and add the value to args_spec
  745. abstract::AbstractBasePtrList args_spec;
  746. for (std::size_t i = 0; i < size; i++) {
  747. ValuePtr converted = nullptr;
  748. bool succ = parse::ConvertData(args[i], &converted);
  749. if (!succ) {
  750. MS_LOG(EXCEPTION) << "Args convert error";
  751. }
  752. bool broaden = true;
  753. auto abs = abstract::FromValue(converted, broaden);
  754. args_spec.push_back(abs);
  755. auto param_node = std::static_pointer_cast<Parameter>(df_builder_->parameters()[i]);
  756. param_node->set_abstract(abs);
  757. }
  758. for (const auto &param : df_builder_->parameters()) {
  759. auto param_node = std::static_pointer_cast<Parameter>(param);
  760. if (param_node->has_default()) {
  761. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
  762. AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
  763. if (ptr == nullptr) {
  764. MS_LOG(EXCEPTION) << "Args convert error";
  765. }
  766. args_spec.push_back(ptr);
  767. param_node->set_abstract(ptr);
  768. }
  769. }
  770. MS_LOG(DEBUG) << "Args_spec size" << args_spec.size();
  771. resource_->set_args_spec(args_spec);
  772. MS_LOG(DEBUG) << "Start opt";
  773. // Create backend and session
  774. resource_->results()[pipeline::kBackend] = compile::CreateBackend();
  775. graph_map_[cell_id] = g;
  776. PynativeOptimizeAction(resource_);
  777. TaskEmitAction(resource_);
  778. ExecuteAction(resource_);
  779. resource_->Clean();
  780. ad::CleanRes();
  781. pipeline::ReclaimOptimizer();
  782. }
  783. void PynativeExecutor::Clear() {
  784. MS_LOG(INFO) << "Clear all res";
  785. top_g_ = curr_g_ = nullptr;
  786. std::stack<FuncGraphPtr>().swap(graph_p_);
  787. graph_info_map_.clear();
  788. }
  789. void PynativeExecutor::Clean() {
  790. graph_map_.clear();
  791. cell_graph_map_.clear();
  792. Clear();
  793. resource_.reset();
  794. }
  795. py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) {
  796. VectorRef arg_list;
  797. pipeline::ProcessVmArgInner(args, resource_, &arg_list);
  798. if (resource_->results().find(pipeline::kOutput) == resource_->results().end() ||
  799. !resource_->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
  800. MS_LOG(EXCEPTION) << "Can't find run graph func for ";
  801. }
  802. compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
  803. if (run == nullptr) {
  804. MS_LOG(EXCEPTION) << "Can't find run graph func for ";
  805. }
  806. std::string backend = MsContext::GetInstance()->backend_policy();
  807. MS_LOG(DEBUG) << "Eval run" << backend;
  808. BaseRef value = (*run)(arg_list);
  809. MS_LOG(DEBUG) << "Run end" << value.ToString();
  810. return BaseRefToPyData(value);
  811. }
  812. FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op,
  813. const std::vector<AnfNodePtr> &weights, size_t arg_size) {
  814. auto nparam = top_g_->parameters().size();
  815. std::ostringstream ss;
  816. ss << "grad{" << nparam << "}";
  817. df_builder_->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  818. df_builder_->debug_info()->set_name(ss.str());
  819. auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights);
  820. std::vector<AnfNodePtr> inputs = {NewValueNode(df)};
  821. for (size_t i = 0; i < arg_size; ++i) {
  822. inputs.push_back(df_builder_->parameters()[i]);
  823. }
  824. auto out = df_builder_->NewCNode(inputs);
  825. df_builder_->set_output(out);
  826. resource_->manager()->AddFuncGraph(df);
  827. resource_->manager()->AddFuncGraph(df_builder_);
  828. return df_builder_;
  829. }
  830. REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
  831. (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
  832. .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
  833. .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
  834. .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
  835. .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
  836. .def("clear", &PynativeExecutor::Clear, "pynative clear status.")
  837. .def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""),
  838. "Executor run function.")
  839. .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
  840. "Executor set grad flag.");
  841. }));
  842. } // namespace pynative
  843. } // namespace mindspore