You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pynative_execute.cc 69 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/pynative/pynative_execute.h"
  17. #include <typeinfo>
  18. #include <map>
  19. #include <set>
  20. #include <memory>
  21. #include <unordered_set>
  22. #include <algorithm>
  23. #include "debug/trace.h"
  24. #include "pybind_api/ir/tensor_py.h"
  25. #include "ir/param_info.h"
  26. #include "ir/anf.h"
  27. #include "ir/tensor.h"
  28. #include "utils/any.h"
  29. #include "utils/utils.h"
  30. #include "utils/ms_context.h"
  31. #include "utils/context/context_extends.h"
  32. #include "utils/config_manager.h"
  33. #include "utils/convert_utils_py.h"
  34. #include "frontend/operator/ops.h"
  35. #include "frontend/operator/composite/composite.h"
  36. #include "frontend/operator/composite/do_signature.h"
  37. #include "pipeline/jit/parse/data_converter.h"
  38. #include "pipeline/jit/parse/parse_base.h"
  39. #include "pipeline/jit/parse/resolve.h"
  40. #include "pipeline/jit/static_analysis/prim.h"
  41. #include "backend/session/session_factory.h"
  42. #include "backend/optimizer/pass/const_input_to_attr_registry.h"
  43. #include "backend/optimizer/common/helper.h"
  44. #include "pipeline/jit/action.h"
  45. #include "pipeline/pynative/base.h"
  46. #include "pybind_api/api_register.h"
  47. #include "vm/transform.h"
  48. #include "frontend/optimizer/ad/grad.h"
  49. #include "pipeline/jit/resource.h"
  50. #include "pipeline/jit/pipeline.h"
  51. #include "pipeline/jit/pass.h"
  52. #ifdef ENABLE_GE
  53. #include "pipeline/pynative/pynative_execute_ge.h"
  54. #endif
  55. #include "debug/anf_ir_dump.h"
  56. using mindspore::tensor::TensorPy;
  57. const char SINGLE_OP_GRAPH[] = "single_op_graph";
  58. // primitive unable to infer value for constant input in PyNative mode
  59. const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
  60. "mixed_precision_cast"};
  61. namespace mindspore {
  62. namespace pynative {
  63. static std::shared_ptr<session::SessionBasic> session = nullptr;
  64. PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
  65. std::mutex PynativeExecutor::instance_lock_;
  66. ResourcePtr PynativeExecutor::resource_;
  67. int PynativeExecutor::graph_id_ = 0;
  68. template <typename... Args>
  69. void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) {
  70. try {
  71. (executor->*method)(args...);
  72. } catch (const py::error_already_set &ex) {
  73. // print function call stack info before release
  74. std::ostringstream oss;
  75. trace::TraceGraphEval();
  76. trace::GetEvalStackInfo(oss);
  77. // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
  78. // these info from screen, no need to open log file to find these info
  79. py::print(oss.str());
  80. MS_LOG(ERROR) << oss.str();
  81. PynativeExecutor::GetInstance()->Clean();
  82. // re-throw this exception to Python interpreter to handle it
  83. throw(py::error_already_set(ex));
  84. } catch (const py::type_error &ex) {
  85. PynativeExecutor::GetInstance()->Clean();
  86. throw py::type_error(ex);
  87. } catch (const py::value_error &ex) {
  88. PynativeExecutor::GetInstance()->Clean();
  89. throw py::value_error(ex);
  90. } catch (const py::index_error &ex) {
  91. PynativeExecutor::GetInstance()->Clean();
  92. throw py::index_error(ex);
  93. } catch (const std::exception &ex) {
  94. PynativeExecutor::GetInstance()->Clean();
  95. // re-throw this exception to Python interpreter to handle it
  96. throw(std::runtime_error(ex.what()));
  97. } catch (...) {
  98. PynativeExecutor::GetInstance()->Clean();
  99. std::string exName(abi::__cxa_current_exception_type()->name());
  100. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
  101. }
  102. }
  103. inline ValuePtr PyAttrValue(const py::object &obj) {
  104. ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
  105. if (!converted_ret) {
  106. MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj));
  107. }
  108. return converted_ret;
  109. }
  110. static std::string GetId(const py::object &obj) {
  111. py::object to_process = obj;
  112. std::string prefix = "";
  113. if (py::isinstance<py::tuple>(to_process) || py::isinstance<py::list>(to_process)) {
  114. auto p_list = py::cast<py::tuple>(to_process);
  115. if (p_list.empty()) {
  116. return "empty";
  117. }
  118. prefix = py::isinstance<py::tuple>(to_process) ? "tuple:" : "list";
  119. std::string key = "";
  120. for (size_t i = 0; i < p_list.size(); ++i) {
  121. key += std::string(py::str(GetId(p_list[i]))) + ":";
  122. }
  123. return prefix + key;
  124. }
  125. if (py::isinstance<mindspore::Type>(to_process)) {
  126. auto type_ptr = py::cast<mindspore::TypePtr>(to_process);
  127. return "type" + type_ptr->ToString();
  128. }
  129. if (py::isinstance<py::str>(to_process)) {
  130. return "s" + std::string(py::str(to_process));
  131. }
  132. if (py::isinstance<py::int_>(to_process)) {
  133. return prefix + std::string(py::str(to_process));
  134. }
  135. if (py::isinstance<py::float_>(to_process)) {
  136. return prefix + std::string(py::str(to_process));
  137. }
  138. if (py::isinstance<tensor::Tensor>(to_process)) {
  139. auto tensor_ptr = py::cast<tensor::TensorPtr>(to_process);
  140. return prefix + tensor_ptr->id();
  141. }
  142. py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
  143. return py::cast<std::string>(ret);
  144. }
  145. static std::string GetOpId(const OpExecInfoPtr &op_exec_info) {
  146. auto id = GetId(op_exec_info->py_primitive->GetPyObj());
  147. op_exec_info->prim_id = id;
  148. return id;
  149. }
  150. std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
  151. std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
  152. for (size_t i = 0; i < dtypes.size(); ++i) {
  153. auto it = type_indexes.find(dtypes[i]);
  154. if (it == type_indexes.end()) {
  155. (void)type_indexes.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
  156. } else {
  157. it->second.push_back(i);
  158. }
  159. }
  160. return type_indexes;
  161. }
  162. std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
  163. const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
  164. std::map<SignatureEnumDType, TypeId> dst_type;
  165. for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
  166. auto type = it->first;
  167. auto indexes = it->second;
  168. if (type == SignatureEnumDType::kDTypeEmptyDefaultValue || indexes.size() < 2) {
  169. continue;
  170. }
  171. size_t priority = 0;
  172. TypeId max_type = TypeId::kTypeUnknown;
  173. bool has_float = false;
  174. bool has_int = false;
  175. bool has_int8 = false;
  176. for (size_t index : indexes) {
  177. if (!has_float && py::isinstance<py::float_>(py_args[index])) {
  178. has_float = true;
  179. }
  180. if (!has_int && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) {
  181. has_int = true;
  182. }
  183. auto obj = py_args[index];
  184. if (py::isinstance<tensor::Tensor>(obj)) {
  185. auto arg = py::cast<tensor::TensorPtr>(obj);
  186. TypeId arg_type_id = arg->data_type();
  187. auto type_priority = prim::type_map.find(arg_type_id);
  188. if (type_priority == prim::type_map.end()) {
  189. continue;
  190. }
  191. if (arg_type_id == kNumberTypeInt8) {
  192. has_int8 = true;
  193. }
  194. if (type_priority->second > priority) {
  195. max_type = type_priority->first;
  196. priority = type_priority->second;
  197. }
  198. }
  199. }
  200. if (max_type == TypeId::kNumberTypeBool) {
  201. if (has_int) {
  202. max_type = TypeId::kNumberTypeInt32;
  203. }
  204. if (has_float) {
  205. max_type = TypeId::kNumberTypeFloat32;
  206. }
  207. }
  208. if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
  209. max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_float) {
  210. max_type = TypeId::kNumberTypeFloat32;
  211. }
  212. if (max_type == TypeId::kNumberTypeUInt8 && has_int8) {
  213. max_type = TypeId::kNumberTypeInt16;
  214. }
  215. (void)dst_type.insert(std::make_pair(type, max_type));
  216. }
  217. return dst_type;
  218. }
  219. std::string TypeIdToMsTypeStr(const TypeId &type_id) {
  220. auto type_name = type_name_map.find(type_id);
  221. if (type_name == type_name_map.end()) {
  222. MS_LOG(EXCEPTION) << "For implicit type conversion, not support convert to the type: " << TypeIdToType(type_id);
  223. }
  224. return type_name->second;
  225. }
  226. py::object DoAutoCast(const py::object &arg, const TypeId &type_id) {
  227. py::tuple args(3);
  228. std::string module_name = "mindspore.ops.functional";
  229. std::string op_name = "cast";
  230. args[0] = parse::python_adapter::GetPyFn(module_name, op_name);
  231. args[1] = "Cast";
  232. std::string dst_type_str = TypeIdToMsTypeStr(type_id);
  233. module_name = "mindspore.common.dtype";
  234. py::object dst_type = parse::python_adapter::GetPyFn(module_name, dst_type_str);
  235. py::tuple inputs(2);
  236. inputs[0] = arg;
  237. inputs[1] = dst_type;
  238. args[2] = inputs;
  239. return RunOp(args)[0];
  240. }
  241. py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) {
  242. auto tensor = py::cast<tensor::TensorPtr>(obj);
  243. auto cast_type = tensor->cast_dtype();
  244. py::object cast_output = obj;
  245. if (cast_type != nullptr) {
  246. auto source_element = tensor->Dtype();
  247. if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
  248. MS_LOG(DEBUG) << "cast to " << cast_type->ToString();
  249. cast_output = DoAutoCast(obj, cast_type->type_id());
  250. *is_cast = true;
  251. }
  252. }
  253. return cast_output;
  254. }
  255. py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) {
  256. auto tuple_size = static_cast<int>(tuple.size());
  257. py::tuple result(tuple_size);
  258. for (int i = 0; i < tuple_size; i++) {
  259. if (py::isinstance<tensor::MetaTensor>(tuple[i])) {
  260. MS_LOG(DEBUG) << "call cast for item " << i;
  261. result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]);
  262. } else if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) {
  263. result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]);
  264. } else {
  265. result[i] = tuple[i];
  266. }
  267. }
  268. return result;
  269. }
  270. bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
  271. MS_EXCEPTION_IF_NULL(dtypes);
  272. auto signature = prim->signatures();
  273. bool has_sig_dtype = false;
  274. (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
  275. [&has_sig_dtype](const Signature &sig) {
  276. auto dtype = sig.dtype;
  277. if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
  278. has_sig_dtype = true;
  279. }
  280. return dtype;
  281. });
  282. return has_sig_dtype;
  283. }
  284. void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
  285. const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info) {
  286. const auto &signature = prim->signatures();
  287. auto &out_args = op_exec_info->op_inputs;
  288. bool has_dtype_sig = (dtypes.size() > 0);
  289. for (size_t i = 0; i < out_args.size(); ++i) {
  290. MS_LOG(DEBUG) << "check inputs " << i;
  291. auto obj = out_args[i];
  292. auto sig = SignatureEnumRW::kRWDefault;
  293. if (signature.size() > 0) {
  294. sig = signature[i].rw;
  295. }
  296. bool is_parameter = false;
  297. TypeId arg_type_id = kTypeUnknown;
  298. if (py::isinstance<tensor::MetaTensor>(obj)) {
  299. auto arg = py::cast<tensor::MetaTensorPtr>(obj);
  300. if (arg->is_parameter()) {
  301. is_parameter = true;
  302. MS_LOG(DEBUG) << "parameter is read " << i;
  303. }
  304. arg_type_id = arg->data_type();
  305. }
  306. // No need to implicit cast if no dtype.
  307. if (!has_dtype_sig || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
  308. continue;
  309. }
  310. auto it = dst_type.find(dtypes[i]);
  311. if (it == dst_type.end() || it->second == kTypeUnknown) {
  312. continue;
  313. }
  314. // implicit cast
  315. bool is_same_type = false;
  316. bool is_sig_write = (sig == SignatureEnumRW::kRWWrite);
  317. if (arg_type_id != 0) {
  318. is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
  319. }
  320. if (is_sig_write) {
  321. if (!is_parameter) {
  322. prim::RaiseExceptionForCheckParameter(prim->name(), i, "not");
  323. }
  324. if (arg_type_id != 0) {
  325. if (!is_same_type) {
  326. prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id),
  327. TypeIdToMsTypeStr(it->second));
  328. }
  329. }
  330. }
  331. if (is_same_type) {
  332. continue;
  333. }
  334. if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) {
  335. MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i
  336. << "th input is a not support implicit conversion type: "
  337. << py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is "
  338. << py::cast<py::str>(obj) << ".";
  339. }
  340. py::object cast_output = DoAutoCast(out_args[i], it->second);
  341. out_args[i] = cast_output;
  342. }
  343. }
  344. void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info,
  345. const abstract::AbstractBasePtrList &args_spec_list) {
  346. MS_LOG(DEBUG) << "prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list);
  347. prim->BeginRecordAddAttr();
  348. AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
  349. prim->EndRecordAddAttr();
  350. op_exec_info->abstract = infer_res;
  351. MS_LOG(DEBUG) << "prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString();
  352. }
  353. OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
  354. if (args.size() != PY_ARGS_NUM) {
  355. MS_LOG(ERROR) << "Three args are needed by RunOp";
  356. return nullptr;
  357. }
  358. auto op_exec_info = std::make_shared<OpExecInfo>();
  359. MS_EXCEPTION_IF_NULL(op_exec_info);
  360. op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
  361. auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
  362. if (!prim->HasPyObj()) {
  363. MS_LOG(EXCEPTION) << "pyobj is empty";
  364. }
  365. op_exec_info->py_primitive = prim;
  366. op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
  367. op_exec_info->op_inputs = args[PY_INPUTS];
  368. return op_exec_info;
  369. }
  370. std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
  371. const std::vector<tensor::TensorPtr> &input_tensors) {
  372. MS_EXCEPTION_IF_NULL(op_exec_info);
  373. std::string graph_info;
  374. // get input tensor info
  375. for (const auto &tensor : input_tensors) {
  376. MS_EXCEPTION_IF_NULL(tensor);
  377. auto tensor_shape = tensor->shape();
  378. (void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
  379. [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
  380. (void)graph_info.append(std::to_string(tensor->data_type()) + "_");
  381. if (tensor->device_address() != nullptr) {
  382. (void)graph_info.append(
  383. std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id()) + "_");
  384. (void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format() +
  385. "_");
  386. }
  387. }
  388. // get prim and abstract info
  389. (void)graph_info.append(op_exec_info->prim_id + "_");
  390. // get attr info
  391. const auto &op_prim = op_exec_info->py_primitive;
  392. MS_EXCEPTION_IF_NULL(op_prim);
  393. const auto &attr_map = op_prim->evaluate_added_attrs();
  394. (void)std::for_each(attr_map.begin(), attr_map.end(),
  395. [&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
  396. return graph_info;
  397. }
  398. py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  399. MS_LOG(INFO) << "RunOpInVM start";
  400. MS_EXCEPTION_IF_NULL(status);
  401. MS_EXCEPTION_IF_NULL(op_exec_info);
  402. MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
  403. auto &op_inputs = op_exec_info->op_inputs;
  404. if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf") {
  405. py::tuple result(op_inputs.size());
  406. for (size_t i = 0; i < op_inputs.size(); i++) {
  407. py::object input = op_inputs[i];
  408. auto tensor = py::cast<tensor::TensorPtr>(input);
  409. auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
  410. new_tensor->set_device_address(tensor->device_address());
  411. new_tensor->set_sync_status(tensor->sync_status());
  412. result[i] = new_tensor;
  413. }
  414. *status = PYNATIVE_SUCCESS;
  415. MS_LOG(INFO) << "RunOpInVM end";
  416. return std::move(result);
  417. }
  418. auto primitive = op_exec_info->py_primitive;
  419. MS_EXCEPTION_IF_NULL(primitive);
  420. auto result = primitive->RunPyComputeFunction(op_inputs);
  421. if (py::isinstance<py::none>(result)) {
  422. MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
  423. *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
  424. py::tuple err_ret(0);
  425. return std::move(err_ret);
  426. }
  427. // execute op
  428. py::tuple tuple_result = py::make_tuple(result);
  429. *status = PYNATIVE_SUCCESS;
  430. MS_LOG(INFO) << "RunOpInVM end";
  431. return std::move(tuple_result);
  432. }
  433. bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
  434. const std::unordered_set<size_t> &input_attrs) {
  435. MS_EXCEPTION_IF_NULL(op_prim);
  436. auto input_names_value = op_prim->GetAttr(kAttrInputNames);
  437. if (input_names_value == nullptr) {
  438. return false;
  439. }
  440. auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
  441. if (input_index >= input_names_vec.size()) {
  442. MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
  443. }
  444. if (input_attrs.find(input_index) != input_attrs.end()) {
  445. ValuePtr value = parse::data_converter::PyDataToValue(input_object);
  446. MS_EXCEPTION_IF_NULL(value);
  447. auto input_name = input_names_vec[input_index];
  448. op_prim->AddAttr(input_name, value);
  449. return true;
  450. }
  451. return false;
  452. }
  453. void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
  454. std::vector<tensor::TensorPtr> *input_tensors) {
  455. MS_EXCEPTION_IF_NULL(op_prim);
  456. MS_EXCEPTION_IF_NULL(input_tensors);
  457. for (const auto &input_object : tuple_inputs) {
  458. if (!py::isinstance<tensor::Tensor>(input_object)) {
  459. MS_LOG(EXCEPTION) << "The input object is not a tensor!";
  460. }
  461. auto tensor = py::cast<tensor::TensorPtr>(input_object);
  462. MS_EXCEPTION_IF_NULL(tensor);
  463. input_tensors->push_back(tensor);
  464. }
  465. op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())}));
  466. }
  467. void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensors) {
  468. MS_EXCEPTION_IF_NULL(input_tensors);
  469. ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
  470. MS_EXCEPTION_IF_NULL(input_value);
  471. if (!input_value->isa<ValueTuple>()) {
  472. MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
  473. }
  474. auto value_tuple = input_value->cast<ValueTuplePtr>();
  475. MS_EXCEPTION_IF_NULL(value_tuple);
  476. tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
  477. MS_EXCEPTION_IF_NULL(tensor_ptr);
  478. input_tensors->push_back(tensor_ptr);
  479. }
  480. void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  481. std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
  482. MS_EXCEPTION_IF_NULL(op_prim);
  483. MS_EXCEPTION_IF_NULL(input_tensors);
  484. MS_EXCEPTION_IF_NULL(tensor_mask);
  485. if (!py::isinstance<py::tuple>(input_object)) {
  486. MS_LOG(EXCEPTION) << "The input should be a tuple!";
  487. }
  488. auto tuple_inputs = py::cast<py::tuple>(input_object);
  489. if (tuple_inputs.size() == 0) {
  490. MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
  491. }
  492. auto inputs = py::cast<py::tuple>(input_object);
  493. if (py::isinstance<tensor::Tensor>(inputs[0])) {
  494. PlantTensorTupleToVector(inputs, op_prim, input_tensors);
  495. } else {
  496. ConvertValueTupleToTensor(input_object, input_tensors);
  497. *tensor_mask = kValueNodeTensorMask;
  498. }
  499. }
  500. void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  501. std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
  502. MS_EXCEPTION_IF_NULL(op_prim);
  503. MS_EXCEPTION_IF_NULL(input_tensors);
  504. MS_EXCEPTION_IF_NULL(tensor_mask);
  505. tensor::TensorPtr tensor_ptr = nullptr;
  506. if (py::isinstance<tensor::Tensor>(input_object)) {
  507. tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
  508. } else if (py::isinstance<py::float_>(input_object)) {
  509. double input_value = py::cast<py::float_>(input_object);
  510. tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
  511. *tensor_mask = kValueNodeTensorMask;
  512. } else if (py::isinstance<py::int_>(input_object)) {
  513. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32);
  514. *tensor_mask = kValueNodeTensorMask;
  515. } else if (py::isinstance<py::array>(input_object)) {
  516. tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
  517. } else if (py::isinstance<py::list>(input_object)) {
  518. auto list_inputs = py::cast<py::list>(input_object);
  519. py::tuple tuple_inputs(list_inputs.size());
  520. for (size_t i = 0; i < tuple_inputs.size(); ++i) {
  521. tuple_inputs[i] = list_inputs[i];
  522. }
  523. ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
  524. return;
  525. } else if (py::isinstance<py::tuple>(input_object)) {
  526. ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
  527. return;
  528. } else if (py::isinstance<py::none>(input_object)) {
  529. return;
  530. } else {
  531. MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
  532. }
  533. MS_EXCEPTION_IF_NULL(tensor_ptr);
  534. input_tensors->push_back(tensor_ptr);
  535. }
  536. void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *tensors_mask,
  537. std::vector<tensor::TensorPtr> *input_tensors) {
  538. MS_EXCEPTION_IF_NULL(op_run_info);
  539. MS_EXCEPTION_IF_NULL(tensors_mask);
  540. MS_EXCEPTION_IF_NULL(input_tensors);
  541. PrimitivePtr op_prim = op_run_info->py_primitive;
  542. MS_EXCEPTION_IF_NULL(op_prim);
  543. opt::ConstInputToAttrInfoRegister reg;
  544. bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
  545. if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
  546. reg_exist = false;
  547. }
  548. op_prim->BeginRecordAddAttr();
  549. size_t input_num = op_run_info->op_inputs.size();
  550. for (size_t index = 0; index < input_num; ++index) {
  551. // convert const input to attr
  552. if (reg_exist &&
  553. RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
  554. continue;
  555. }
  556. // convert const and tuple input to tensor
  557. int tensor_mask = static_cast<int>(op_run_info->inputs_mask[index]);
  558. ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
  559. // mark tensors, data : 0, weight : 1, valuenode: 2
  560. std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
  561. tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
  562. }
  563. op_prim->EndRecordAddAttr();
  564. }
  565. void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) {
  566. MS_EXCEPTION_IF_NULL(input_tensors);
  567. if (input_tensors->size() != tensors_mask.size()) {
  568. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
  569. << tensors_mask.size();
  570. }
  571. std::vector<tensor::TensorPtr> new_input_tensors;
  572. for (size_t index = 0; index < tensors_mask.size(); ++index) {
  573. if (tensors_mask[index] != kValueNodeTensorMask) {
  574. new_input_tensors.push_back(input_tensors->at(index));
  575. }
  576. }
  577. *input_tensors = new_input_tensors;
  578. }
  579. BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
  580. if (utils::isa<VectorRef>(base_ref)) {
  581. auto ref_list = utils::cast<VectorRef>(base_ref);
  582. py::tuple output_tensors(ref_list.size());
  583. for (size_t i = 0; i < ref_list.size(); ++i) {
  584. auto output = TransformBaseRefListToTuple(ref_list[i]);
  585. if (utils::isa<tensor::TensorPtr>(output)) {
  586. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  587. MS_EXCEPTION_IF_NULL(tensor_ptr);
  588. output_tensors[i] = tensor_ptr;
  589. } else if (utils::isa<PyObjectRef>(output)) {
  590. py::object obj = utils::cast<PyObjectRef>(output).object_;
  591. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  592. output_tensors[i] = tensor_tuple;
  593. } else {
  594. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  595. }
  596. }
  597. return std::make_shared<PyObjectRef>(output_tensors);
  598. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  599. return base_ref;
  600. } else {
  601. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  602. }
  603. }
  604. py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  605. MS_EXCEPTION_IF_NULL(op_exec_info);
  606. MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
  607. auto ms_context = MsContext::GetInstance();
  608. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
  609. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  610. if (device_target != kAscendDevice && device_target != kGPUDevice) {
  611. MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
  612. }
  613. if (session == nullptr) {
  614. session = session::SessionFactory::Get().Create(device_target);
  615. MS_EXCEPTION_IF_NULL(session);
  616. session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));
  617. }
  618. std::vector<tensor::TensorPtr> input_tensors;
  619. std::vector<int> tensors_mask;
  620. ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
  621. // get graph info for checking it whether existing in the cache
  622. std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
  623. session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive, op_exec_info->abstract,
  624. op_exec_info->value};
  625. session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask);
  626. EraseValueNodeTensor(tensors_mask, &input_tensors);
  627. VectorRef outputs;
  628. session->RunOp(&op_run_info, graph_info, input_tensors, &outputs);
  629. auto result = BaseRefToPyData(outputs);
  630. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
  631. *status = PYNATIVE_SUCCESS;
  632. MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms";
  633. return result;
  634. }
  635. py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
  636. PynativeStatusCode *const status) {
  637. MS_EXCEPTION_IF_NULL(status);
  638. py::object result;
  639. switch (backend_policy) {
  640. case kMsBackendVmOnly: {
  641. // use vm only
  642. MS_LOG(INFO) << "RunOp use VM only backend";
  643. result = RunOpInVM(op_exec_info, status);
  644. break;
  645. }
  646. case kMsBackendGePrior: {
  647. #ifdef ENABLE_GE
  648. // use GE first, use vm when GE fails
  649. MS_LOG(INFO) << "RunOp use GE first backend";
  650. result = RunOpInGE(op_exec_info, status);
  651. if (*status != PYNATIVE_SUCCESS) {
  652. result = RunOpInVM(op_exec_info, status);
  653. }
  654. #endif
  655. break;
  656. }
  657. case kMsBackendMsPrior: {
  658. // use Ms fisrt,use others when ms failed
  659. MS_LOG(INFO) << "RunOp use Ms first backend";
  660. result = RunOpInMs(op_exec_info, status);
  661. if (*status != PYNATIVE_SUCCESS) {
  662. MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
  663. }
  664. break;
  665. }
  666. default:
  667. MS_LOG(ERROR) << "No backend configured for run op";
  668. }
  669. return result;
  670. }
  671. ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {
  672. auto id = GetOpId(op_exec_info);
  673. int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int>();
  674. auto op = std::to_string(graph_id) + id;
  675. op.append(std::to_string(op_id_map_[id]));
  676. auto iter = op_forward_map_.find(op);
  677. if (iter != op_forward_map_.end()) {
  678. ++op_id_map_[id];
  679. MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second;
  680. return iter->second;
  681. }
  682. if (!first_grad_step_) {
  683. ++op_id_map_[id];
  684. }
  685. return nullptr;
  686. }
  687. AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
  688. abstract::AbstractBasePtrList *args_spec_list) {
  689. MS_EXCEPTION_IF_NULL(op_masks);
  690. MS_EXCEPTION_IF_NULL(args_spec_list);
  691. CNodePtr cnode = nullptr;
  692. std::vector<AnfNodePtr> inputs;
  693. auto prim = op_exec_info->py_primitive;
  694. const auto &signature = prim->signatures();
  695. inputs.push_back(NewValueNode(prim));
  696. size_t size = op_exec_info->op_inputs.size();
  697. auto sig_size = signature.size();
  698. // ignore signature for cast op
  699. if (sig_size > 0 && sig_size != size) {
  700. MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
  701. << "inputs size " << sig_size;
  702. }
  703. bool is_cast_op = (op_exec_info->op_name == "Cast");
  704. if (!is_cast_op) {
  705. for (size_t i = 0; i < size; i++) {
  706. auto obj = op_exec_info->op_inputs[i];
  707. auto sig = SignatureEnumRW::kRWDefault;
  708. if (sig_size > 0) {
  709. sig = signature[i].rw;
  710. }
  711. MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "
  712. << std::string(py::repr(obj));
  713. // mix precision for non param
  714. bool is_cast = false;
  715. py::object cast_output;
  716. if (py::isinstance<tensor::MetaTensor>(obj)) {
  717. auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
  718. if (meta_tensor && meta_tensor->is_parameter()) {
  719. if (sig != SignatureEnumRW::kRWRead) {
  720. continue;
  721. }
  722. }
  723. // redundant cast call if the tensor is a const Tensor.
  724. cast_output = DoParamMixPrecisionCast(&is_cast, obj);
  725. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  726. // mix precision for tuple inputs
  727. cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj);
  728. }
  729. if (is_cast) {
  730. op_exec_info->op_inputs[i] = cast_output;
  731. }
  732. }
  733. std::vector<SignatureEnumDType> dtypes;
  734. bool has_dtype_sig = GetSignatureType(prim, &dtypes);
  735. std::map<SignatureEnumDType, TypeId> dst_types;
  736. if (has_dtype_sig) {
  737. // fetch info for implicit cast
  738. auto type_indexes = GetTypeIndex(dtypes);
  739. dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
  740. }
  741. MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name;
  742. DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
  743. }
  744. MS_LOG(DEBUG) << "make cnode for " << op_exec_info->op_name;
  745. for (size_t i = 0; i < size; i++) {
  746. const auto &obj = op_exec_info->op_inputs[i];
  747. bool op_mask = false;
  748. if (py::isinstance<tensor::MetaTensor>(obj)) {
  749. auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
  750. if (meta_tensor) {
  751. op_mask = meta_tensor->is_parameter();
  752. }
  753. }
  754. (*op_masks).push_back(op_mask);
  755. MS_LOG(DEBUG) << "gen args i " << i << " " << op_exec_info->op_name << " op mask " << op_mask << " grad_flag_ "
  756. << grad_flag_;
  757. AnfNodePtr node = nullptr;
  758. abstract::AbstractBasePtr abs = nullptr;
  759. auto id = GetId(obj);
  760. if (node_abs_map_.find(id) != node_abs_map_.end()) {
  761. abs = node_abs_map_[id];
  762. }
  763. if (!graph_info_map_.empty()) {
  764. node = GetInput(obj, op_mask);
  765. }
  766. if (node != nullptr && node->abstract() != nullptr) {
  767. abs = node->abstract();
  768. }
  769. auto const_input_index = prim->get_const_input_indexes();
  770. bool have_const_input = !const_input_index.empty();
  771. bool is_const_prim = prim->is_const_prim();
  772. MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
  773. << prim->is_const_prim();
  774. bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i);
  775. if (abs == nullptr || is_const_prim || is_const_input) {
  776. MS_LOG(DEBUG) << "MakeCnode get node no in map" << id;
  777. ValuePtr input_value = PyAttrValue(obj);
  778. abs = input_value->ToAbstract();
  779. if (!is_const_prim && !is_const_input) {
  780. auto config = abstract::AbstractBase::kBroadenTensorOnly;
  781. abs = abs->Broaden(config);
  782. MS_LOG(DEBUG) << "broaden for " << prim->ToString() << " " << config;
  783. }
  784. node_abs_map_[id] = abs;
  785. }
  786. (*args_spec_list).push_back(abs);
  787. inputs.push_back(node);
  788. }
  789. MS_LOG(DEBUG) << "MakeCnode args end";
  790. if (grad_flag_) {
  791. if (curr_g_ != nullptr) {
  792. cnode = curr_g_->NewCNode(inputs);
  793. MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4);
  794. }
  795. }
  796. return cnode;
  797. }
  798. void PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out_real,
  799. const AnfNodePtr &cnode) {
  800. if (!grad_flag_ || graph_info_map_.empty()) {
  801. MS_LOG(DEBUG) << "no graph cnode";
  802. return;
  803. }
  804. std::string obj_id = GetId(out_real);
  805. MS_EXCEPTION_IF_NULL(cnode);
  806. MS_LOG(DEBUG) << "MakeCnode set obj node id " << cnode->DebugString(4) << "id " << obj_id;
  807. if (py::isinstance<py::tuple>(out_real)) {
  808. auto value = py::cast<py::tuple>(out_real);
  809. if (value.size() > 1) {
  810. for (int i = 0; i < static_cast<int>(value.size()); i++) {
  811. auto value_id = GetId(value[i]);
  812. MS_LOG(DEBUG) << "MakeCnode set node id " << value_id;
  813. set_obj_node_map(curr_g_, value_id, cnode, i);
  814. }
  815. }
  816. }
  817. set_obj_node_map(curr_g_, obj_id, cnode);
  818. set_pyobj(curr_g_, obj_id);
  819. }
  820. void GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map) {
  821. if (t_map == nullptr) {
  822. return;
  823. }
  824. for (size_t i = 0; i < tuple->size(); i++) {
  825. ValuePtr tuple_i = (*tuple)[i];
  826. if (tuple_i->isa<tensor::Tensor>()) {
  827. auto t = tuple_i->cast<tensor::TensorPtr>();
  828. (*t_map)[t->id()] = t;
  829. } else if (tuple_i->isa<ValueTuple>()) {
  830. GenTupleMap(tuple_i->cast<ValueTuplePtr>(), t_map);
  831. }
  832. }
  833. MS_LOG(DEBUG) << "End GenTupleMap" << tuple->ToString();
  834. }
  835. ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple) {
  836. std::vector<ValuePtr> value_list;
  837. for (size_t i = 0; i < tuple->size(); i++) {
  838. ValuePtr tuple_i = (*tuple)[i];
  839. if (tuple_i->isa<tensor::Tensor>()) {
  840. auto t = tuple_i->cast<tensor::TensorPtr>();
  841. auto new_tensor = std::make_shared<tensor::Tensor>(*t);
  842. new_tensor->set_device_address(nullptr);
  843. value_list.push_back(new_tensor);
  844. } else if (tuple_i->isa<ValueTuple>()) {
  845. value_list.push_back(CleanTupleAddr(tuple_i->cast<ValueTuplePtr>()));
  846. } else {
  847. MS_LOG(DEBUG) << "in value" << tuple_i->ToString();
  848. value_list.push_back(tuple_i);
  849. }
  850. }
  851. MS_LOG(DEBUG) << "End CleanTupleAddr";
  852. return std::make_shared<ValueTuple>(value_list);
  853. }
  854. void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value,
  855. std::map<std::string, tensor::TensorPtr> *t_map) {
  856. if (op_forward_map_.find(id) != op_forward_map_.end()) {
  857. if (op_forward_map_[id]->isa<ValueTuple>()) {
  858. // for one op have multi outputs but save only one tensor
  859. if (value->isa<tensor::Tensor>()) {
  860. auto tuple = op_forward_map_[id]->cast<ValueTuplePtr>();
  861. auto value_t = value->cast<tensor::TensorPtr>();
  862. for (size_t i = 0; i < tuple->size(); i++) {
  863. if ((*tuple)[i]->isa<tensor::Tensor>()) {
  864. auto tuple_t = (*tuple)[i]->cast<tensor::TensorPtr>();
  865. if (value_t->id() == tuple_t->id()) {
  866. tuple_t->set_device_address(value_t->device_address());
  867. MS_LOG(DEBUG) << "After Saveop " << tuple_t->ToString();
  868. break;
  869. }
  870. }
  871. }
  872. }
  873. }
  874. if (value->isa<ValueTuple>() && t_map != nullptr) {
  875. GenTupleMap(op_forward_map_[id]->cast<ValueTuplePtr>(), t_map);
  876. }
  877. MS_LOG(DEBUG) << "Save op forward value: "
  878. << "(" << id << "), " << op_forward_map_[id]->ToString();
  879. return;
  880. }
  881. if (value->isa<ValueTuple>() && t_map == nullptr) {
  882. // make cnode gen all tuple node and set device_address be null
  883. op_forward_map_[id] = CleanTupleAddr(value->cast<ValueTuplePtr>());
  884. } else {
  885. op_forward_map_[id] = value;
  886. }
  887. MS_LOG(DEBUG) << "Save op forward value: "
  888. << "(" << id << "), " << value->ToString();
  889. }
  890. void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) {
  891. if (!grad_flag_ || op_exec_info->value != nullptr) {
  892. return;
  893. }
  894. py::object out_real = out;
  895. if (out.size() == 1) {
  896. out_real = out[0];
  897. }
  898. auto value = PyAttrValue(out_real);
  899. if (cnode != nullptr) {
  900. size_t size = op_exec_info->op_inputs.size();
  901. for (size_t i = 0; i < size; i++) {
  902. auto obj = op_exec_info->op_inputs[i];
  903. auto obj_id = GetId(obj);
  904. if (obj_to_forward_id_.find(obj_id) != obj_to_forward_id_.end()) {
  905. cnode->add_input_value(PyAttrValue(obj), obj_to_forward_id_[obj_id]);
  906. } else {
  907. cnode->add_input_value(nullptr, "");
  908. }
  909. }
  910. std::string id = GetOpId(op_exec_info);
  911. int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int>();
  912. auto op_id = std::to_string(graph_id) + id;
  913. op_id.append(std::to_string(op_id_map_[id]));
  914. cnode->set_forward(value, op_id);
  915. ++op_id_map_[id];
  916. auto out_id = GetId(out_real);
  917. if (py::isinstance<py::tuple>(out_real)) {
  918. auto tuple_item = py::cast<py::tuple>(out_real);
  919. for (size_t i = 0; i < tuple_item.size(); i++) {
  920. auto tuple_item_id = GetId(tuple_item[i]);
  921. obj_to_forward_id_[tuple_item_id] = op_id;
  922. }
  923. SaveOpForwardValue(op_id, value, nullptr);
  924. }
  925. obj_to_forward_id_[out_id] = op_id;
  926. }
  927. }
  928. AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
  929. auto id = GetId(obj);
  930. auto &out = graph_info_map_[curr_g_].obj_node_map[id];
  931. if (out.second.size() == 1 && out.second[0] == -1) {
  932. return out.first;
  933. }
  934. CNodePtr node = out.first->cast<CNodePtr>();
  935. MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString();
  936. auto abs = node->abstract();
  937. ValuePtr out_obj = nullptr;
  938. if (node->forward().first != nullptr) {
  939. out_obj = node->forward().first;
  940. } else {
  941. out_obj = PyAttrValue(obj);
  942. }
  943. for (auto &idx : out.second) {
  944. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
  945. node = curr_g_->NewCNode(tuple_get_item_inputs);
  946. if (out_obj->isa<ValueTuple>()) {
  947. node->add_input_value(out_obj, "");
  948. node->add_input_value(MakeValue(idx), "");
  949. out_obj = (*out_obj->cast<ValueTuplePtr>())[idx];
  950. node->set_forward(out_obj, "");
  951. }
  952. if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
  953. auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[idx];
  954. MS_LOG(DEBUG) << "set tuple getitem abs" << prim_abs->ToString();
  955. node->set_abstract(prim_abs);
  956. }
  957. }
  958. if (node->abstract() != nullptr) {
  959. node_abs_map_[id] = node->abstract();
  960. }
  961. MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6);
  962. return node;
  963. }
  964. AnfNodePtr PynativeExecutor::GetParamNode(const py::object &obj) {
  965. auto id = GetId(obj);
  966. auto &param = graph_info_map_[curr_g_].param_map[id];
  967. if (param.second.size() == 1 && param.second[0] == -1) {
  968. return param.first;
  969. }
  970. auto para_node = param.first;
  971. for (auto &idx : param.second) {
  972. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node, NewValueNode(idx)};
  973. para_node = curr_g_->NewCNode(tuple_get_item_inputs);
  974. }
  975. return para_node;
  976. }
  977. std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) {
  978. auto cell_id = GetId(cell);
  979. for (size_t i = 0; i < args.size(); i++) {
  980. std::string arg_id = GetId(args[i]);
  981. if (node_abs_map_.find(arg_id) != node_abs_map_.end()) {
  982. cell_id += node_abs_map_[arg_id]->ToString();
  983. } else {
  984. auto abs = PyAttrValue(args[i])->ToAbstract();
  985. auto config = abstract::AbstractBase::kBroadenTensorOnly;
  986. abs = abs->Broaden(config);
  987. cell_id += abs->ToString();
  988. node_abs_map_[arg_id] = abs;
  989. }
  990. }
  991. return cell_id;
  992. }
  993. py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
  994. MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
  995. mindspore::parse::python_adapter::set_python_env_flag(true);
  996. MsBackendPolicy backend_policy;
  997. #if (!defined ENABLE_GE)
  998. auto ms_context = MsContext::GetInstance();
  999. MS_EXCEPTION_IF_NULL(ms_context);
  1000. if (!context::IsTsdOpened(ms_context)) {
  1001. if (!context::OpenTsd(ms_context)) {
  1002. MS_LOG(EXCEPTION) << "Open tsd failed";
  1003. }
  1004. }
  1005. if (ms_context->backend_policy() == "ms") {
  1006. backend_policy = kMsBackendMsPrior;
  1007. } else {
  1008. backend_policy = kMsBackendVmOnly;
  1009. }
  1010. #else
  1011. auto ms_context = MsContext::GetInstance();
  1012. MS_EXCEPTION_IF_NULL(ms_context);
  1013. context::PynativeInitGe(ms_context);
  1014. backend_policy = kMsBackendGeOnly;
  1015. #endif
  1016. if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
  1017. backend_policy = kMsBackendVmOnly;
  1018. }
  1019. PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
  1020. // returns a null py::tuple on error
  1021. py::tuple err_ret(0);
  1022. py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
  1023. if (status != PYNATIVE_SUCCESS) {
  1024. MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name;
  1025. return err_ret;
  1026. }
  1027. MS_LOG(DEBUG) << "RunOp end";
  1028. return result;
  1029. }
  1030. py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
  1031. MS_LOG(DEBUG) << "RunOp start " << args.size();
  1032. OpExecInfoPtr op_exec_info = nullptr;
  1033. auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
  1034. auto name = py::cast<std::string>(args[PY_NAME]);
  1035. abstract::AbstractBasePtrList args_spec_list;
  1036. std::vector<bool> op_masks;
  1037. op_exec_info = GenerateOpExecInfo(args);
  1038. if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
  1039. return RunOpInner(op_exec_info);
  1040. }
  1041. auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
  1042. bool is_find = false;
  1043. if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
  1044. auto abs_list = prim_abs_list_[prim->id()];
  1045. MS_LOG(DEBUG) << "match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
  1046. if (abs_list.find(args_spec_list) != abs_list.end()) {
  1047. MS_LOG(DEBUG) << "match prim ok" << op_exec_info->op_name;
  1048. op_exec_info->abstract = abs_list[args_spec_list].abs;
  1049. prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
  1050. is_find = true;
  1051. }
  1052. }
  1053. if (op_exec_info->abstract == nullptr) {
  1054. // use python infer method
  1055. if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
  1056. PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);
  1057. }
  1058. }
  1059. if (cnode != nullptr) {
  1060. cnode->set_abstract(op_exec_info->abstract);
  1061. MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
  1062. }
  1063. op_exec_info->inputs_mask = op_masks;
  1064. MS_EXCEPTION_IF_NULL(op_exec_info);
  1065. if (op_exec_info->abstract != nullptr) {
  1066. MS_LOG(DEBUG) << "run op infer" << name << op_exec_info->abstract->ToString();
  1067. py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
  1068. if (!output["value"].is_none()) {
  1069. py::tuple value_ret(1);
  1070. value_ret[0] = output["value"];
  1071. return value_ret;
  1072. }
  1073. if (op_exec_info->py_primitive->is_const_prim()) {
  1074. py::tuple value_ret(1);
  1075. value_ret[0] = "";
  1076. return value_ret;
  1077. }
  1078. }
  1079. if (!is_find) {
  1080. // const_value need infer every step
  1081. auto &out = prim_abs_list_[prim->id()];
  1082. out[args_spec_list].abs = op_exec_info->abstract;
  1083. out[args_spec_list].attrs = prim->evaluate_added_attrs();
  1084. MS_LOG(DEBUG) << "set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
  1085. }
  1086. if (PynativeExecutor::GetInstance()->grad_flag()) {
  1087. op_exec_info->value = PynativeExecutor::GetInstance()->GetForwardValue(op_exec_info);
  1088. } else {
  1089. (void)GetOpId(op_exec_info);
  1090. }
  1091. auto result = RunOpInner(op_exec_info);
  1092. py::object out_real = result;
  1093. if (result.size() == 1) {
  1094. MS_LOG(DEBUG) << "MakeCnode out size is one.";
  1095. out_real = result[0];
  1096. }
  1097. std::string obj_id = GetId(out_real);
  1098. node_abs_map_[obj_id] = op_exec_info->abstract;
  1099. PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, out_real, cnode);
  1100. if (cnode != nullptr) {
  1101. PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode->cast<CNodePtr>(), result);
  1102. }
  1103. return result;
  1104. }
  1105. py::tuple RunOp(const py::args &args) {
  1106. try {
  1107. return PynativeExecutor::GetInstance()->RunOpInner(args);
  1108. } catch (const py::error_already_set &ex) {
  1109. PynativeExecutor::GetInstance()->Clean();
  1110. // re-throw this exception to Python interpreter to handle it
  1111. throw(py::error_already_set(ex));
  1112. } catch (const py::type_error &ex) {
  1113. PynativeExecutor::GetInstance()->Clean();
  1114. throw py::type_error(ex);
  1115. } catch (const py::value_error &ex) {
  1116. PynativeExecutor::GetInstance()->Clean();
  1117. throw py::value_error(ex);
  1118. } catch (const py::index_error &ex) {
  1119. PynativeExecutor::GetInstance()->Clean();
  1120. throw py::index_error(ex);
  1121. } catch (const std::exception &ex) {
  1122. PynativeExecutor::GetInstance()->Clean();
  1123. // re-throw this exception to Python interpreter to handle it
  1124. throw(std::runtime_error(ex.what()));
  1125. } catch (...) {
  1126. PynativeExecutor::GetInstance()->Clean();
  1127. std::string exName(abi::__cxa_current_exception_type()->name());
  1128. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
  1129. }
  1130. }
  1131. void ClearPyNativeSession() { session = nullptr; }
  1132. PynativeExecutor::~PynativeExecutor() { ClearRes(); }
  1133. PynativeExecutor::PynativeExecutor() {
  1134. grad_flag_ = false;
  1135. first_grad_step_ = false;
  1136. }
  1137. void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
  1138. auto cell_id = GetCellId(cell, args);
  1139. // judge graph_context_.empty() to create sperate graphs except for the top
  1140. if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) {
  1141. if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
  1142. resource_ = cell_resource_map_[cell_id];
  1143. }
  1144. MS_LOG(DEBUG) << "Newgraph already compiled";
  1145. return;
  1146. }
  1147. auto g = std::make_shared<FuncGraph>();
  1148. if (graph_context_.empty()) {
  1149. for (auto arg : args) {
  1150. if (py::isinstance<tensor::Tensor>(arg)) {
  1151. auto tensor = arg.cast<tensor::TensorPtr>();
  1152. if (tensor && tensor->is_parameter()) {
  1153. MS_EXCEPTION(TypeError) << "The inputs could not be Parameter.";
  1154. }
  1155. }
  1156. }
  1157. // a df builder is built for every top function graph
  1158. df_builder_ = std::make_shared<FuncGraph>();
  1159. df_builder_map_[cell_id] = df_builder_;
  1160. top_g_ = curr_g_ = g;
  1161. resource_ = std::make_shared<pipeline::Resource>();
  1162. resource_->results()[pipeline::kPynativeGraphId] = graph_id_++;
  1163. cell_resource_map_[cell_id] = resource_;
  1164. MS_LOG(DEBUG) << "First new graph" << top_g_.get();
  1165. first_grad_step_ = true;
  1166. top_graph_cells_.insert(cell_id);
  1167. } else {
  1168. if (df_builder_ == nullptr) {
  1169. MS_LOG(EXCEPTION) << "In NewGraphInner, got df builder is nullptr";
  1170. }
  1171. curr_g_ = g;
  1172. }
  1173. Pushp();
  1174. if (graph_info_map_.count(g) == 0) {
  1175. graph_info_map_[g] = GraphInfo();
  1176. }
  1177. for (size_t i = 0; i < args.size(); i++) {
  1178. auto param = args[i];
  1179. auto new_param = g->add_parameter();
  1180. std::string param_obj = GetId(param);
  1181. if (py::isinstance<py::tuple>(param)) {
  1182. auto tuple = param.cast<py::tuple>();
  1183. auto tuple_size = static_cast<int>(tuple.size());
  1184. for (int j = 0; j < tuple_size; j++) {
  1185. set_param_map(curr_g_, GetId(tuple[j]), new_param, j);
  1186. SetTupleParam(tuple[j], new_param, std::vector<int>{j});
  1187. }
  1188. }
  1189. set_param_map(curr_g_, param_obj, new_param);
  1190. }
  1191. }
  1192. AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) {
  1193. ValuePtr converted_ret = nullptr;
  1194. parse::ConvertData(obj, &converted_ret);
  1195. auto node = NewValueNode(converted_ret);
  1196. set_obj_node_map(curr_g_, obj_id, node);
  1197. return node;
  1198. }
  1199. AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
  1200. AnfNodePtr node = nullptr;
  1201. std::string obj_id = GetId(obj);
  1202. if (op_mask) {
  1203. MS_LOG(DEBUG) << "Topgraph free parameter";
  1204. // get the parameter name from parameter object
  1205. auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name");
  1206. if (py::isinstance<py::none>(name_attr)) {
  1207. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  1208. }
  1209. auto param_name = py::cast<std::string>(name_attr);
  1210. if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
  1211. auto free_param = df_builder_->add_parameter();
  1212. free_param->set_name(param_name);
  1213. free_param->debug_info()->set_name(param_name);
  1214. auto value = py::cast<tensor::TensorPtr>(obj);
  1215. free_param->set_default_param(value);
  1216. MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
  1217. set_param_map(df_builder_, obj_id, free_param);
  1218. return free_param;
  1219. }
  1220. return graph_info_map_[df_builder_].param_map[obj_id].first;
  1221. }
  1222. // if input is graph output
  1223. if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) {
  1224. // op(x, y)
  1225. node = GetParamNode(obj);
  1226. } else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) {
  1227. // out = op(op1(x, y))
  1228. // out = op(cell1(x, y))
  1229. // out = op(cell1(x, y)[0])
  1230. node = GetObjNode(obj);
  1231. } else if (py::isinstance<py::tuple>(obj)) {
  1232. // out = op((x, y))
  1233. // out = cell((x, y))
  1234. auto tuple = obj.cast<py::tuple>();
  1235. // cell((1,2)): support not mix (scalar, tensor)
  1236. if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) {
  1237. return MakeValueNode(obj, obj_id);
  1238. }
  1239. std::vector<AnfNodePtr> args;
  1240. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  1241. auto tuple_size = static_cast<int>(tuple.size());
  1242. for (int i = 0; i < tuple_size; i++) {
  1243. args.push_back(GetInput(tuple[i], false));
  1244. }
  1245. auto cnode = curr_g_->NewCNode(args);
  1246. set_obj_node_map(curr_g_, GetId(obj), cnode);
  1247. node = cnode;
  1248. } else {
  1249. node = MakeValueNode(obj, obj_id);
  1250. }
  1251. MS_LOG(DEBUG) << "Now getinput node " << node->ToString() << obj_id;
  1252. return node;
  1253. }
  1254. // for output[0][1] need getitem multi
  1255. void PynativeExecutor::SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector<int> idx) {
  1256. if (py::isinstance<py::tuple>(obj)) {
  1257. auto tuple = obj.cast<py::tuple>();
  1258. for (int i = 0; i < static_cast<int>(tuple.size()); i++) {
  1259. std::vector<int> tmp = idx;
  1260. tmp.push_back(i);
  1261. set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, tmp);
  1262. SetTupleOutput(tuple[i], cnode, tmp);
  1263. }
  1264. }
  1265. }
  1266. // for param ((a, (b, c)), d) need multi getitem
  1267. void PynativeExecutor::SetTupleParam(const py::object &obj, const AnfNodePtr &para_node, std::vector<int> idx) {
  1268. if (py::isinstance<py::tuple>(obj)) {
  1269. auto tuple = obj.cast<py::tuple>();
  1270. for (int i = 0; i < static_cast<int>(tuple.size()); i++) {
  1271. std::vector<int> tmp = idx;
  1272. tmp.push_back(i);
  1273. set_param_map(curr_g_, GetId(tuple[i]), para_node, tmp);
  1274. SetTupleParam(tuple[i], para_node, tmp);
  1275. }
  1276. }
  1277. }
  1278. void PynativeExecutor::Pushp() { graph_context_.push(curr_g_); }
  1279. void PynativeExecutor::Popp() {
  1280. if (graph_context_.empty()) {
  1281. MS_LOG(EXCEPTION) << "Stack graph_context_ is empty";
  1282. }
  1283. graph_context_.pop();
  1284. if (!graph_context_.empty()) {
  1285. curr_g_ = graph_context_.top();
  1286. }
  1287. }
  1288. void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
  1289. auto cell_id = GetCellId(cell, args);
  1290. if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) {
  1291. MS_LOG(DEBUG) << "Endgraph already compiled";
  1292. return;
  1293. }
  1294. cell_graph_map_[cell_id] = curr_g_;
  1295. auto out_id = GetId(out);
  1296. if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) {
  1297. // cell construct return x, y
  1298. if (py::isinstance<py::tuple>(out)) {
  1299. std::vector<AnfNodePtr> args;
  1300. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  1301. auto tuple = out.cast<py::tuple>();
  1302. MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size();
  1303. auto tuple_size = static_cast<int>(tuple.size());
  1304. auto cnode = curr_g_->NewCNode(args);
  1305. for (int i = 0; i < tuple_size; i++) {
  1306. args.push_back(GetInput(tuple[i], false));
  1307. }
  1308. cnode->set_inputs(args);
  1309. for (int i = 0; i < tuple_size; i++) {
  1310. set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i);
  1311. SetTupleOutput(tuple[i], cnode, std::vector<int>{i});
  1312. }
  1313. set_obj_node_map(curr_g_, out_id, cnode);
  1314. } else {
  1315. MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
  1316. MakeValueNode(out, out_id);
  1317. }
  1318. }
  1319. EndGraphByOutId(out_id, cell, out, args);
  1320. }
  1321. void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out,
  1322. const py::args &args) {
  1323. AnfNodePtr output_node;
  1324. if (graph_info_map_[curr_g_].param_map.count(out_id)) {
  1325. output_node = GetParamNode(out);
  1326. } else {
  1327. output_node = GetObjNode(out);
  1328. }
  1329. curr_g_->set_output(output_node);
  1330. std::vector<AnfNodePtr> inputs;
  1331. inputs.push_back(NewValueNode(curr_g_));
  1332. MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
  1333. resource_->manager()->AddFuncGraph(curr_g_);
  1334. // custom bprop debug
  1335. bool need_replace_param = false;
  1336. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  1337. need_replace_param = true;
  1338. size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
  1339. if (par_number > 0) {
  1340. MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
  1341. << " parameters that is not supported in the net.";
  1342. }
  1343. MS_LOG(DEBUG) << "Use cell custom bprop function.";
  1344. FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
  1345. if (bprop_graph != nullptr) {
  1346. (void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
  1347. (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_)));
  1348. }
  1349. }
  1350. auto newfg = ad::Grad(curr_g_, resource_, graph_context_.size() == 1);
  1351. if (need_replace_param) {
  1352. auto params = newfg->parameters();
  1353. auto manager = Manage({newfg}, false);
  1354. if (args.size() > params.size()) {
  1355. MS_EXCEPTION(TypeError) << "The number of arguments " << args.size()
  1356. << " is more than the number of parameters required, which is " << params.size();
  1357. }
  1358. for (size_t i = 0; i < args.size(); i++) {
  1359. ValuePtr value = PyAttrValue(args[i]);
  1360. auto v_node = NewValueNode(value);
  1361. manager->Replace(params[i], v_node);
  1362. }
  1363. }
  1364. graph_info_map_.erase(curr_g_);
  1365. if (graph_context_.size() > 1) {
  1366. Popp();
  1367. // connect the previous graph to the inside graph
  1368. auto graph_prev = graph_context_.top();
  1369. for (size_t i = 0; i < args.size(); i++) {
  1370. auto input = GetInput(args[i], false);
  1371. inputs.push_back(input);
  1372. }
  1373. auto out_cnode = graph_prev->NewCNode(inputs);
  1374. set_pyobj(graph_prev, GetCellId(cell, args));
  1375. if (py::isinstance<py::tuple>(out)) {
  1376. auto out_list = py::cast<py::tuple>(out);
  1377. auto out_size = static_cast<int>(out_list.size());
  1378. for (int i = 0; i < out_size; i++) {
  1379. set_obj_node_map(graph_prev, GetId(out_list[i]), out_cnode, i);
  1380. SetTupleOutput(out_list[i], out_cnode, std::vector<int>{i});
  1381. }
  1382. }
  1383. set_obj_node_map(graph_prev, GetId(out), out_cnode);
  1384. } else {
  1385. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  1386. DumpIR("before_resolve.ir", newfg);
  1387. }
  1388. parse::ResolveFuncGraph(newfg, resource_);
  1389. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  1390. DumpIR("after_resolve.ir", newfg);
  1391. }
  1392. resource_->set_func_graph(newfg);
  1393. Popp();
  1394. }
  1395. }
  1396. std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weights) {
  1397. std::vector<AnfNodePtr> w_args;
  1398. if (py::hasattr(weights, "__parameter_tuple__")) {
  1399. auto tuple = weights.cast<py::tuple>();
  1400. MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size();
  1401. w_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  1402. for (size_t it = 0; it < tuple.size(); ++it) {
  1403. auto param = tuple[it];
  1404. auto param_id = GetId(param);
  1405. AnfNodePtr para_node = nullptr;
  1406. if (graph_info_map_[df_builder_].param_map.count(param_id)) {
  1407. para_node = graph_info_map_[df_builder_].param_map[param_id].first;
  1408. } else {
  1409. auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(param, "name");
  1410. if (py::isinstance<py::none>(name_attr)) {
  1411. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  1412. }
  1413. auto param_name = py::cast<std::string>(name_attr);
  1414. auto free_param = df_builder_->add_parameter();
  1415. free_param->set_name(param_name);
  1416. auto value = py::cast<tensor::TensorPtr>(param);
  1417. free_param->set_default_param(value);
  1418. free_param->debug_info()->set_name(param_name);
  1419. para_node = free_param;
  1420. }
  1421. w_args.push_back(para_node);
  1422. }
  1423. } else {
  1424. MS_LOG(DEBUG) << "training not paramter_tuple";
  1425. }
  1426. return w_args;
  1427. }
  1428. abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) {
  1429. abstract::AbstractBasePtrList args_spec;
  1430. std::size_t size = args.size();
  1431. for (std::size_t i = 0; i < size; i++) {
  1432. ValuePtr converted = nullptr;
  1433. bool succ = parse::ConvertData(args[i], &converted);
  1434. if (!succ) {
  1435. MS_LOG(EXCEPTION) << "Args convert error";
  1436. }
  1437. bool broaden = true;
  1438. auto abs = abstract::FromValue(converted, broaden);
  1439. args_spec.push_back(abs);
  1440. auto param_node = std::static_pointer_cast<Parameter>(df_builder_->parameters()[i]);
  1441. param_node->set_abstract(abs);
  1442. }
  1443. for (const auto &param : df_builder_->parameters()) {
  1444. auto param_node = std::static_pointer_cast<Parameter>(param);
  1445. if (param_node->has_default()) {
  1446. ValuePtr value = param_node->default_param();
  1447. auto ptr = value->ToAbstract();
  1448. if (ptr == nullptr) {
  1449. MS_LOG(EXCEPTION) << "Args convert error";
  1450. }
  1451. args_spec.push_back(ptr);
  1452. param_node->set_abstract(ptr);
  1453. }
  1454. }
  1455. return args_spec;
  1456. }
  1457. void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  1458. const py::args &args) {
  1459. MS_LOG(INFO) << "GradNet start" << args.size();
  1460. std::size_t size = args.size();
  1461. std::string cell_id = GetCellId(cell, args);
  1462. if (graph_map_.count(cell_id) != 0) {
  1463. MS_LOG(DEBUG) << "GradNet already compiled";
  1464. return;
  1465. }
  1466. size_t forward_args_count = args.size();
  1467. if (grad->sens_param()) {
  1468. forward_args_count = forward_args_count - 1;
  1469. }
  1470. py::tuple forward_args(forward_args_count);
  1471. for (size_t i = 0; i < forward_args_count; i++) {
  1472. forward_args[i] = args[i];
  1473. }
  1474. std::string forward_cell_id = GetCellId(cell, forward_args);
  1475. MS_LOG(DEBUG) << "Forward cell_id:" << forward_cell_id;
  1476. if (df_builder_map_.find(forward_cell_id) == df_builder_map_.end()) {
  1477. MS_LOG(EXCEPTION) << "Cannot find df builder";
  1478. }
  1479. df_builder_ = df_builder_map_[forward_cell_id];
  1480. if (df_builder_ == nullptr) {
  1481. MS_LOG(EXCEPTION) << "Got unexpected null df builder";
  1482. }
  1483. if (cell_resource_map_.find(forward_cell_id) == cell_resource_map_.end()) {
  1484. MS_LOG(EXCEPTION) << "Cannot find resource for " << forward_cell_id;
  1485. }
  1486. MS_LOG(DEBUG) << "GradNet first compiled";
  1487. resource_ = cell_resource_map_[forward_cell_id];
  1488. std::vector<AnfNodePtr> new_params;
  1489. for (size_t i = 0; i < size; i++) {
  1490. ParameterPtr p = std::make_shared<Parameter>(df_builder_);
  1491. new_params.push_back(p);
  1492. }
  1493. MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size();
  1494. new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end());
  1495. df_builder_->set_parameters(new_params);
  1496. resource_->manager()->SetParameters(df_builder_, new_params);
  1497. std::vector<AnfNodePtr> w_args = GetWeightsArgs(weights);
  1498. MS_EXCEPTION_IF_NULL(resource_->func_graph());
  1499. if (cell_graph_map_.find(forward_cell_id) == cell_graph_map_.end()) {
  1500. MS_LOG(EXCEPTION) << "Could not find top graph by cellid: " << forward_cell_id;
  1501. }
  1502. top_g_ = cell_graph_map_[forward_cell_id];
  1503. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  1504. DumpIR("before_grad.ir", resource_->func_graph());
  1505. }
  1506. auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
  1507. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  1508. DumpIR("after_grad.ir", g);
  1509. }
  1510. resource_->set_func_graph(g);
  1511. resource_->manager()->KeepRoots({g});
  1512. // get the parameters items and add the value to args_spec
  1513. abstract::AbstractBasePtrList args_spec = GetArgsSpec(args);
  1514. MS_LOG(DEBUG) << "Args_spec size" << args_spec.size();
  1515. resource_->set_args_spec(args_spec);
  1516. MS_LOG(DEBUG) << "Start opt";
  1517. // Create backend and session
  1518. resource_->results()[pipeline::kBackend] = compile::CreateBackend();
  1519. graph_map_[cell_id] = g;
  1520. PynativeOptimizeAction(resource_);
  1521. TaskEmitAction(resource_);
  1522. ExecuteAction(resource_);
  1523. resource_->Clean();
  1524. ad::CleanRes();
  1525. pipeline::ReclaimOptimizer();
  1526. }
  1527. template <typename T>
  1528. void MapClear(T *map, const std::string &flag) {
  1529. for (auto it = map->begin(); it != map->end();) {
  1530. if (it->first.find(flag) != std::string::npos) {
  1531. it->second = nullptr;
  1532. it = map->erase(it);
  1533. } else {
  1534. it++;
  1535. }
  1536. }
  1537. }
  1538. void PynativeExecutor::Clear(const std::string &flag) {
  1539. if (!flag.empty()) {
  1540. MS_LOG(DEBUG) << "Clear res";
  1541. MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&graph_map_, flag);
  1542. MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&cell_graph_map_, flag);
  1543. MapClear<std::unordered_map<std::string, ResourcePtr>>(&cell_resource_map_, flag);
  1544. MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&df_builder_map_, flag);
  1545. // Maybe exit in the pynative runing op, so need reset pynative flag.
  1546. auto ms_context = MsContext::GetInstance();
  1547. if (ms_context != nullptr) {
  1548. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
  1549. }
  1550. ConfigManager::GetInstance().ResetIterNum();
  1551. if (top_graph_cells_.find(flag) != top_graph_cells_.end()) {
  1552. op_forward_map_.clear();
  1553. Clean();
  1554. }
  1555. return;
  1556. }
  1557. MS_LOG(DEBUG) << "Clear";
  1558. grad_flag_ = false;
  1559. top_g_ = nullptr;
  1560. df_builder_ = nullptr;
  1561. curr_g_ = nullptr;
  1562. first_grad_step_ = false;
  1563. graph_info_map_.clear();
  1564. op_id_map_.clear();
  1565. obj_to_forward_id_.clear();
  1566. std::stack<FuncGraphPtr>().swap(graph_context_);
  1567. ConfigManager::GetInstance().ResetIterNum();
  1568. }
  1569. void PynativeExecutor::Clean() {
  1570. MS_LOG(DEBUG) << "Clean all res";
  1571. Clear();
  1572. grad_flag_ = false;
  1573. ad::CleanRes();
  1574. pipeline::ReclaimOptimizer();
  1575. }
  1576. template <typename T>
  1577. void MapErase(T *map) {
  1578. for (auto it = map->begin(); it != map->end();) {
  1579. it = map->erase(it++);
  1580. }
  1581. }
  1582. void PynativeExecutor::ClearRes() {
  1583. MapErase<std::unordered_map<std::string, FuncGraphPtr>>(&graph_map_);
  1584. MapErase<std::unordered_map<std::string, FuncGraphPtr>>(&cell_graph_map_);
  1585. MapErase<std::unordered_map<std::string, ResourcePtr>>(&cell_resource_map_);
  1586. MapErase<std::unordered_map<std::string, abstract::AbstractBasePtr>>(&node_abs_map_);
  1587. Clean();
  1588. resource_.reset();
  1589. }
  1590. size_t GetTupleSize(const py::tuple &args) {
  1591. size_t count = 0;
  1592. for (size_t i = 0; i < args.size(); i++) {
  1593. if (py::isinstance<py::tuple>(args[i])) {
  1594. count += GetTupleSize(args[i]);
  1595. } else {
  1596. count += 1;
  1597. }
  1598. }
  1599. return count;
  1600. }
  1601. void ConvertTupleArg(py::tuple *res, size_t *index, const py::tuple &arg) {
  1602. for (size_t i = 0; i < arg.size(); i++) {
  1603. if (py::isinstance<py::tuple>(arg[i])) {
  1604. ConvertTupleArg(res, index, arg[i]);
  1605. } else {
  1606. (*res)[(*index)++] = arg[i];
  1607. }
  1608. }
  1609. }
  1610. py::tuple ConvertArgs(const py::tuple &args) {
  1611. size_t tuple_size = GetTupleSize(args);
  1612. py::tuple res(tuple_size);
  1613. size_t index = 0;
  1614. for (size_t i = 0; i < args.size(); i++) {
  1615. if (py::isinstance<py::tuple>(args[i])) {
  1616. ConvertTupleArg(&res, &index, args[i]);
  1617. } else {
  1618. res[index++] = args[i];
  1619. }
  1620. }
  1621. return res;
  1622. }
  1623. py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) {
  1624. VectorRef arg_list;
  1625. py::tuple converted_args = ConvertArgs(args);
  1626. pipeline::ProcessVmArgInner(converted_args, resource_, &arg_list);
  1627. if (resource_->results().find(pipeline::kOutput) == resource_->results().end() ||
  1628. !resource_->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
  1629. MS_LOG(EXCEPTION) << "Can't find run graph func for ";
  1630. }
  1631. compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
  1632. if (run == nullptr) {
  1633. MS_LOG(EXCEPTION) << "Can't find run graph func for ";
  1634. }
  1635. std::string backend = MsContext::GetInstance()->backend_policy();
  1636. MS_LOG(DEBUG) << "Eval run" << backend;
  1637. BaseRef value = (*run)(arg_list);
  1638. MS_LOG(DEBUG) << "Run end" << value.ToString();
  1639. return BaseRefToPyData(value);
  1640. }
  1641. FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op,
  1642. const std::vector<AnfNodePtr> &weights, size_t arg_size) {
  1643. auto nparam = top_g_->parameters().size();
  1644. std::ostringstream ss;
  1645. ss << "grad{" << nparam << "}";
  1646. df_builder_->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  1647. df_builder_->debug_info()->set_name(ss.str());
  1648. auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights);
  1649. std::vector<AnfNodePtr> inputs = {NewValueNode(df)};
  1650. for (size_t i = 0; i < arg_size; ++i) {
  1651. inputs.push_back(df_builder_->parameters()[i]);
  1652. }
  1653. auto out = df_builder_->NewCNode(inputs);
  1654. df_builder_->set_output(out);
  1655. resource_->manager()->AddFuncGraph(df);
  1656. resource_->manager()->AddFuncGraph(df_builder_);
  1657. return df_builder_;
  1658. }
  1659. void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
  1660. PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args);
  1661. }
  1662. void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
  1663. PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
  1664. }
  1665. void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  1666. const py::args &args) {
  1667. PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
  1668. }
  1669. REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
  1670. (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
  1671. .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
  1672. .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
  1673. .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
  1674. .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
  1675. .def("clear", &PynativeExecutor::Clear, "pynative clear status.")
  1676. .def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""),
  1677. "Executor run function.")
  1678. .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
  1679. "Executor set grad flag.");
  1680. }));
  1681. } // namespace pynative
  1682. } // namespace mindspore