You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pynative_execute.cc 101 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/pynative/pynative_execute.h"
  17. #include <typeinfo>
  18. #include <map>
  19. #include <set>
  20. #include <memory>
  21. #include <sstream>
  22. #include <unordered_set>
  23. #include <algorithm>
  24. #include "debug/trace.h"
  25. #include "pybind_api/ir/tensor_py.h"
  26. #include "ir/param_info.h"
  27. #include "ir/anf.h"
  28. #include "ir/cell.h"
  29. #include "ir/tensor.h"
  30. #include "utils/any.h"
  31. #include "utils/utils.h"
  32. #include "utils/ms_context.h"
  33. #include "utils/context/context_extends.h"
  34. #include "utils/config_manager.h"
  35. #include "utils/convert_utils_py.h"
  36. #include "frontend/operator/ops.h"
  37. #include "frontend/operator/composite/do_signature.h"
  38. #include "pipeline/jit/parse/data_converter.h"
  39. #include "pipeline/jit/parse/resolve.h"
  40. #include "pipeline/jit/static_analysis/prim.h"
  41. #include "backend/session/session_factory.h"
  42. #include "backend/optimizer/pass/const_input_to_attr_registry.h"
  43. #include "backend/optimizer/common/helper.h"
  44. #include "pipeline/jit/action.h"
  45. #include "pipeline/pynative/base.h"
  46. #include "pybind_api/api_register.h"
  47. #include "vm/transform.h"
  48. #include "frontend/optimizer/ad/grad.h"
  49. #include "pipeline/jit/resource.h"
  50. #include "pipeline/jit/pipeline.h"
  51. #include "pipeline/jit/pass.h"
  52. #ifdef ENABLE_GE
  53. #include "pipeline/pynative/pynative_execute_ge.h"
  54. #endif
  55. #include "debug/anf_ir_dump.h"
  56. using mindspore::tensor::TensorPy;
  57. const size_t PTR_LEN = 15;
  58. // primitive unable to infer value for constant input in PyNative mode
  59. static const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
  60. "mixed_precision_cast"};
  61. static const char kOpsFunctionModelName[] = "mindspore.ops.functional";
  62. static const char kMSDtypeModelName[] = "mindspore.common.dtype";
  63. namespace mindspore::pynative {
  64. static std::shared_ptr<session::SessionBasic> session = nullptr;
  65. PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
  66. std::mutex PynativeExecutor::instance_lock_;
  67. int64_t PynativeExecutor::graph_id_ = 0;
  68. template <typename... Args>
  69. void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) {
  70. MS_EXCEPTION_IF_NULL(executor);
  71. try {
  72. (executor->*method)(args...);
  73. } catch (const py::error_already_set &ex) {
  74. // print function call stack info before release
  75. std::ostringstream oss;
  76. trace::TraceGraphEval();
  77. trace::GetEvalStackInfo(oss);
  78. // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
  79. // these info from screen, no need to open log file to find these info
  80. py::print(oss.str());
  81. MS_LOG(ERROR) << oss.str();
  82. PynativeExecutor::GetInstance()->ClearRes();
  83. // re-throw this exception to Python interpreter to handle it
  84. throw(py::error_already_set(ex));
  85. } catch (const py::type_error &ex) {
  86. PynativeExecutor::GetInstance()->ClearRes();
  87. throw py::type_error(ex);
  88. } catch (const py::value_error &ex) {
  89. PynativeExecutor::GetInstance()->ClearRes();
  90. throw py::value_error(ex);
  91. } catch (const py::index_error &ex) {
  92. PynativeExecutor::GetInstance()->ClearRes();
  93. throw py::index_error(ex);
  94. } catch (const std::exception &ex) {
  95. PynativeExecutor::GetInstance()->ClearRes();
  96. // re-throw this exception to Python interpreter to handle it
  97. throw(std::runtime_error(ex.what()));
  98. } catch (...) {
  99. PynativeExecutor::GetInstance()->ClearRes();
  100. std::string exName(abi::__cxa_current_exception_type()->name());
  101. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
  102. }
  103. }
  104. inline ValuePtr PyAttrValue(const py::object &obj) {
  105. ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
  106. if (!converted_ret) {
  107. MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
  108. }
  109. return converted_ret;
  110. }
  111. static std::string GetId(const py::object &obj) {
  112. if (py::isinstance<tensor::Tensor>(obj)) {
  113. auto tensor_ptr = py::cast<tensor::TensorPtr>(obj);
  114. return tensor_ptr->id();
  115. } else if (py::isinstance<mindspore::Type>(obj)) {
  116. auto type_ptr = py::cast<mindspore::TypePtr>(obj);
  117. return "type" + type_ptr->ToString();
  118. } else if (py::isinstance<py::str>(obj) || py::isinstance<py::int_>(obj) || py::isinstance<py::float_>(obj)) {
  119. return std::string(py::str(obj));
  120. } else if (py::isinstance<py::none>(obj)) {
  121. return "none";
  122. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  123. auto p_list = py::cast<py::tuple>(obj);
  124. string prefix = py::isinstance<py::tuple>(obj) ? "tuple:" : "list";
  125. if (p_list.empty()) {
  126. prefix = "empty";
  127. } else {
  128. std::string key;
  129. for (size_t i = 0; i < p_list.size(); ++i) {
  130. key += std::string(py::str(GetId(p_list[i]))) + ":";
  131. }
  132. prefix += key;
  133. }
  134. return prefix;
  135. }
  136. py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
  137. return py::cast<std::string>(ret);
  138. }
  139. std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
  140. std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
  141. for (size_t i = 0; i < dtypes.size(); ++i) {
  142. auto it = type_indexes.find(dtypes[i]);
  143. if (it == type_indexes.end()) {
  144. (void)type_indexes.emplace(std::make_pair(dtypes[i], std::vector<size_t>{i}));
  145. } else {
  146. it->second.emplace_back(i);
  147. }
  148. }
  149. return type_indexes;
  150. }
  151. std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
  152. const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
  153. std::map<SignatureEnumDType, TypeId> dst_type;
  154. for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
  155. auto type = it->first;
  156. auto indexes = it->second;
  157. if (type == SignatureEnumDType::kDTypeEmptyDefaultValue || indexes.size() < 2) {
  158. continue;
  159. }
  160. size_t priority = 0;
  161. TypeId max_type = TypeId::kTypeUnknown;
  162. bool has_scalar_float32 = false;
  163. bool has_scalar_int64 = false;
  164. bool has_tensor_int8 = false;
  165. for (size_t index : indexes) {
  166. if (!has_scalar_float32 && py::isinstance<py::float_>(py_args[index])) {
  167. has_scalar_float32 = true;
  168. }
  169. if (!has_scalar_int64 && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) {
  170. has_scalar_int64 = true;
  171. }
  172. auto obj = py_args[index];
  173. if (py::isinstance<tensor::Tensor>(obj)) {
  174. auto arg = py::cast<tensor::TensorPtr>(obj);
  175. TypeId arg_type_id = arg->data_type();
  176. auto type_priority = prim::type_map.find(arg_type_id);
  177. if (type_priority == prim::type_map.end()) {
  178. continue;
  179. }
  180. if (arg_type_id == kNumberTypeInt8) {
  181. has_tensor_int8 = true;
  182. }
  183. if (type_priority->second > priority) {
  184. max_type = type_priority->first;
  185. priority = type_priority->second;
  186. }
  187. }
  188. }
  189. if (max_type == TypeId::kNumberTypeBool) {
  190. if (has_scalar_int64) {
  191. max_type = TypeId::kNumberTypeInt64;
  192. }
  193. if (has_scalar_float32) {
  194. max_type = TypeId::kNumberTypeFloat32;
  195. }
  196. }
  197. if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
  198. max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
  199. max_type = TypeId::kNumberTypeFloat32;
  200. }
  201. if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
  202. max_type = TypeId::kNumberTypeInt16;
  203. }
  204. (void)dst_type.emplace(std::make_pair(type, max_type));
  205. }
  206. return dst_type;
  207. }
  208. std::string TypeIdToMsTypeStr(const TypeId &type_id) {
  209. auto type_name = type_name_map.find(type_id);
  210. if (type_name == type_name_map.end()) {
  211. MS_LOG(EXCEPTION) << "For implicit type conversion, not support convert to the type: " << TypeIdToType(type_id);
  212. }
  213. return type_name->second;
  214. }
  215. bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
  216. MS_EXCEPTION_IF_NULL(dtypes);
  217. auto signature = prim->signatures();
  218. bool has_sig_dtype = false;
  219. (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
  220. [&has_sig_dtype](const Signature &sig) {
  221. auto dtype = sig.dtype;
  222. if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
  223. has_sig_dtype = true;
  224. }
  225. return dtype;
  226. });
  227. return has_sig_dtype;
  228. }
  229. void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info,
  230. const abstract::AbstractBasePtrList &args_spec_list) {
  231. MS_LOG(DEBUG) << "Prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list);
  232. prim->BeginRecordAddAttr();
  233. AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
  234. prim->EndRecordAddAttr();
  235. op_exec_info->abstract = infer_res;
  236. MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString();
  237. }
  238. std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
  239. const std::vector<tensor::TensorPtr> &input_tensors) {
  240. MS_EXCEPTION_IF_NULL(op_exec_info);
  241. std::string graph_info;
  242. // get input tensor info
  243. for (const auto &tensor : input_tensors) {
  244. MS_EXCEPTION_IF_NULL(tensor);
  245. auto tensor_shape = tensor->shape();
  246. (void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
  247. [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
  248. (void)graph_info.append(std::to_string(tensor->data_type()) + "_");
  249. if (tensor->device_address() != nullptr) {
  250. (void)graph_info.append(
  251. std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id()) + "_");
  252. (void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format() +
  253. "_");
  254. }
  255. }
  256. // get prim and abstract info
  257. (void)graph_info.append(op_exec_info->prim_id + "_");
  258. // get attr info
  259. const auto &op_prim = op_exec_info->py_primitive;
  260. MS_EXCEPTION_IF_NULL(op_prim);
  261. const auto &attr_map = op_prim->evaluate_added_attrs();
  262. (void)std::for_each(attr_map.begin(), attr_map.end(),
  263. [&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
  264. // Add output information(shape, type id) of the operator to graph_info to solve the problem of cache missing
  265. // caused by operators like DropoutGenMask whose output is related to values of input when input shapes are
  266. // the same but values are different
  267. auto abstr = op_exec_info->abstract;
  268. MS_EXCEPTION_IF_NULL(abstr);
  269. auto build_shape = abstr->BuildShape();
  270. MS_EXCEPTION_IF_NULL(build_shape);
  271. (void)graph_info.append(build_shape->ToString() + "_");
  272. auto build_type = abstr->BuildType();
  273. MS_EXCEPTION_IF_NULL(build_type);
  274. (void)graph_info.append(std::to_string(build_type->type_id()) + "_");
  275. return graph_info;
  276. }
  277. bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
  278. const std::unordered_set<size_t> &input_attrs) {
  279. MS_EXCEPTION_IF_NULL(op_prim);
  280. auto input_names_value = op_prim->GetAttr(kAttrInputNames);
  281. if (input_names_value == nullptr) {
  282. return false;
  283. }
  284. auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
  285. if (input_index >= input_names_vec.size()) {
  286. MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
  287. }
  288. if (input_attrs.find(input_index) != input_attrs.end()) {
  289. ValuePtr value = parse::data_converter::PyDataToValue(input_object);
  290. MS_EXCEPTION_IF_NULL(value);
  291. auto input_name = input_names_vec[input_index];
  292. op_prim->AddAttr(input_name, value);
  293. return true;
  294. }
  295. return false;
  296. }
  297. void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
  298. std::vector<tensor::TensorPtr> *input_tensors) {
  299. MS_EXCEPTION_IF_NULL(op_prim);
  300. MS_EXCEPTION_IF_NULL(input_tensors);
  301. for (const auto &input_object : tuple_inputs) {
  302. if (!py::isinstance<tensor::Tensor>(input_object)) {
  303. MS_LOG(EXCEPTION) << "The input object is not a tensor!";
  304. }
  305. auto tensor = py::cast<tensor::TensorPtr>(input_object);
  306. MS_EXCEPTION_IF_NULL(tensor);
  307. input_tensors->emplace_back(tensor);
  308. }
  309. op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(tuple_inputs.size())}));
  310. }
  311. void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensors) {
  312. MS_EXCEPTION_IF_NULL(input_tensors);
  313. ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
  314. MS_EXCEPTION_IF_NULL(input_value);
  315. if (!input_value->isa<ValueTuple>()) {
  316. MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
  317. }
  318. auto value_tuple = input_value->cast<ValueTuplePtr>();
  319. MS_EXCEPTION_IF_NULL(value_tuple);
  320. tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
  321. MS_EXCEPTION_IF_NULL(tensor_ptr);
  322. input_tensors->emplace_back(tensor_ptr);
  323. }
  324. void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  325. std::vector<tensor::TensorPtr> *input_tensors, int64_t *tensor_mask) {
  326. MS_EXCEPTION_IF_NULL(op_prim);
  327. MS_EXCEPTION_IF_NULL(input_tensors);
  328. MS_EXCEPTION_IF_NULL(tensor_mask);
  329. if (!py::isinstance<py::tuple>(input_object)) {
  330. MS_LOG(EXCEPTION) << "The input should be a tuple!";
  331. }
  332. auto tuple_inputs = py::cast<py::tuple>(input_object);
  333. if (tuple_inputs.empty()) {
  334. MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
  335. }
  336. auto inputs = py::cast<py::tuple>(input_object);
  337. if (py::isinstance<tensor::Tensor>(inputs[0])) {
  338. PlantTensorTupleToVector(inputs, op_prim, input_tensors);
  339. } else {
  340. ConvertValueTupleToTensor(input_object, input_tensors);
  341. *tensor_mask = kValueNodeTensorMask;
  342. }
  343. }
  344. void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  345. std::vector<tensor::TensorPtr> *input_tensors, int64_t *tensor_mask) {
  346. MS_EXCEPTION_IF_NULL(op_prim);
  347. MS_EXCEPTION_IF_NULL(input_tensors);
  348. MS_EXCEPTION_IF_NULL(tensor_mask);
  349. tensor::TensorPtr tensor_ptr = nullptr;
  350. if (py::isinstance<tensor::Tensor>(input_object)) {
  351. tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
  352. } else if (py::isinstance<py::float_>(input_object)) {
  353. double input_value = py::cast<py::float_>(input_object);
  354. tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
  355. *tensor_mask = kValueNodeTensorMask;
  356. } else if (py::isinstance<py::int_>(input_object)) {
  357. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
  358. *tensor_mask = kValueNodeTensorMask;
  359. } else if (py::isinstance<py::array>(input_object)) {
  360. tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
  361. } else if (py::isinstance<py::list>(input_object)) {
  362. auto list_inputs = py::cast<py::list>(input_object);
  363. py::tuple tuple_inputs(list_inputs.size());
  364. for (size_t i = 0; i < tuple_inputs.size(); ++i) {
  365. tuple_inputs[i] = list_inputs[i];
  366. }
  367. ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
  368. return;
  369. } else if (py::isinstance<py::tuple>(input_object)) {
  370. ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
  371. return;
  372. } else if (py::isinstance<py::none>(input_object)) {
  373. return;
  374. } else {
  375. MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
  376. }
  377. MS_EXCEPTION_IF_NULL(tensor_ptr);
  378. input_tensors->emplace_back(tensor_ptr);
  379. }
  380. void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t> *tensors_mask,
  381. std::vector<tensor::TensorPtr> *input_tensors) {
  382. MS_EXCEPTION_IF_NULL(op_run_info);
  383. MS_EXCEPTION_IF_NULL(tensors_mask);
  384. MS_EXCEPTION_IF_NULL(input_tensors);
  385. PrimitivePtr op_prim = op_run_info->py_primitive;
  386. MS_EXCEPTION_IF_NULL(op_prim);
  387. opt::ConstInputToAttrInfoRegister reg;
  388. bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
  389. if (op_run_info->is_dynamic_shape &&
  390. dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) {
  391. MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name;
  392. reg_exist = false;
  393. }
  394. auto ms_context = MsContext::GetInstance();
  395. MS_EXCEPTION_IF_NULL(ms_context);
  396. if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
  397. if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) {
  398. reg_exist = false;
  399. }
  400. }
  401. if (op_run_info->op_name == prim::kPrimGatherD->name()) {
  402. // Gather op needs converting const input to attr on GPU device
  403. if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
  404. reg_exist = false;
  405. }
  406. }
  407. op_prim->BeginRecordAddAttr();
  408. size_t input_num = op_run_info->op_inputs.size();
  409. for (size_t index = 0; index < input_num; ++index) {
  410. // convert const input to attr
  411. if (reg_exist &&
  412. RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
  413. continue;
  414. }
  415. // convert const and tuple input to tensor
  416. int64_t tensor_mask = static_cast<int64_t>(op_run_info->inputs_mask[index]);
  417. ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
  418. // mark tensors, data : 0, weight : 1, valuenode: 2
  419. std::vector<int64_t> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
  420. tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
  421. }
  422. op_prim->EndRecordAddAttr();
  423. }
  424. BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
  425. if (utils::isa<VectorRef>(base_ref)) {
  426. auto ref_list = utils::cast<VectorRef>(base_ref);
  427. py::tuple output_tensors(ref_list.size());
  428. for (size_t i = 0; i < ref_list.size(); ++i) {
  429. auto output = TransformBaseRefListToTuple(ref_list[i]);
  430. if (utils::isa<tensor::TensorPtr>(output)) {
  431. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  432. MS_EXCEPTION_IF_NULL(tensor_ptr);
  433. output_tensors[i] = tensor_ptr;
  434. } else if (utils::isa<PyObjectRef>(output)) {
  435. py::object obj = utils::cast<PyObjectRef>(output).object_;
  436. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  437. output_tensors[i] = tensor_tuple;
  438. } else {
  439. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  440. }
  441. }
  442. return std::make_shared<PyObjectRef>(output_tensors);
  443. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  444. return base_ref;
  445. } else {
  446. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  447. }
  448. }
  449. size_t GetTupleSize(const py::tuple &args) {
  450. size_t count = 0;
  451. for (size_t i = 0; i < args.size(); i++) {
  452. if (py::isinstance<py::tuple>(args[i])) {
  453. count += GetTupleSize(args[i]);
  454. } else {
  455. count += 1;
  456. }
  457. }
  458. return count;
  459. }
  460. void ConvertTupleArg(py::tuple *res, size_t *index, const py::tuple &arg) {
  461. for (size_t i = 0; i < arg.size(); i++) {
  462. if (py::isinstance<py::tuple>(arg[i])) {
  463. ConvertTupleArg(res, index, arg[i]);
  464. } else {
  465. (*res)[(*index)++] = arg[i];
  466. }
  467. }
  468. }
  469. py::tuple ConvertArgs(const py::tuple &args) {
  470. size_t tuple_size = GetTupleSize(args);
  471. py::tuple res(tuple_size);
  472. size_t index = 0;
  473. for (size_t i = 0; i < args.size(); i++) {
  474. if (py::isinstance<py::tuple>(args[i])) {
  475. ConvertTupleArg(&res, &index, args[i]);
  476. } else {
  477. res[index++] = args[i];
  478. }
  479. }
  480. return res;
  481. }
  482. void ClearPyNativeSession() { session = nullptr; }
  483. PynativeExecutor::~PynativeExecutor() {
  484. MS_LOG(DEBUG) << "PynativeExecutor destructor";
  485. ClearRes();
  486. }
  487. py::tuple RunOp(const py::args &args) {
  488. auto executor = PynativeExecutor::GetInstance();
  489. MS_EXCEPTION_IF_NULL(executor);
  490. OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args);
  491. MS_EXCEPTION_IF_NULL(op_exec_info);
  492. MS_LOG(DEBUG) << "RunOp name: " << op_exec_info->op_name << " start, args: " << args.size();
  493. try {
  494. return executor->RunOpInner(op_exec_info);
  495. } catch (const py::error_already_set &ex) {
  496. executor->ClearRes();
  497. // re-throw this exception to Python interpreter to handle it
  498. throw(py::error_already_set(ex));
  499. } catch (const py::type_error &ex) {
  500. executor->ClearRes();
  501. throw py::type_error(ex);
  502. } catch (const py::value_error &ex) {
  503. executor->ClearRes();
  504. throw py::value_error(ex);
  505. } catch (const py::index_error &ex) {
  506. executor->ClearRes();
  507. throw py::index_error(ex);
  508. } catch (const std::exception &ex) {
  509. executor->ClearRes();
  510. // re-throw this exception to Python interpreter to handle it
  511. throw(std::runtime_error(ex.what()));
  512. } catch (...) {
  513. executor->ClearRes();
  514. std::string exName(abi::__cxa_current_exception_type()->name());
  515. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
  516. }
  517. }
  518. py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
  519. MS_EXCEPTION_IF_NULL(op_exec_info);
  520. if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
  521. return RunOpWithInitBackendPolicy(op_exec_info);
  522. }
  523. // make cnode for building grad graph if grad flag is set.
  524. abstract::AbstractBasePtrList args_spec_list;
  525. std::vector<bool> op_masks;
  526. auto cnode = MakeCNode(op_exec_info, &op_masks, &args_spec_list);
  527. op_exec_info->inputs_mask = op_masks;
  528. // get output abstract info
  529. bool is_find = false;
  530. GetOpOutputAbstract(op_exec_info, args_spec_list, &is_find);
  531. MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString();
  532. // infer output value for const prim
  533. auto prim = op_exec_info->py_primitive;
  534. MS_EXCEPTION_IF_NULL(prim);
  535. py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
  536. if (!output["value"].is_none()) {
  537. py::tuple value_ret(1);
  538. value_ret[0] = output["value"];
  539. return value_ret;
  540. }
  541. if (prim->is_const_prim()) {
  542. py::tuple value_ret(1);
  543. value_ret[0] = "";
  544. return value_ret;
  545. }
  546. // add output abstract info into cache
  547. if (!is_find && !op_exec_info->is_dynamic_shape) {
  548. // const_value need infer every step
  549. auto &out = prim_abs_list_[prim->id()];
  550. out[args_spec_list].abs = op_exec_info->abstract;
  551. out[args_spec_list].attrs = prim->evaluate_added_attrs();
  552. MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
  553. }
  554. // run op with selected backend
  555. auto result = RunOpWithInitBackendPolicy(op_exec_info);
  556. py::object out_real = result;
  557. if (result.size() == 1) {
  558. MS_LOG(DEBUG) << "Output size is 1";
  559. out_real = result[0];
  560. }
  561. // update output abstract for cnode
  562. if (cnode != nullptr) {
  563. cnode->set_abstract(op_exec_info->abstract);
  564. }
  565. std::string obj_id = GetId(out_real);
  566. node_abs_map_[obj_id] = op_exec_info->abstract;
  567. // save info for building grad graph
  568. SaveOutputNodeMap(obj_id, out_real, cnode);
  569. SaveAllResult(op_exec_info, cnode, out_real);
  570. // Update the abstract and device address of value node with tensor in grad graph
  571. UpdateAbstractAndDeviceAddress(op_exec_info, out_real);
  572. return result;
  573. }
  574. OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
  575. if (args.size() != PY_ARGS_NUM) {
  576. MS_LOG(ERROR) << "Three args are needed by RunOp";
  577. return nullptr;
  578. }
  579. auto op_exec_info = std::make_shared<OpExecInfo>();
  580. auto op_name = py::cast<std::string>(args[PY_NAME]);
  581. op_exec_info->op_name = op_name;
  582. if (grad_flag()) {
  583. int64_t graph_id = graph_id_;
  584. auto resource = GetResource();
  585. if (resource != nullptr) {
  586. MS_LOG(DEBUG) << "Get resource ptr " << resource.get();
  587. auto it = resource->results().find(pipeline::kPynativeGraphId);
  588. if (it != resource->results().end()) {
  589. graph_id = it->second.cast<int64_t>();
  590. }
  591. }
  592. op_exec_info->op_index = std::to_string(graph_id) + op_name + std::to_string(op_index_map_[op_name]);
  593. op_index_map_[op_name]++;
  594. }
  595. auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
  596. MS_EXCEPTION_IF_NULL(prim);
  597. if (!prim->HasPyObj()) {
  598. MS_LOG(EXCEPTION) << "Pyobj is empty";
  599. }
  600. op_exec_info->prim_id = GetId(prim->GetPyObj());
  601. op_exec_info->py_primitive = prim;
  602. op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
  603. op_exec_info->op_inputs = args[PY_INPUTS];
  604. return op_exec_info;
  605. }
  606. AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
  607. abstract::AbstractBasePtrList *args_spec_list) {
  608. MS_EXCEPTION_IF_NULL(op_masks);
  609. MS_EXCEPTION_IF_NULL(args_spec_list);
  610. MS_EXCEPTION_IF_NULL(op_exec_info);
  611. auto prim = op_exec_info->py_primitive;
  612. std::vector<AnfNodePtr> inputs;
  613. inputs.emplace_back(NewValueNode(prim));
  614. const auto &signature = prim->signatures();
  615. auto sig_size = signature.size();
  616. auto size = op_exec_info->op_inputs.size();
  617. // ignore signature for cast op
  618. if (sig_size > 0 && sig_size != size) {
  619. MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
  620. << "inputs size " << sig_size;
  621. }
  622. if (op_exec_info->op_name != prim::kPrimCast->name()) {
  623. RunParameterAutoMixPrecisionCast(op_exec_info);
  624. }
  625. MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag();
  626. for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
  627. abstract::AbstractBasePtr abs = nullptr;
  628. const auto &obj = op_exec_info->op_inputs[i];
  629. auto id = GetId(obj);
  630. auto it = node_abs_map_.find(id);
  631. if (it != node_abs_map_.end()) {
  632. abs = it->second;
  633. }
  634. bool op_mask = false;
  635. if (py::isinstance<tensor::MetaTensor>(obj)) {
  636. auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
  637. if (meta_tensor) {
  638. op_mask = meta_tensor->is_parameter();
  639. }
  640. }
  641. MS_LOG(DEBUG) << "Gen args i " << i << " op_mask " << op_mask;
  642. (*op_masks).emplace_back(op_mask);
  643. if (need_construct_graph()) {
  644. AnfNodePtr input_node = nullptr;
  645. if (!graph_info_map_.empty() && !top_cell_list_.empty()) {
  646. input_node = GetInput(obj, op_mask);
  647. }
  648. // update abstract
  649. if (input_node != nullptr && input_node->abstract() != nullptr) {
  650. abs = input_node->abstract();
  651. }
  652. if (input_node != nullptr) {
  653. inputs.emplace_back(input_node);
  654. }
  655. }
  656. auto const_input_index = prim->get_const_input_indexes();
  657. bool have_const_input = !const_input_index.empty();
  658. bool is_const_prim = prim->is_const_prim();
  659. MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
  660. << prim->is_const_prim();
  661. bool is_const_input =
  662. have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
  663. if (abs == nullptr || is_const_prim || is_const_input) {
  664. MS_LOG(DEBUG) << "MakeCnode get node no in map " << id;
  665. ValuePtr input_value = PyAttrValue(obj);
  666. abs = input_value->ToAbstract();
  667. if (!is_const_prim && !is_const_input) {
  668. auto config = abstract::AbstractBase::kBroadenTensorOnly;
  669. abs = abs->Broaden(config);
  670. MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
  671. }
  672. node_abs_map_[id] = abs;
  673. }
  674. (*args_spec_list).emplace_back(abs);
  675. }
  676. CNodePtr cnode = nullptr;
  677. if (need_construct_graph()) {
  678. MS_EXCEPTION_IF_NULL(curr_g_);
  679. cnode = curr_g_->NewCNode(inputs);
  680. MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << " new cnode is " << cnode->DebugString(4);
  681. }
  682. return cnode;
  683. }
  684. void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
  685. const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) {
  686. MS_EXCEPTION_IF_NULL(is_find);
  687. MS_EXCEPTION_IF_NULL(op_exec_info);
  688. *is_find = false;
  689. auto op_name = op_exec_info->op_name;
  690. auto prim = op_exec_info->py_primitive;
  691. MS_EXCEPTION_IF_NULL(prim);
  692. if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
  693. auto abs_list = prim_abs_list_[prim->id()];
  694. MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
  695. if (abs_list.find(args_spec_list) != abs_list.end()) {
  696. MS_LOG(DEBUG) << "Match prim ok " << op_name;
  697. op_exec_info->abstract = abs_list[args_spec_list].abs;
  698. prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
  699. *is_find = true;
  700. }
  701. }
  702. if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) {
  703. // use python infer method
  704. if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) {
  705. PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);
  706. }
  707. }
  708. // get output dynamic shape info
  709. auto py_abstract = op_exec_info->abstract;
  710. MS_EXCEPTION_IF_NULL(py_abstract);
  711. auto py_shape = py_abstract->BuildShape();
  712. MS_EXCEPTION_IF_NULL(py_shape);
  713. auto py_shape_info = py_shape->ToString();
  714. if (py_shape_info.find("-1") != string::npos) {
  715. auto c_abstract = abstract::CppInferShape(prim, args_spec_list);
  716. MS_EXCEPTION_IF_NULL(c_abstract);
  717. auto c_shape = c_abstract->BuildShape();
  718. MS_EXCEPTION_IF_NULL(c_shape);
  719. auto c_shape_info = c_shape->ToString();
  720. MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info;
  721. if (c_shape_info.find("-1") != string::npos) {
  722. op_exec_info->is_dynamic_shape = true;
  723. }
  724. }
  725. }
  726. py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
  727. size_t index) {
  728. py::tuple cast_args(3);
  729. cast_args[PY_PRIM] = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
  730. cast_args[PY_NAME] = prim::kPrimCast->name();
  731. std::string dst_type_str = TypeIdToMsTypeStr(type_id);
  732. py::object dst_type = parse::python_adapter::GetPyFn(kMSDtypeModelName, dst_type_str);
  733. py::tuple inputs(2);
  734. inputs[0] = arg;
  735. inputs[1] = dst_type;
  736. cast_args[PY_INPUTS] = inputs;
  737. auto op_exec = GenerateOpExecInfo(cast_args);
  738. op_exec->is_mixed_precision_cast = true;
  739. op_exec->next_op_name = op_name;
  740. op_exec->next_input_index = index;
  741. return RunOpInner(op_exec)[0];
  742. }
  743. py::object PynativeExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name,
  744. size_t index) {
  745. MS_EXCEPTION_IF_NULL(is_cast);
  746. auto tensor = py::cast<tensor::TensorPtr>(obj);
  747. auto cast_type = tensor->cast_dtype();
  748. py::object cast_output = obj;
  749. if (cast_type != nullptr) {
  750. auto source_element = tensor->Dtype();
  751. if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
  752. MS_LOG(DEBUG) << "Cast to " << cast_type->ToString();
  753. *is_cast = true;
  754. return DoAutoCast(obj, cast_type->type_id(), op_name, index);
  755. }
  756. }
  757. return cast_output;
  758. }
  759. py::object PynativeExecutor::DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple,
  760. const std::string &op_name, size_t index) {
  761. MS_EXCEPTION_IF_NULL(is_cast);
  762. auto tuple_size = static_cast<int64_t>(tuple.size());
  763. py::tuple result(tuple_size);
  764. for (int64_t i = 0; i < tuple_size; i++) {
  765. if (py::isinstance<tensor::MetaTensor>(tuple[i])) {
  766. MS_LOG(DEBUG) << "Call cast for item " << i;
  767. result[i] = DoParamMixPrecisionCast(is_cast, tuple[i], op_name, index);
  768. } else if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) {
  769. result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i], op_name, index);
  770. } else {
  771. result[i] = tuple[i];
  772. }
  773. }
  774. return std::move(result);
  775. }
  776. void PynativeExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
  777. const std::vector<SignatureEnumDType> &dtypes,
  778. const OpExecInfoPtr &op_exec_info) {
  779. const auto &signature = prim->signatures();
  780. auto &out_args = op_exec_info->op_inputs;
  781. for (size_t i = 0; i < out_args.size(); ++i) {
  782. // No need to implicit cast if no dtype.
  783. if (dtypes.empty() || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
  784. continue;
  785. }
  786. auto it = dst_type.find(dtypes[i]);
  787. if (it == dst_type.end() || it->second == kTypeUnknown) {
  788. continue;
  789. }
  790. MS_LOG(DEBUG) << "Check inputs " << i;
  791. auto obj = out_args[i];
  792. auto sig = SignatureEnumRW::kRWDefault;
  793. if (!signature.empty()) {
  794. sig = signature[i].rw;
  795. }
  796. bool is_parameter = false;
  797. TypeId arg_type_id = kTypeUnknown;
  798. if (py::isinstance<tensor::MetaTensor>(obj)) {
  799. auto arg = py::cast<tensor::MetaTensorPtr>(obj);
  800. if (arg->is_parameter()) {
  801. is_parameter = true;
  802. MS_LOG(DEBUG) << "Parameter is read " << i;
  803. }
  804. arg_type_id = arg->data_type();
  805. }
  806. // implicit cast
  807. bool is_same_type = false;
  808. if (arg_type_id != kTypeUnknown) {
  809. is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
  810. }
  811. if (sig == SignatureEnumRW::kRWWrite) {
  812. if (!is_parameter) {
  813. prim::RaiseExceptionForCheckParameter(prim->name(), i, "not");
  814. }
  815. if (arg_type_id != kTypeUnknown) {
  816. if (!is_same_type) {
  817. prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id),
  818. TypeIdToMsTypeStr(it->second));
  819. }
  820. }
  821. }
  822. if (is_same_type) {
  823. continue;
  824. }
  825. if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) {
  826. MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i
  827. << "th input is a not support implicit conversion type: "
  828. << py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is "
  829. << py::cast<py::str>(obj) << ".";
  830. }
  831. py::object cast_output = DoAutoCast(out_args[i], it->second, op_exec_info->op_name, i);
  832. out_args[i] = cast_output;
  833. }
  834. }
  835. void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info) {
  836. size_t size = op_exec_info->op_inputs.size();
  837. auto prim = op_exec_info->py_primitive;
  838. MS_EXCEPTION_IF_NULL(prim);
  839. const auto &signature = prim->signatures();
  840. for (size_t i = 0; i < size; i++) {
  841. auto obj = op_exec_info->op_inputs[i];
  842. auto sig = SignatureEnumRW::kRWDefault;
  843. if (!signature.empty()) {
  844. sig = signature[i].rw;
  845. }
  846. MS_LOG(DEBUG) << "Check mix precision " << op_exec_info->op_name << " input " << i << " "
  847. << std::string(py::repr(obj));
  848. // mix precision for non param
  849. bool is_cast = false;
  850. py::object cast_output;
  851. if (py::isinstance<tensor::MetaTensor>(obj)) {
  852. auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
  853. if (meta_tensor && meta_tensor->is_parameter()) {
  854. if (sig != SignatureEnumRW::kRWRead) {
  855. continue;
  856. }
  857. }
  858. // redundant cast call if the tensor is a const Tensor.
  859. cast_output = DoParamMixPrecisionCast(&is_cast, obj, prim->name(), i);
  860. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  861. // mix precision for tuple inputs
  862. cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj, prim->name(), i);
  863. }
  864. if (is_cast) {
  865. op_exec_info->op_inputs[i] = cast_output;
  866. }
  867. }
  868. std::vector<SignatureEnumDType> dtypes;
  869. bool has_dtype_sig = GetSignatureType(prim, &dtypes);
  870. std::map<SignatureEnumDType, TypeId> dst_types;
  871. if (has_dtype_sig) {
  872. // fetch info for implicit cast
  873. auto type_indexes = GetTypeIndex(dtypes);
  874. dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
  875. }
  876. MS_LOG(DEBUG) << "Do signature for " << op_exec_info->op_name;
  877. DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
  878. }
  879. AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
  880. AnfNodePtr node = nullptr;
  881. std::string obj_id = GetId(obj);
  882. if (op_mask) {
  883. MS_LOG(DEBUG) << "Cell parameters(weights)";
  884. // get the parameter name from parameter object
  885. auto name_attr = parse::python_adapter::GetPyObjAttr(obj, "name");
  886. if (py::isinstance<py::none>(name_attr)) {
  887. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  888. }
  889. auto param_name = py::cast<std::string>(name_attr);
  890. auto df_builder = GetDfbuilder();
  891. MS_EXCEPTION_IF_NULL(df_builder);
  892. if (graph_info_map_.at(df_builder).params.find(obj_id) == graph_info_map_.at(df_builder).params.end()) {
  893. auto free_param = df_builder->add_parameter();
  894. free_param->set_name(param_name);
  895. free_param->debug_info()->set_name(param_name);
  896. auto value = py::cast<tensor::TensorPtr>(obj);
  897. free_param->set_default_param(value);
  898. MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
  899. SetParamNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
  900. SetParamNodeMapInGraphInfoMap(curr_g_, obj_id, free_param);
  901. SetNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
  902. SetNodeMapInGraphInfoMap(curr_g_, obj_id, free_param);
  903. return free_param;
  904. }
  905. node = graph_info_map_.at(df_builder).node_map[obj_id].first;
  906. MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id;
  907. return node;
  908. }
  909. if (graph_info_map_.at(curr_g_).node_map.find(obj_id) != graph_info_map_.at(curr_g_).node_map.end()) {
  910. // op(x, y)
  911. // out = op(op1(x, y))
  912. // out = op(cell1(x, y))
  913. // out = op(cell1(x, y)[0])
  914. node = GetObjNode(obj, obj_id);
  915. } else if (py::isinstance<py::tuple>(obj)) {
  916. // out = op((x, y))
  917. // out = cell((x, y))
  918. auto tuple = obj.cast<py::tuple>();
  919. // cell((1,2)): support not mix (scalar, tensor)
  920. if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) {
  921. return MakeValueNode(obj, obj_id);
  922. }
  923. std::vector<AnfNodePtr> args;
  924. args.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  925. auto tuple_size = tuple.size();
  926. for (size_t i = 0; i < tuple_size; i++) {
  927. args.emplace_back(GetInput(tuple[i], false));
  928. }
  929. auto cnode = curr_g_->NewCNode(args);
  930. SetNodeMapInGraphInfoMap(curr_g_, GetId(obj), cnode);
  931. node = cnode;
  932. } else {
  933. node = MakeValueNode(obj, obj_id);
  934. }
  935. node == nullptr ? MS_LOG(DEBUG) << "Get node is nullptr"
  936. : MS_LOG(DEBUG) << "Get input node " << node->ToString() << " " << obj_id;
  937. return node;
  938. }
  939. void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real) {
  940. MS_EXCEPTION_IF_NULL(op_exec_info);
  941. if (!grad_flag()) {
  942. return;
  943. }
  944. auto op_index = op_exec_info->op_index;
  945. auto output_value = PyAttrValue(out_real);
  946. MS_EXCEPTION_IF_NULL(output_value);
  947. std::vector<tensor::TensorPtr> output_tensors;
  948. TensorValueToTensor(output_value, &output_tensors);
  949. if (op_index_with_tensor_id_.find(op_index) == op_index_with_tensor_id_.end()) {
  950. // first step
  951. std::for_each(output_tensors.begin(), output_tensors.end(), [&](const tensor::TensorPtr &tensor) {
  952. op_index_with_tensor_id_[op_index].emplace_back(tensor->id());
  953. });
  954. return;
  955. }
  956. auto ms_context = MsContext::GetInstance();
  957. auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  958. const auto &tensor_id_list = op_index_with_tensor_id_[op_index];
  959. for (size_t i = 0; i < tensor_id_list.size(); ++i) {
  960. auto tensor_id = tensor_id_list[i];
  961. if (tensor_id_with_tensor_.find(tensor_id) != tensor_id_with_tensor_.end()) {
  962. auto &new_tensor = output_tensors[i];
  963. auto &tensors_in_value_node = tensor_id_with_tensor_[tensor_id];
  964. std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) {
  965. MS_LOG(DEBUG) << "Debug address: Replace forward old tensor obj " << tensor.get() << ", tensor id "
  966. << tensor->id() << ", device address " << tensor->device_address().get()
  967. << " with New tensor obj " << new_tensor.get() << ", tensor id " << new_tensor->id()
  968. << ", device address " << new_tensor->device_address().get();
  969. tensor->set_shape(new_tensor->shape());
  970. tensor->set_data_type(new_tensor->data_type());
  971. if (target != kCPUDevice) {
  972. tensor->set_device_address(new_tensor->device_address());
  973. } else {
  974. auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
  975. auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
  976. auto old_ptr = old_device_address->GetMutablePtr();
  977. auto new_ptr = new_device_address->GetPtr();
  978. MS_EXCEPTION_IF_NULL(old_ptr);
  979. MS_EXCEPTION_IF_NULL(new_ptr);
  980. auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize());
  981. if (ret != EOK) {
  982. MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret;
  983. }
  984. }
  985. });
  986. }
  987. }
  988. }
  989. void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) {
  990. MS_EXCEPTION_IF_NULL(resource);
  991. tensor_id_with_tensor_.clear();
  992. const auto &func_graph = resource->func_graph();
  993. const auto &value_node_list = func_graph->value_nodes();
  994. for (const auto &elem : value_node_list) {
  995. auto value_node = elem.first->cast<ValueNodePtr>();
  996. MS_EXCEPTION_IF_NULL(value_node);
  997. std::vector<tensor::TensorPtr> tensors;
  998. TensorValueToTensor(value_node->value(), &tensors);
  999. for (const auto &tensor : tensors) {
  1000. if (tensor->device_address() != nullptr) {
  1001. tensor_id_with_tensor_[tensor->id()].emplace_back(tensor);
  1002. MS_LOG(DEBUG) << "Debug address: Save forward tensor obj " << tensor.get() << ", tensor id " << tensor->id()
  1003. << ", device address " << tensor->device_address().get();
  1004. }
  1005. }
  1006. }
  1007. }
  1008. void PynativeExecutor::CleanTensorsInValueNode() {
  1009. // Only need clean in ms backend policy and session should not be nullptr in ms backend.
  1010. if (session == nullptr) {
  1011. return;
  1012. }
  1013. auto useless_tensors = std::make_shared<std::vector<tensor::TensorPtr>>();
  1014. for (const auto &id_tensor_pair : tensor_id_with_tensor_) {
  1015. std::copy(id_tensor_pair.second.begin(), id_tensor_pair.second.end(), std::back_inserter(*useless_tensors));
  1016. }
  1017. session->CleanUselessTensors(useless_tensors);
  1018. }
  1019. AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
  1020. auto &out = graph_info_map_.at(curr_g_).node_map[obj_id];
  1021. if (out.second.size() == 1 && out.second[0] == -1) {
  1022. return out.first;
  1023. }
  1024. MS_LOG(DEBUG) << "Output size " << out.second.size();
  1025. // Params node
  1026. if (graph_info_map_.at(curr_g_).params.find(obj_id) != graph_info_map_.at(curr_g_).params.end()) {
  1027. auto para_node = out.first;
  1028. for (auto &idx : out.second) {
  1029. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node,
  1030. NewValueNode(idx)};
  1031. para_node = curr_g_->NewCNode(tuple_get_item_inputs);
  1032. }
  1033. return para_node;
  1034. }
  1035. // Normal node
  1036. auto node = out.first->cast<CNodePtr>();
  1037. auto abs = node->abstract();
  1038. ValuePtr out_obj = nullptr;
  1039. if (node->forward().first != nullptr) {
  1040. out_obj = node->forward().first;
  1041. } else {
  1042. out_obj = PyAttrValue(obj);
  1043. }
  1044. for (auto &idx : out.second) {
  1045. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
  1046. node = curr_g_->NewCNode(tuple_get_item_inputs);
  1047. if (out_obj->isa<ValueTuple>()) {
  1048. node->add_input_value(out_obj, "");
  1049. node->add_input_value(MakeValue(idx), "");
  1050. out_obj = (*out_obj->cast<ValueTuplePtr>())[idx];
  1051. node->set_forward(out_obj, "");
  1052. }
  1053. if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
  1054. auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[idx];
  1055. MS_LOG(DEBUG) << "Set tuple getitem abs " << prim_abs->ToString();
  1056. node->set_abstract(prim_abs);
  1057. }
  1058. }
  1059. if (node->abstract() != nullptr) {
  1060. node_abs_map_[obj_id] = node->abstract();
  1061. }
  1062. MS_LOG(DEBUG) << "GetObjNode output " << node->DebugString(6);
  1063. return node;
  1064. }
  1065. AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) {
  1066. ValuePtr converted_ret = nullptr;
  1067. parse::ConvertData(obj, &converted_ret);
  1068. auto node = NewValueNode(converted_ret);
  1069. SetNodeMapInGraphInfoMap(curr_g_, obj_id, node);
  1070. return node;
  1071. }
  1072. void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real,
  1073. const AnfNodePtr &cnode) {
  1074. if (!need_construct_graph()) {
  1075. MS_LOG(DEBUG) << "No need save output";
  1076. return;
  1077. }
  1078. MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString(4) << " id " << obj_id;
  1079. if (py::isinstance<py::tuple>(out_real)) {
  1080. auto value = py::cast<py::tuple>(out_real);
  1081. auto size = static_cast<int64_t>(value.size());
  1082. if (size > 1) {
  1083. for (int64_t i = 0; i < size; ++i) {
  1084. auto value_id = GetId(value[i]);
  1085. SetNodeMapInGraphInfoMap(curr_g_, value_id, cnode, i);
  1086. }
  1087. }
  1088. }
  1089. SetNodeMapInGraphInfoMap(curr_g_, obj_id, cnode);
  1090. SetPyObjInGraphInfoMap(curr_g_, obj_id);
  1091. }
  1092. void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node,
  1093. const py::object &out_real) {
  1094. if (!grad_flag() || node == nullptr) {
  1095. return;
  1096. }
  1097. MS_EXCEPTION_IF_NULL(op_exec_info);
  1098. auto cnode = node->cast<CNodePtr>();
  1099. MS_EXCEPTION_IF_NULL(cnode);
  1100. // save input object
  1101. size_t size = op_exec_info->op_inputs.size();
  1102. for (size_t i = 0; i < size; i++) {
  1103. auto obj = op_exec_info->op_inputs[i];
  1104. auto obj_id = GetId(obj);
  1105. if (obj_to_forward_id_.find(obj_id) != obj_to_forward_id_.end()) {
  1106. cnode->add_input_value(PyAttrValue(obj), obj_to_forward_id_[obj_id]);
  1107. } else {
  1108. cnode->add_input_value(nullptr, "");
  1109. }
  1110. }
  1111. // save output object
  1112. auto output_value = PyAttrValue(out_real);
  1113. MS_EXCEPTION_IF_NULL(output_value);
  1114. cnode->set_forward(output_value, op_exec_info->op_index);
  1115. auto out_id = GetId(out_real);
  1116. if (py::isinstance<py::tuple>(out_real)) {
  1117. auto tuple_item = py::cast<py::tuple>(out_real);
  1118. for (size_t i = 0; i < tuple_item.size(); i++) {
  1119. auto tuple_item_id = GetId(tuple_item[i]);
  1120. obj_to_forward_id_[tuple_item_id] = op_exec_info->op_index;
  1121. }
  1122. }
  1123. obj_to_forward_id_[out_id] = op_exec_info->op_index;
  1124. }
  1125. void PynativeExecutor::GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map) {
  1126. if (t_map == nullptr) {
  1127. return;
  1128. }
  1129. for (size_t i = 0; i < tuple->size(); i++) {
  1130. ValuePtr tuple_i = (*tuple)[i];
  1131. if (tuple_i->isa<tensor::Tensor>()) {
  1132. auto t = tuple_i->cast<tensor::TensorPtr>();
  1133. (*t_map)[t->id()] = t;
  1134. } else if (tuple_i->isa<ValueTuple>()) {
  1135. GenTupleMap(tuple_i->cast<ValueTuplePtr>(), t_map);
  1136. }
  1137. }
  1138. MS_LOG(DEBUG) << "End GenTupleMap " << tuple->ToString();
  1139. }
  1140. ValuePtr PynativeExecutor::CleanTupleAddr(const ValueTuplePtr &tuple) {
  1141. std::vector<ValuePtr> value_list;
  1142. for (size_t i = 0; i < tuple->size(); i++) {
  1143. ValuePtr tuple_i = (*tuple)[i];
  1144. if (tuple_i->isa<tensor::Tensor>()) {
  1145. auto t = tuple_i->cast<tensor::TensorPtr>();
  1146. auto new_tensor = std::make_shared<tensor::Tensor>(*t);
  1147. new_tensor->set_device_address(nullptr);
  1148. value_list.emplace_back(new_tensor);
  1149. } else if (tuple_i->isa<ValueTuple>()) {
  1150. value_list.emplace_back(CleanTupleAddr(tuple_i->cast<ValueTuplePtr>()));
  1151. } else {
  1152. MS_LOG(DEBUG) << "Tuple[i] value " << tuple_i->ToString();
  1153. value_list.emplace_back(tuple_i);
  1154. }
  1155. }
  1156. MS_LOG(DEBUG) << "End CleanTupleAddr";
  1157. return std::make_shared<ValueTuple>(value_list);
  1158. }
  1159. py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) {
  1160. auto backend_policy = InitEnv(op_exec_info);
  1161. PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
  1162. // returns a null py::tuple on error
  1163. py::tuple err_ret(0);
  1164. py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
  1165. if (status != PYNATIVE_SUCCESS) {
  1166. MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name;
  1167. return err_ret;
  1168. }
  1169. MS_LOG(DEBUG) << "RunOp end";
  1170. return result;
  1171. }
  1172. MsBackendPolicy PynativeExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) {
  1173. MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
  1174. parse::python_adapter::set_python_env_flag(true);
  1175. MsBackendPolicy backend_policy;
  1176. #if (!defined ENABLE_GE)
  1177. auto ms_context = MsContext::GetInstance();
  1178. MS_EXCEPTION_IF_NULL(ms_context);
  1179. if (!context::IsTsdOpened(ms_context)) {
  1180. if (!context::OpenTsd(ms_context)) {
  1181. MS_LOG(EXCEPTION) << "Open tsd failed";
  1182. }
  1183. }
  1184. if (ms_context->backend_policy() == "ms") {
  1185. backend_policy = kMsBackendMsPrior;
  1186. } else {
  1187. backend_policy = kMsBackendVmOnly;
  1188. }
  1189. #else
  1190. auto ms_context = MsContext::GetInstance();
  1191. MS_EXCEPTION_IF_NULL(ms_context);
  1192. context::PynativeInitGe(ms_context);
  1193. backend_policy = kMsBackendGeOnly;
  1194. #endif
  1195. if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
  1196. backend_policy = kMsBackendVmOnly;
  1197. }
  1198. return backend_policy;
  1199. }
  1200. py::object PynativeExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
  1201. PynativeStatusCode *const status) {
  1202. MS_EXCEPTION_IF_NULL(status);
  1203. py::object result;
  1204. switch (backend_policy) {
  1205. case kMsBackendVmOnly: {
  1206. // use vm only
  1207. MS_LOG(INFO) << "RunOp use VM only backend";
  1208. result = RunOpInVM(op_exec_info, status);
  1209. break;
  1210. }
  1211. case kMsBackendGePrior: {
  1212. #ifdef ENABLE_GE
  1213. // use GE first, use vm when GE fails
  1214. MS_LOG(INFO) << "RunOp use GE first backend";
  1215. result = RunOpInGE(op_exec_info, status);
  1216. if (*status != PYNATIVE_SUCCESS) {
  1217. result = RunOpInVM(op_exec_info, status);
  1218. }
  1219. #endif
  1220. break;
  1221. }
  1222. case kMsBackendMsPrior: {
  1223. // use Ms fisrt,use others when ms failed
  1224. MS_LOG(INFO) << "RunOp use Ms first backend";
  1225. result = RunOpInMs(op_exec_info, status);
  1226. if (*status != PYNATIVE_SUCCESS) {
  1227. MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
  1228. }
  1229. break;
  1230. }
  1231. default:
  1232. MS_LOG(ERROR) << "No backend configured for run op";
  1233. }
  1234. return result;
  1235. }
  1236. py::object PynativeExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  1237. MS_LOG(INFO) << "RunOpInVM start";
  1238. MS_EXCEPTION_IF_NULL(status);
  1239. MS_EXCEPTION_IF_NULL(op_exec_info);
  1240. MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
  1241. auto &op_inputs = op_exec_info->op_inputs;
  1242. if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf" ||
  1243. op_exec_info->op_name == "stop_gradient") {
  1244. py::tuple result(op_inputs.size());
  1245. for (size_t i = 0; i < op_inputs.size(); i++) {
  1246. py::object input = op_inputs[i];
  1247. auto input_obj_id = GetId(input);
  1248. auto tensor = py::cast<tensor::TensorPtr>(input);
  1249. MS_EXCEPTION_IF_NULL(tensor);
  1250. if (obj_to_forward_id_.find(input_obj_id) == obj_to_forward_id_.end() &&
  1251. op_exec_info->op_name == "HookBackward") {
  1252. // the input object is not a output of forward cnode, eg: parameter
  1253. result[i] = tensor;
  1254. } else {
  1255. // the input object is a output of forward cnode
  1256. auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
  1257. new_tensor->set_device_address(tensor->device_address());
  1258. new_tensor->set_sync_status(tensor->sync_status());
  1259. result[i] = new_tensor;
  1260. }
  1261. }
  1262. *status = PYNATIVE_SUCCESS;
  1263. MS_LOG(INFO) << "RunOpInVM end";
  1264. return std::move(result);
  1265. }
  1266. auto primitive = op_exec_info->py_primitive;
  1267. MS_EXCEPTION_IF_NULL(primitive);
  1268. auto result = primitive->RunPyComputeFunction(op_inputs);
  1269. if (py::isinstance<py::none>(result)) {
  1270. MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
  1271. *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
  1272. py::tuple err_ret(0);
  1273. return std::move(err_ret);
  1274. }
  1275. // execute op
  1276. py::tuple tuple_result = py::make_tuple(result);
  1277. *status = PYNATIVE_SUCCESS;
  1278. MS_LOG(INFO) << "RunOpInVM end";
  1279. return std::move(tuple_result);
  1280. }
  1281. py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  1282. MS_EXCEPTION_IF_NULL(op_exec_info);
  1283. MS_EXCEPTION_IF_NULL(status);
  1284. MS_LOG(INFO) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms";
  1285. auto ms_context = MsContext::GetInstance();
  1286. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
  1287. if (session == nullptr) {
  1288. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  1289. session = session::SessionFactory::Get().Create(device_target);
  1290. MS_EXCEPTION_IF_NULL(session);
  1291. session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));
  1292. }
  1293. std::vector<tensor::TensorPtr> input_tensors;
  1294. std::vector<int64_t> tensors_mask;
  1295. ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
  1296. // get graph info for checking it whether existing in the cache
  1297. std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
  1298. #if defined(__APPLE__)
  1299. session::OpRunInfo op_run_info = {op_exec_info->op_name,
  1300. op_exec_info->py_primitive,
  1301. op_exec_info->abstract,
  1302. op_exec_info->is_dynamic_shape,
  1303. op_exec_info->is_mixed_precision_cast,
  1304. op_exec_info->next_op_name,
  1305. static_cast<int>(op_exec_info->next_input_index)};
  1306. #else
  1307. session::OpRunInfo op_run_info = {op_exec_info->op_name,
  1308. op_exec_info->py_primitive,
  1309. op_exec_info->abstract,
  1310. op_exec_info->is_dynamic_shape,
  1311. op_exec_info->is_mixed_precision_cast,
  1312. op_exec_info->next_op_name,
  1313. op_exec_info->next_input_index};
  1314. #endif
  1315. VectorRef outputs;
  1316. session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask);
  1317. if (op_exec_info->is_dynamic_shape) {
  1318. op_exec_info->abstract = op_run_info.abstract;
  1319. }
  1320. auto result = BaseRefToPyData(outputs);
  1321. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
  1322. *status = PYNATIVE_SUCCESS;
  1323. MS_LOG(INFO) << "End run op [" << op_exec_info->op_name << "] with backend policy ms";
  1324. return result;
  1325. }
  1326. void PynativeExecutor::PushCurrentGraphToStack() { graph_stack_.push(curr_g_); }
  1327. void PynativeExecutor::PopGraphStack() {
  1328. if (graph_stack_.empty()) {
  1329. MS_LOG(EXCEPTION) << "Stack graph_stack_ is empty";
  1330. }
  1331. graph_stack_.pop();
  1332. if (!graph_stack_.empty()) {
  1333. curr_g_ = graph_stack_.top();
  1334. }
  1335. }
  1336. std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) {
  1337. auto cell_id = GetId(cell);
  1338. for (size_t i = 0; i < args.size(); i++) {
  1339. std::string arg_id = GetId(args[i]);
  1340. auto it = node_abs_map_.find(arg_id);
  1341. if (it != node_abs_map_.end()) {
  1342. cell_id += "_" + it->second->BuildShape()->ToString();
  1343. cell_id += "_" + it->second->BuildType()->ToString();
  1344. } else {
  1345. auto abs = PyAttrValue(args[i])->ToAbstract();
  1346. auto config = abstract::AbstractBase::kBroadenTensorOnly;
  1347. abs = abs->Broaden(config);
  1348. cell_id += "_" + abs->BuildShape()->ToString();
  1349. cell_id += "_" + abs->BuildType()->ToString();
  1350. node_abs_map_[arg_id] = abs;
  1351. }
  1352. }
  1353. return cell_id;
  1354. }
  1355. bool PynativeExecutor::IsNotNestedGrad() const {
  1356. MS_LOG(DEBUG) << "Grad nested count is " << grad_order_;
  1357. return grad_order_ <= 1;
  1358. }
  1359. bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
  1360. return std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
  1361. [&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
  1362. }
  1363. bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) {
  1364. return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) {
  1365. return !value.bprop_cell_id.empty() && cell_id.find(value.bprop_cell_id) != std::string::npos;
  1366. });
  1367. }
  1368. void PynativeExecutor::SubNestedGradOrder() {
  1369. if (grad_order_ > 0) {
  1370. --grad_order_;
  1371. }
  1372. }
  1373. bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) {
  1374. return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfo &value) {
  1375. return value.cell_id == cell_id && (!is_grad || value.is_grad);
  1376. });
  1377. }
  1378. void PynativeExecutor::ClearResidualRes(const std::string &cell_id) {
  1379. if (top_cell_list_.empty() && !graph_stack_.empty()) {
  1380. graph_id_ = 0;
  1381. graph_info_map_.clear();
  1382. cell_sw_map_.clear();
  1383. cell_graph_list_.clear();
  1384. std::stack<FuncGraphPtr>().swap(graph_stack_);
  1385. }
  1386. if (dynamic_cell_) {
  1387. VectorClear<std::vector<TopCellInfo>>(&top_cell_list_, cell_id);
  1388. }
  1389. }
  1390. FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
  1391. // Cell is empty, get nearest dfbuilder
  1392. if (cell_id.empty() && !top_cell_list_.empty()) {
  1393. if (top_cell_list_.size() == 1) {
  1394. return top_cell_list_.begin()->df_builder;
  1395. }
  1396. if (grad_order_ == 0 || grad_order_ == 1) {
  1397. return top_cell_list_.back().df_builder;
  1398. }
  1399. if (top_cell_list_.size() < 2) {
  1400. MS_LOG(EXCEPTION) << "Top cell list size must greater than 2";
  1401. }
  1402. MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size();
  1403. // Grad order greater than 2
  1404. auto it = top_cell_list_.end();
  1405. std::advance(it, -2);
  1406. return it->df_builder;
  1407. }
  1408. // If top graph hold
  1409. for (const auto &it : top_cell_list_) {
  1410. if (cell_id.find(it.cell_id) != std::string::npos) {
  1411. return it.df_builder;
  1412. }
  1413. }
  1414. // Current cell is not top graph, get first top cell
  1415. if (!top_cell_list_.empty()) {
  1416. return top_cell_list_.front().df_builder;
  1417. }
  1418. return nullptr;
  1419. }
  1420. ResourcePtr PynativeExecutor::GetResource(const std::string &cell_id) {
  1421. // Cell is empty, get nearest resource
  1422. if (cell_id.empty() && !top_cell_list_.empty()) {
  1423. if (top_cell_list_.size() == 1) {
  1424. return top_cell_list_.begin()->resource;
  1425. }
  1426. if (grad_order_ == 0 || grad_order_ == 1) {
  1427. return top_cell_list_.back().resource;
  1428. }
  1429. if (top_cell_list_.size() < 2) {
  1430. MS_LOG(EXCEPTION) << "Top cell list size must greater than 2";
  1431. }
  1432. MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size();
  1433. // Grad order greater than 2
  1434. auto it = top_cell_list_.end();
  1435. std::advance(it, -2);
  1436. return it->resource;
  1437. }
  1438. for (const auto &it : top_cell_list_) {
  1439. if (cell_id.find(it.cell_id) != std::string::npos) {
  1440. return it.resource;
  1441. }
  1442. }
  1443. // Current cell is not top graph, get first top cell
  1444. if (!top_cell_list_.empty()) {
  1445. return top_cell_list_.front().resource;
  1446. }
  1447. return nullptr;
  1448. }
  1449. std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
  1450. parse::AstMainType type) {
  1451. MS_EXCEPTION_IF_NULL(ast);
  1452. if (py::isinstance<py::none>(node)) {
  1453. MS_LOG(DEBUG) << "Get none type node!";
  1454. return "";
  1455. }
  1456. auto node_type = ast->GetNodeType(node);
  1457. MS_EXCEPTION_IF_NULL(node_type);
  1458. // check node type
  1459. parse::AstMainType node_main_type = node_type->main_type();
  1460. if (node_main_type != type) {
  1461. MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type;
  1462. return "";
  1463. }
  1464. std::string node_name = node_type->node_name();
  1465. MS_LOG(DEBUG) << "Ast node is " << node_name;
  1466. return node_name;
  1467. }
  1468. void PynativeExecutor::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
  1469. MS_EXCEPTION_IF_NULL(ast);
  1470. py::list args = ast->GetArgs(fn_node);
  1471. for (size_t i = 1; i < args.size(); i++) {
  1472. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  1473. MS_LOG(DEBUG) << "Input arg name: " << arg_name;
  1474. cell_input_args_.emplace(arg_name);
  1475. }
  1476. }
  1477. bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
  1478. MS_LOG(DEBUG) << "Parse if/while expr";
  1479. py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
  1480. const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR);
  1481. if (node_name == parse::NAMED_PRIMITIVE_COMPARE) {
  1482. py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT);
  1483. py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS);
  1484. if (comparators_node.empty()) {
  1485. MS_LOG(DEBUG) << "Get comparators node falied!";
  1486. return false;
  1487. }
  1488. auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
  1489. auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
  1490. if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
  1491. py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
  1492. left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
  1493. }
  1494. MS_LOG(DEBUG) << "Left is " << left << " Right is " << right;
  1495. if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
  1496. unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
  1497. return true;
  1498. }
  1499. }
  1500. // if flag:
  1501. if (node_name == parse::NAMED_PRIMITIVE_NAME) {
  1502. std::string id = py::cast<std::string>(test_node.attr("id"));
  1503. if (cell_input_args_.find(id) != cell_input_args_.end()) {
  1504. return true;
  1505. }
  1506. }
  1507. return false;
  1508. }
  1509. bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
  1510. MS_LOG(DEBUG) << "Parse assign expr";
  1511. py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE);
  1512. const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR);
  1513. if (node_name == parse::NAMED_PRIMITIVE_CALL) {
  1514. py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC);
  1515. const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR);
  1516. if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
  1517. py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
  1518. py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
  1519. const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
  1520. if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end()) {
  1521. return true;
  1522. }
  1523. }
  1524. }
  1525. return false;
  1526. }
  1527. bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
  1528. MS_LOG(DEBUG) << "Parse for expr";
  1529. py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
  1530. if (py::isinstance<py::none>(body_node)) {
  1531. MS_LOG(DEBUG) << "Parse body of for expression is none!";
  1532. return false;
  1533. }
  1534. py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
  1535. size_t count = LongToSize(pcount);
  1536. MS_LOG(DEBUG) << "The for nodes count in body is " << count;
  1537. for (size_t i = 0; i < count; ++i) {
  1538. auto it = py::cast<py::list>(body_node)[i];
  1539. const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
  1540. if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) {
  1541. return true;
  1542. }
  1543. }
  1544. return false;
  1545. }
  1546. bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
  1547. MS_EXCEPTION_IF_NULL(ast);
  1548. py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
  1549. if (py::isinstance<py::none>(func_obj)) {
  1550. MS_LOG(DEBUG) << "Parse body of cell is none!";
  1551. return false;
  1552. }
  1553. py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
  1554. size_t count = IntToSize(pcount);
  1555. MS_LOG(DEBUG) << "The nodes count in body is " << count;
  1556. bool ret = false;
  1557. for (size_t i = 0; i < count; ++i) {
  1558. auto node = py::cast<py::list>(func_obj)[i];
  1559. const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT);
  1560. if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) {
  1561. ret = ParseAssignExprNode(ast, node);
  1562. } else if (node_name == parse::NAMED_PRIMITIVE_FOR) {
  1563. ret = ParseForExprNode(ast, node);
  1564. } else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) {
  1565. ret = ParseIfWhileExprNode(ast, node);
  1566. }
  1567. if (ret) {
  1568. MS_LOG(INFO) << "Current cell is dynamic!";
  1569. break;
  1570. }
  1571. }
  1572. return ret;
  1573. }
  1574. std::string PynativeExecutor::GetCellInfo(const py::object &cell) {
  1575. if (py::isinstance<Cell>(cell)) {
  1576. auto c_cell = py::cast<CellPtr>(cell);
  1577. MS_EXCEPTION_IF_NULL(c_cell);
  1578. auto cell_info = c_cell->ToString();
  1579. return cell_info;
  1580. }
  1581. return "";
  1582. }
  1583. bool PynativeExecutor::IsDynamicCell(const py::object &cell) {
  1584. std::string cell_info = GetCellInfo(cell);
  1585. if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) {
  1586. return false;
  1587. }
  1588. // using ast parse to check whether the construct of cell will be changed
  1589. auto ast = std::make_shared<parse::ParseAst>(cell);
  1590. bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
  1591. if (!success) {
  1592. MS_LOG(ERROR) << "Parse code to ast tree failed";
  1593. return false;
  1594. }
  1595. py::object fn_node = ast->GetAstNode();
  1596. // get the name of input args as the initialize of dynamic_variables
  1597. ParseInputArgs(ast, fn_node);
  1598. // parse body context
  1599. bool ret = false;
  1600. ret = ParseBodyContext(ast, fn_node);
  1601. cell_input_args_.clear();
  1602. return ret;
  1603. }
  1604. void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
  1605. auto cell_id = GetCellId(cell, args);
  1606. MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
  1607. // check whether cell needed to construct grad graph
  1608. if (!dynamic_cell_ && graph_stack_.empty() && CheckCellGraph(cell_id)) {
  1609. if (top_cell_list_.empty()) {
  1610. MS_LOG(EXCEPTION) << "Top cell list is empty";
  1611. }
  1612. if (IsTopGraph(cell_id)) {
  1613. op_index_map_.clear();
  1614. }
  1615. MS_LOG(INFO) << "NewGraph already compiled";
  1616. return;
  1617. }
  1618. // init resource for constructing forward graph and grad graph
  1619. auto g = std::make_shared<FuncGraph>();
  1620. curr_g_ = g;
  1621. ClearResidualRes(cell_id);
  1622. if (graph_stack_.empty() && !IsBpropGraph(cell_id)) {
  1623. MakeNewTopGraph(cell_id, args, g);
  1624. }
  1625. PushCurrentGraphToStack();
  1626. if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) {
  1627. GraphInfo graph_info = GraphInfo(cell_id);
  1628. graph_info_map_.emplace(curr_g_, graph_info);
  1629. }
  1630. for (size_t i = 0; i < args.size(); ++i) {
  1631. auto param = args[i];
  1632. auto new_param = g->add_parameter();
  1633. std::string param_id = GetId(param);
  1634. SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true);
  1635. SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
  1636. SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
  1637. }
  1638. // check whether the construct of cell will be changed
  1639. if (!dynamic_cell_) {
  1640. dynamic_cell_ = IsDynamicCell(cell);
  1641. MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_;
  1642. }
  1643. }
  1644. void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g) {
  1645. for (const auto &arg : args) {
  1646. if (py::isinstance<tensor::Tensor>(arg)) {
  1647. auto tensor = arg.cast<tensor::TensorPtr>();
  1648. if (tensor && tensor->is_parameter()) {
  1649. MS_EXCEPTION(TypeError) << "The inputs could not be Parameter.";
  1650. }
  1651. }
  1652. }
  1653. // Clear runop pre
  1654. auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
  1655. [&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
  1656. if (it != top_cell_list_.end()) {
  1657. top_cell_list_.erase(it);
  1658. }
  1659. dynamic_cell_ = false;
  1660. op_index_map_.clear();
  1661. op_index_with_tensor_id_.clear();
  1662. auto df_builder = std::make_shared<FuncGraph>();
  1663. GraphInfo graph_info = GraphInfo(cell_id);
  1664. graph_info_map_.emplace(df_builder, graph_info);
  1665. auto resource = std::make_shared<pipeline::Resource>();
  1666. resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
  1667. auto top_cell_info = TopCellInfo(resource, df_builder, nullptr, cell_id);
  1668. top_cell_list_.emplace_back(top_cell_info);
  1669. MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get();
  1670. }
  1671. void PynativeExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
  1672. bool is_param) {
  1673. if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
  1674. return;
  1675. }
  1676. auto tuple = args.cast<py::tuple>();
  1677. auto tuple_size = static_cast<int64_t>(tuple.size());
  1678. for (int64_t i = 0; i < tuple_size; ++i) {
  1679. auto id = GetId(tuple[i]);
  1680. if (is_param && node->isa<Parameter>()) {
  1681. auto param = node->cast<ParameterPtr>();
  1682. MS_EXCEPTION_IF_NULL(param);
  1683. SetParamNodeMapInGraphInfoMap(g, id, param);
  1684. }
  1685. SetNodeMapInGraphInfoMap(g, id, node, i);
  1686. SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, std::vector<int64_t>{i}, is_param);
  1687. }
  1688. }
  1689. void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args,
  1690. const AnfNodePtr &node,
  1691. const std::vector<int64_t> &index_sequence, bool is_param) {
  1692. if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
  1693. return;
  1694. }
  1695. auto tuple = args.cast<py::tuple>();
  1696. auto tuple_size = static_cast<int64_t>(tuple.size());
  1697. for (int64_t i = 0; i < tuple_size; ++i) {
  1698. std::vector<int64_t> tmp = index_sequence;
  1699. tmp.emplace_back(i);
  1700. auto id = GetId(tuple[i]);
  1701. if (is_param && node->isa<Parameter>()) {
  1702. auto param = node->cast<ParameterPtr>();
  1703. MS_EXCEPTION_IF_NULL(param);
  1704. SetParamNodeMapInGraphInfoMap(g, id, param);
  1705. }
  1706. SetNodeMapInGraphInfoMap(g, id, node, tmp);
  1707. SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, tmp, is_param);
  1708. }
  1709. }
  1710. void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
  1711. const auto &cell_id = GetCellId(cell, args);
  1712. MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
  1713. if (!dynamic_cell_ && graph_stack_.empty() && CheckCellGraph(cell_id)) {
  1714. MS_LOG(INFO) << "Endgraph already compiled";
  1715. return;
  1716. }
  1717. auto out_id = GetId(out);
  1718. // x =op1, y =op2, return (x, y)
  1719. if (graph_info_map_.at(curr_g_).node_map.find(out_id) == graph_info_map_.at(curr_g_).node_map.end()) {
  1720. if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
  1721. auto tuple = out.cast<py::tuple>();
  1722. auto tuple_size = static_cast<int64_t>(tuple.size());
  1723. std::vector<AnfNodePtr> inputs;
  1724. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  1725. for (int64_t i = 0; i < tuple_size; i++) {
  1726. inputs.emplace_back(GetInput(tuple[i], false));
  1727. }
  1728. auto cnode = curr_g_->NewCNode(inputs);
  1729. SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
  1730. SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
  1731. } else {
  1732. MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
  1733. MakeValueNode(out, out_id);
  1734. }
  1735. }
  1736. EndGraphByOutId(cell, cell_id, out, out_id, args);
  1737. }
  1738. void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
  1739. const std::string &out_id, const py::args &args) {
  1740. AnfNodePtr output_node = GetObjNode(out, out_id);
  1741. curr_g_->set_output(output_node);
  1742. MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString();
  1743. if (EndBpropGraph(cell_id)) {
  1744. return;
  1745. }
  1746. auto resource = GetResource(cell_id);
  1747. MS_EXCEPTION_IF_NULL(resource);
  1748. resource->manager()->AddFuncGraph(curr_g_);
  1749. UpdateCellGraph(cell, curr_g_, cell_id, true, false);
  1750. auto newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args);
  1751. if (graph_stack_.size() > 1) {
  1752. std::vector<AnfNodePtr> inputs;
  1753. inputs.emplace_back(NewValueNode(curr_g_));
  1754. PopGraphStack();
  1755. // connect the previous graph to the inside graph
  1756. auto graph_prev = graph_stack_.top();
  1757. for (size_t i = 0; i < args.size(); i++) {
  1758. auto input = GetInput(args[i], false);
  1759. inputs.emplace_back(input);
  1760. }
  1761. auto out_cnode = graph_prev->NewCNode(inputs);
  1762. SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args));
  1763. SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
  1764. SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
  1765. } else {
  1766. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  1767. DumpIR("before_resolve.ir", newfg);
  1768. }
  1769. parse::ResolveFuncGraph(newfg, resource);
  1770. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  1771. DumpIR("after_resolve.ir", newfg);
  1772. }
  1773. resource->set_func_graph(newfg);
  1774. PopGraphStack();
  1775. }
  1776. }
  1777. bool PynativeExecutor::EndBpropGraph(const string &cell_id) {
  1778. auto is_bprop_graph = IsBpropGraph(cell_id);
  1779. if (is_bprop_graph) {
  1780. if (IsNotNestedGrad()) {
  1781. PopGraphStack();
  1782. }
  1783. return true;
  1784. }
  1785. return false;
  1786. }
  1787. void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
  1788. bool need_cloned, bool is_grad) {
  1789. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  1790. // Bprop just save backward graph
  1791. auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
  1792. [&cell_id](const CellInfo &value) { return value.cell_id == cell_id; });
  1793. if (it != cell_graph_list_.end()) {
  1794. it->is_grad = is_grad;
  1795. it->fg = g;
  1796. MS_LOG(DEBUG) << "Update bprop bg";
  1797. } else {
  1798. py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
  1799. auto cell_info = CellInfo(false, true, g, cell_id, GetId(bprop_func));
  1800. cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
  1801. }
  1802. return;
  1803. }
  1804. FuncGraphPtr tmp = g;
  1805. if (need_cloned && !IsNotNestedGrad()) {
  1806. auto cloned_curr_g = BasicClone(g);
  1807. graph_info_map_[cloned_curr_g] = graph_info_map_.at(g);
  1808. tmp = cloned_curr_g;
  1809. MS_LOG(DEBUG) << "Replace cur graph " << g.get() << " with cloned new " << cloned_curr_g.get();
  1810. }
  1811. for (auto &it : cell_graph_list_) {
  1812. if (it.cell_id != cell_id) {
  1813. continue;
  1814. }
  1815. it.is_grad = is_grad;
  1816. if (need_cloned) {
  1817. it.fg = tmp;
  1818. }
  1819. if (!need_cloned && !is_grad) {
  1820. graph_info_map_[g] = graph_info_map_.at(it.fg);
  1821. graph_info_map_.erase(it.fg);
  1822. it.fg = g;
  1823. MS_LOG(DEBUG) << "Replace cur graph " << it.fg.get() << " with new " << g.get();
  1824. }
  1825. return;
  1826. }
  1827. MS_LOG(DEBUG) << "Add new cell graph " << cell_id;
  1828. auto cell_info = CellInfo(false, true, tmp, cell_id, "");
  1829. cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
  1830. }
  1831. FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r,
  1832. const std::string &cell_id, const py::args &args) {
  1833. bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME);
  1834. if (is_custom_bprop) {
  1835. size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
  1836. if (par_number > 0) {
  1837. MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
  1838. << " parameters that is not supported in the net.";
  1839. }
  1840. MS_LOG(INFO) << "Use cell custom bprop function.";
  1841. FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
  1842. if (bprop_graph != nullptr) {
  1843. (void)g->transforms().emplace(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
  1844. (void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g)));
  1845. }
  1846. }
  1847. // Obtain grad graph
  1848. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  1849. DumpIR("fg.ir", g);
  1850. }
  1851. auto is_top = IsTopGraph(cell_id);
  1852. MS_LOG(DEBUG) << "Grad top cell " << is_top;
  1853. set_need_replace_forward(IsNotNestedGrad());
  1854. auto newfg = ad::Grad(g, r, is_top);
  1855. if (is_custom_bprop) {
  1856. auto params = newfg->parameters();
  1857. auto manager = Manage({newfg}, false);
  1858. if (args.size() > params.size()) {
  1859. MS_EXCEPTION(TypeError) << "The number of arguments " << args.size()
  1860. << " is more than the number of parameters required, which is " << params.size();
  1861. }
  1862. for (size_t i = 0; i < args.size(); i++) {
  1863. ValuePtr value = PyAttrValue(args[i]);
  1864. auto v_node = NewValueNode(value);
  1865. manager->Replace(params[i], v_node);
  1866. }
  1867. UpdateCellGraph(cell, newfg, cell_id, false, false);
  1868. }
  1869. return newfg;
  1870. }
  1871. std::string PynativeExecutor::GetGradCellId(bool has_sens, const py::object &cell, const py::args &args,
  1872. py::object *forward_args, py::object *sens) {
  1873. auto size = args.size();
  1874. size_t forward_args_size = size;
  1875. if (has_sens) {
  1876. if (size >= 1) {
  1877. --forward_args_size;
  1878. if (sens != nullptr) {
  1879. *sens = args[forward_args_size];
  1880. }
  1881. }
  1882. py::tuple f_args(forward_args_size);
  1883. for (size_t i = 0; i < forward_args_size; ++i) {
  1884. f_args[i] = args[i];
  1885. }
  1886. *forward_args = f_args;
  1887. }
  1888. const auto &cell_id = GetCellId(cell, *forward_args);
  1889. return cell_id;
  1890. }
  1891. void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  1892. const py::args &args) {
  1893. auto size = args.size();
  1894. py::object sens = py::none();
  1895. py::object forward_args = args;
  1896. const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args, &forward_args, &sens);
  1897. MS_LOG(DEBUG) << "GradNet start " << args.size() << " " << cell_id;
  1898. const auto &sw_changed = CheckCellChanged(cell_id, weights, sens);
  1899. if (!dynamic_cell_ && !sw_changed.second && CheckCellGraph(cell_id, true)) {
  1900. MS_LOG(INFO) << "Gradgraph already compiled";
  1901. return;
  1902. }
  1903. // Nested graph
  1904. if (CheckCellGraph(cell_id) && !graph_stack_.empty()) {
  1905. MS_LOG(DEBUG) << "Set nested top graph";
  1906. SetNestedTopGraph(cell, forward_args, cell_id);
  1907. }
  1908. auto df_builder = GetDfbuilder(cell_id);
  1909. MS_EXCEPTION_IF_NULL(df_builder);
  1910. auto resource = GetResource(cell_id);
  1911. MS_EXCEPTION_IF_NULL(resource);
  1912. MS_LOG(DEBUG) << "df_builder ptr " << df_builder.get() << " resource ptr " << resource.get();
  1913. // Set all params(input+weights)
  1914. SetGradGraphParams(df_builder, resource, size);
  1915. // Get params(weights) require derivative
  1916. auto w_args = GetWeightsArgs(weights, df_builder);
  1917. // Get the parameters items and add the value to args_spec
  1918. auto args_spec = GetArgsSpec(args, df_builder);
  1919. resource->set_args_spec(args_spec);
  1920. // Get real grad graph
  1921. GradGraph(resource->func_graph(), grad, w_args, size, cell_id);
  1922. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  1923. DumpIR("before_grad.ir", resource->func_graph());
  1924. DumpIR("after_grad.ir", df_builder);
  1925. }
  1926. resource->set_func_graph(df_builder);
  1927. resource->manager()->KeepRoots({df_builder});
  1928. resource->results()[pipeline::kBackend] = compile::CreateBackend();
  1929. MS_LOG(INFO) << "Start opt";
  1930. PynativeOptimizeAction(resource);
  1931. SaveTensorsInValueNode(resource);
  1932. TaskEmitAction(resource);
  1933. ExecuteAction(resource);
  1934. UpdateCellGraph(cell, curr_g_, cell_id, false, true);
  1935. UpdateGraphInfoMap(cell_id);
  1936. resource->Clean();
  1937. }
  1938. std::pair<bool, bool> PynativeExecutor::CheckCellChanged(const std::string &cell_id, const py::object &weights,
  1939. const py::object &sens) {
  1940. auto fn = [](const py::object &arg) {
  1941. std::string arg_id;
  1942. if (py::isinstance<tensor::Tensor>(arg)) {
  1943. auto tensor_ptr = py::cast<tensor::TensorPtr>(arg);
  1944. auto dtype = tensor_ptr->data_type();
  1945. auto shape = tensor_ptr->shape();
  1946. std::stringstream ss;
  1947. std::for_each(shape.begin(), shape.end(), [&ss](int i) { ss << i; });
  1948. arg_id = ss.str() + std::to_string(dtype);
  1949. } else {
  1950. arg_id = std::string(py::str(arg));
  1951. }
  1952. return arg_id;
  1953. };
  1954. std::string sens_id = "sens";
  1955. if (!py::isinstance<py::none>(sens)) {
  1956. sens_id = fn(sens);
  1957. }
  1958. std::string weights_id = fn(weights);
  1959. std::pair<bool, bool> sens_weights_changed(false, false);
  1960. // Check whether sens or weights changed
  1961. auto it = cell_sw_map_.find(cell_id);
  1962. if (it != cell_sw_map_.end() && it->second.first != sens_id) {
  1963. MS_LOG(DEBUG) << "Sens_id, cur is " << it->second.first << " new is " << sens_id;
  1964. sens_weights_changed.first = true;
  1965. }
  1966. if (it != cell_sw_map_.end() && it->second.second != weights_id) {
  1967. MS_LOG(DEBUG) << "Weights_id, cur is " << it->second.first << " new is " << weights_id;
  1968. sens_weights_changed.second = true;
  1969. }
  1970. cell_sw_map_[cell_id] = std::make_pair(sens_id, weights_id);
  1971. return sens_weights_changed;
  1972. }
  1973. void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id) {
  1974. if (IsTopGraph(cell_id)) {
  1975. return;
  1976. }
  1977. ResourcePtr resource = nullptr;
  1978. auto ia = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
  1979. [&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
  1980. if (ia != top_cell_list_.end()) {
  1981. resource = GetResource(ia->cell_id);
  1982. MS_EXCEPTION_IF_NULL(resource);
  1983. MS_LOG(DEBUG) << "Find old resource " << resource.get();
  1984. }
  1985. if (resource == nullptr) {
  1986. resource = std::make_shared<pipeline::Resource>();
  1987. resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
  1988. MS_LOG(DEBUG) << "Make new resource " << resource.get();
  1989. }
  1990. MS_EXCEPTION_IF_NULL(resource);
  1991. FuncGraphPtr df_builder = std::make_shared<FuncGraph>();
  1992. GraphInfo graph_info = GraphInfo(cell_id);
  1993. graph_info_map_.emplace(df_builder, graph_info);
  1994. auto top_cell_info = TopCellInfo(resource, df_builder, nullptr, cell_id);
  1995. top_cell_list_.emplace_back(top_cell_info);
  1996. FuncGraphPtr forward_graph = nullptr;
  1997. auto ib = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
  1998. [&cell_id](const CellInfo &value) { return value.cell_id == cell_id; });
  1999. if (ib != cell_graph_list_.end()) {
  2000. forward_graph = ib->fg;
  2001. }
  2002. MS_EXCEPTION_IF_NULL(forward_graph);
  2003. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  2004. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  2005. DumpIR("nested_bprop.ir", forward_graph);
  2006. }
  2007. // Custom bprop get backward graph(before opt), which use like other forward graph
  2008. curr_g_ = forward_graph;
  2009. resource->set_func_graph(forward_graph);
  2010. return;
  2011. }
  2012. // Copy weights
  2013. std::vector<AnfNodePtr> weights_params{};
  2014. for (const auto &it : graph_info_map_.at(forward_graph).params) {
  2015. if (it.second->has_default()) {
  2016. weights_params.emplace_back(it.second);
  2017. graph_info_map_.at(df_builder).params.emplace(it.first, it.second);
  2018. SetNodeMapInGraphInfoMap(df_builder, it.first, it.second);
  2019. }
  2020. }
  2021. MS_LOG(DEBUG) << "Get weights params size " << weights_params.size();
  2022. df_builder->set_parameters(weights_params);
  2023. resource->manager()->AddFuncGraph(forward_graph);
  2024. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  2025. DumpIR("nested_fg.ir", forward_graph);
  2026. }
  2027. set_need_replace_forward(false);
  2028. auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args);
  2029. resource->set_func_graph(newfg);
  2030. }
  2031. void PynativeExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size) {
  2032. std::vector<AnfNodePtr> new_params;
  2033. for (size_t i = 0; i < size; i++) {
  2034. ParameterPtr p = std::make_shared<Parameter>(df_builder);
  2035. new_params.emplace_back(p);
  2036. }
  2037. MS_LOG(DEBUG) << "GradNet weight param size " << df_builder->parameters().size();
  2038. // df_builder_->parameters() set in GetInput, which are weights params
  2039. new_params.insert(new_params.end(), df_builder->parameters().begin(), df_builder->parameters().end());
  2040. df_builder->set_parameters(new_params);
  2041. resource->manager()->SetParameters(df_builder, new_params);
  2042. }
  2043. std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder) {
  2044. std::vector<AnfNodePtr> w_args;
  2045. if (!py::hasattr(weights, "__parameter_tuple__")) {
  2046. MS_LOG(DEBUG) << "No paramter_tuple get";
  2047. return {};
  2048. }
  2049. auto tuple = weights.cast<py::tuple>();
  2050. MS_LOG(DEBUG) << "Get weights tuple size " << tuple.size();
  2051. w_args.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  2052. for (size_t it = 0; it < tuple.size(); ++it) {
  2053. auto param = tuple[it];
  2054. auto param_id = GetId(param);
  2055. AnfNodePtr para_node = nullptr;
  2056. if (graph_info_map_.at(df_builder).params.find(param_id) != graph_info_map_.at(df_builder).params.end() &&
  2057. graph_info_map_.at(df_builder).node_map.find(param_id) != graph_info_map_.at(df_builder).node_map.end()) {
  2058. para_node = graph_info_map_.at(df_builder).node_map[param_id].first;
  2059. } else {
  2060. auto name_attr = parse::python_adapter::GetPyObjAttr(param, "name");
  2061. if (py::isinstance<py::none>(name_attr)) {
  2062. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  2063. }
  2064. auto param_name = py::cast<std::string>(name_attr);
  2065. auto free_param = df_builder->add_parameter();
  2066. free_param->set_name(param_name);
  2067. auto value = py::cast<tensor::TensorPtr>(param);
  2068. free_param->set_default_param(value);
  2069. free_param->debug_info()->set_name(param_name);
  2070. para_node = free_param;
  2071. }
  2072. w_args.emplace_back(para_node);
  2073. }
  2074. return w_args;
  2075. }
  2076. abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder) {
  2077. abstract::AbstractBasePtrList args_spec;
  2078. std::size_t size = args.size();
  2079. auto df_params = df_builder->parameters();
  2080. if (df_params.size() < size) {
  2081. MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << size;
  2082. }
  2083. // input params
  2084. for (std::size_t i = 0; i < size; i++) {
  2085. ValuePtr converted = nullptr;
  2086. bool succ = parse::ConvertData(args[i], &converted);
  2087. if (!succ) {
  2088. MS_LOG(EXCEPTION) << "Args convert error";
  2089. }
  2090. bool broaden = true;
  2091. auto abs = abstract::FromValue(converted, broaden);
  2092. args_spec.emplace_back(abs);
  2093. auto param_node = std::static_pointer_cast<Parameter>(df_params[i]);
  2094. param_node->set_abstract(abs);
  2095. }
  2096. // weights params
  2097. for (const auto &param : df_params) {
  2098. auto param_node = std::static_pointer_cast<Parameter>(param);
  2099. if (param_node->has_default()) {
  2100. ValuePtr value = param_node->default_param();
  2101. auto ptr = value->ToAbstract();
  2102. MS_EXCEPTION_IF_NULL(ptr);
  2103. args_spec.emplace_back(ptr);
  2104. param_node->set_abstract(ptr);
  2105. }
  2106. }
  2107. MS_LOG(DEBUG) << "Args_spec size " << args_spec.size();
  2108. return args_spec;
  2109. }
  2110. void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op,
  2111. const std::vector<AnfNodePtr> &weights, size_t arg_size, const std::string &cell_id) {
  2112. FuncGraphPtr top_g = nullptr;
  2113. auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
  2114. [&cell_id](const CellInfo &value) { return value.cell_id == cell_id; });
  2115. if (it != cell_graph_list_.end()) {
  2116. top_g = it->fg;
  2117. }
  2118. MS_EXCEPTION_IF_NULL(top_g);
  2119. auto nparam = top_g->parameters().size();
  2120. MS_LOG(DEBUG) << "Top graph input params size " << nparam;
  2121. std::ostringstream ss;
  2122. ss << "grad{" << nparam << "}";
  2123. auto df_builder = GetDfbuilder(cell_id);
  2124. MS_EXCEPTION_IF_NULL(df_builder);
  2125. auto resource = GetResource(cell_id);
  2126. MS_EXCEPTION_IF_NULL(resource);
  2127. df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  2128. df_builder->debug_info()->set_name(ss.str());
  2129. auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g->parameters(), weights);
  2130. std::vector<AnfNodePtr> inputs = {NewValueNode(df)};
  2131. auto df_params = df_builder->parameters();
  2132. if (df_params.size() < arg_size) {
  2133. MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << arg_size;
  2134. }
  2135. for (size_t i = 0; i < arg_size; ++i) {
  2136. inputs.emplace_back(df_params[i]);
  2137. }
  2138. auto out = df_builder->NewCNode(inputs);
  2139. df_builder->set_output(out);
  2140. resource->manager()->AddFuncGraph(df);
  2141. resource->manager()->AddFuncGraph(df_builder);
  2142. }
  2143. void PynativeExecutor::UpdateGraphInfoMap(const std::string &cell_id) {
  2144. std::vector<std::string> l{};
  2145. bool index_find = false;
  2146. for (const auto &it : cell_graph_list_) {
  2147. if (index_find) {
  2148. l.emplace_back(it.cell_id);
  2149. continue;
  2150. }
  2151. if (it.cell_id == cell_id) {
  2152. index_find = true;
  2153. l.emplace_back(it.cell_id);
  2154. }
  2155. }
  2156. for (const auto &it : l) {
  2157. for (auto ic = graph_info_map_.begin(); ic != graph_info_map_.end();) {
  2158. if (ic->second.cell_id.find(it) != std::string::npos) {
  2159. ic = graph_info_map_.erase(ic);
  2160. } else {
  2161. ++ic;
  2162. }
  2163. }
  2164. }
  2165. }
  2166. py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) {
  2167. BaseRef ret = false;
  2168. AddNestedGradOrder();
  2169. if (!grad_running()) {
  2170. MS_LOG(DEBUG) << "Grad not running yet";
  2171. return BaseRefToPyData(ret);
  2172. }
  2173. const auto &cell_id = GetCellId(cell, args);
  2174. string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size()));
  2175. MS_LOG(DEBUG) << "Key is " << key;
  2176. for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) {
  2177. MS_LOG(DEBUG) << "Cur cell id " << it->cell_id;
  2178. if (key != it->cell_id.substr(0, std::min(PTR_LEN, it->cell_id.size()))) {
  2179. continue;
  2180. }
  2181. MS_LOG(DEBUG) << "Delete cellid from cell graph list";
  2182. cell_graph_list_.erase(it);
  2183. ret = true;
  2184. break;
  2185. }
  2186. return BaseRefToPyData(ret);
  2187. }
  2188. py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) {
  2189. auto cell_id = GetCellId(cell, args);
  2190. MS_LOG(DEBUG) << "Run start cell id " << cell_id;
  2191. bool has_sens = false;
  2192. for (const auto &it : top_cell_list_) {
  2193. if (cell_id.find(it.cell_id) != std::string::npos && cell_id != it.cell_id) {
  2194. has_sens = true;
  2195. break;
  2196. }
  2197. }
  2198. py::object forward_args = args;
  2199. cell_id = GetGradCellId(has_sens, cell, args, &forward_args);
  2200. MS_LOG(DEBUG) << "Run has sens " << has_sens << " forward cell id " << cell_id;
  2201. auto resource = GetResource(cell_id);
  2202. MS_EXCEPTION_IF_NULL(resource);
  2203. MS_LOG(DEBUG) << "Run resource ptr " << resource.get();
  2204. VectorRef arg_list;
  2205. py::tuple converted_args = ConvertArgs(args);
  2206. pipeline::ProcessVmArgInner(converted_args, resource, &arg_list);
  2207. if (resource->results().find(pipeline::kOutput) == resource->results().end()) {
  2208. MS_LOG(EXCEPTION) << "Can't find run graph output";
  2209. }
  2210. if (!resource->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
  2211. MS_LOG(EXCEPTION) << "Run graph is not VmEvalFuncPtr";
  2212. }
  2213. compile::VmEvalFuncPtr run = resource->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
  2214. MS_EXCEPTION_IF_NULL(run);
  2215. std::string backend = MsContext::GetInstance()->backend_policy();
  2216. MS_LOG(DEBUG) << "Eval run " << backend;
  2217. set_grad_runing(true);
  2218. BaseRef value = (*run)(arg_list);
  2219. CleanTensorsInValueNode();
  2220. set_grad_runing(false);
  2221. MS_LOG(DEBUG) << "Eval run end " << value.ToString();
  2222. auto out = BaseRefToPyData(value);
  2223. if (MakeBpropNestedCnode(cell, out, cell_id)) {
  2224. return out;
  2225. }
  2226. MakeNestedCnode(cell_id, args, resource, out, has_sens);
  2227. return out;
  2228. }
  2229. bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id) {
  2230. if (graph_stack_.empty() || !py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  2231. MS_LOG(DEBUG) << "No nested bprop grad find";
  2232. return false;
  2233. }
  2234. auto out_id = GetId(out);
  2235. std::vector<AnfNodePtr> inputs;
  2236. inputs.emplace_back(NewValueNode(curr_g_));
  2237. PopGraphStack();
  2238. for (const auto &ig : graph_info_map_.at(curr_g_).params) {
  2239. if (!ig.second->has_default()) {
  2240. inputs.emplace_back(ig.second);
  2241. }
  2242. }
  2243. auto cnode = curr_g_->NewCNode(inputs);
  2244. SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
  2245. SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
  2246. MS_LOG(DEBUG) << "Custom bprop make nested node is " << cnode->DebugString(4);
  2247. return true;
  2248. }
  2249. void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
  2250. const py::object &out, bool has_sens) {
  2251. if (graph_stack_.empty()) {
  2252. MS_LOG(DEBUG) << "No nested grad find";
  2253. return;
  2254. }
  2255. auto graph_prev = graph_stack_.top();
  2256. MS_EXCEPTION_IF_NULL(graph_prev);
  2257. MS_LOG(DEBUG) << "Get pre graph ptr " << graph_prev.get();
  2258. auto newfg = resource->func_graph();
  2259. MS_EXCEPTION_IF_NULL(newfg);
  2260. auto size = args.size();
  2261. if (has_sens) {
  2262. size -= 1;
  2263. }
  2264. std::vector<AnfNodePtr> inputs;
  2265. inputs.emplace_back(NewValueNode(newfg));
  2266. for (size_t i = 0; i < size; ++i) {
  2267. inputs.emplace_back(GetInput(args[i], false));
  2268. }
  2269. auto out_id = GetId(out);
  2270. auto cnode = graph_prev->NewCNode(inputs);
  2271. SetTupleArgsToGraphInfoMap(graph_prev, out, cnode);
  2272. SetNodeMapInGraphInfoMap(graph_prev, out_id, cnode);
  2273. MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
  2274. }
  2275. void PynativeExecutor::Clear(const std::string &cell_id) {
  2276. if (cell_id.empty()) {
  2277. Clean();
  2278. return;
  2279. }
  2280. MS_LOG(DEBUG) << "Clear cell res, cell id " << cell_id;
  2281. for (auto it = graph_info_map_.begin(); it != graph_info_map_.end();) {
  2282. if (it->second.cell_id.find(cell_id) != std::string::npos) {
  2283. it = graph_info_map_.erase(it);
  2284. } else {
  2285. ++it;
  2286. }
  2287. }
  2288. // Maybe exit in runop step
  2289. auto ms_context = MsContext::GetInstance();
  2290. if (ms_context != nullptr) {
  2291. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
  2292. }
  2293. ConfigManager::GetInstance().ResetIterNum();
  2294. MapClear<std::unordered_map<std::string, bool>>(&cell_dynamic_map_, cell_id);
  2295. MapClear<std::unordered_map<std::string, std::pair<std::string, std::string>>>(&cell_sw_map_, cell_id);
  2296. VectorClear<std::vector<CellInfo>>(&cell_graph_list_, cell_id);
  2297. VectorClear<std::vector<TopCellInfo>>(&top_cell_list_, cell_id);
  2298. node_abs_map_.clear();
  2299. }
  2300. void PynativeExecutor::Clean() {
  2301. MS_LOG(DEBUG) << "Clean";
  2302. SubNestedGradOrder();
  2303. node_abs_map_.clear();
  2304. obj_to_forward_id_.clear();
  2305. ad::CleanRes();
  2306. pipeline::ReclaimOptimizer();
  2307. }
  2308. void PynativeExecutor::ClearRes() {
  2309. MS_LOG(DEBUG) << "Clear all res";
  2310. Clean();
  2311. graph_id_ = 0;
  2312. grad_order_ = 0;
  2313. grad_flag_ = false;
  2314. dynamic_cell_ = false;
  2315. grad_is_running_ = false;
  2316. need_replace_forward_ = true;
  2317. curr_g_ = nullptr;
  2318. graph_info_map_.clear();
  2319. cell_sw_map_.clear();
  2320. cell_graph_list_.clear();
  2321. top_cell_list_.clear();
  2322. op_index_map_.clear();
  2323. op_index_with_tensor_id_.clear();
  2324. tensor_id_with_tensor_.clear();
  2325. cell_dynamic_map_.clear();
  2326. prim_abs_list_.clear();
  2327. std::stack<FuncGraphPtr>().swap(graph_stack_);
  2328. }
  2329. void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
  2330. PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args);
  2331. }
  2332. void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
  2333. MS_LOG(DEBUG) << "Enter end graph process.";
  2334. auto &mem_cleaner = pipeline::Resource::mem_cleaner();
  2335. mem_cleaner.EnterPynativeEndGraphProcess();
  2336. PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
  2337. mem_cleaner.LeavePynativeEndGraphProcess();
  2338. MS_LOG(DEBUG) << "Leave end graph process.";
  2339. }
  2340. void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  2341. const py::args &args) {
  2342. PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
  2343. }
  2344. void PynativeExecutor::Sync() {
  2345. if (session == nullptr) {
  2346. MS_EXCEPTION(NotExistsError) << "No session has been created!";
  2347. }
  2348. session->SyncStream();
  2349. }
  2350. void PynativeExecutor::EnterConstruct(const py::object &cell) {
  2351. if (top_cell_ != nullptr) {
  2352. return;
  2353. }
  2354. top_cell_ = cell.ptr();
  2355. pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess();
  2356. MS_LOG(DEBUG) << "Enter construct process.";
  2357. }
  2358. void PynativeExecutor::LeaveConstruct(const py::object &cell) {
  2359. if (top_cell_ != cell.ptr()) {
  2360. return;
  2361. }
  2362. top_cell_ = nullptr;
  2363. pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess();
  2364. MS_LOG(DEBUG) << "Leave construct process.";
  2365. }
  2366. REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
  2367. (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
  2368. .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
  2369. .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
  2370. .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
  2371. .def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.")
  2372. .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
  2373. .def("clear", &PynativeExecutor::Clear, "pynative clear status.")
  2374. .def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
  2375. .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
  2376. .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
  2377. "Executor set grad flag.")
  2378. .def("enter_construct", &PynativeExecutor::EnterConstruct,
  2379. "Do something before enter construct function.")
  2380. .def("leave_construct", &PynativeExecutor::LeaveConstruct,
  2381. "Do something after leave construct function.");
  2382. }));
  2383. } // namespace mindspore::pynative