You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pynative_execute.cc 118 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/pynative/pynative_execute.h"
  17. #include <typeinfo>
  18. #include <map>
  19. #include <set>
  20. #include <memory>
  21. #include <sstream>
  22. #include <unordered_set>
  23. #include <algorithm>
  24. #include "debug/trace.h"
  25. #include "pybind_api/ir/tensor_py.h"
  26. #include "ir/param_info.h"
  27. #include "ir/anf.h"
  28. #include "ir/cell.h"
  29. #include "ir/tensor.h"
  30. #include "utils/any.h"
  31. #include "utils/utils.h"
  32. #include "utils/ms_context.h"
  33. #include "utils/context/context_extends.h"
  34. #include "utils/config_manager.h"
  35. #include "utils/convert_utils_py.h"
  36. #include "frontend/operator/ops.h"
  37. #include "frontend/operator/composite/do_signature.h"
  38. #include "pipeline/jit/parse/data_converter.h"
  39. #include "pipeline/jit/parse/resolve.h"
  40. #include "pipeline/jit/static_analysis/prim.h"
  41. #include "backend/session/session_factory.h"
  42. #include "backend/optimizer/pass/const_input_to_attr_registry.h"
  43. #include "backend/optimizer/common/helper.h"
  44. #include "pipeline/jit/action.h"
  45. #include "pipeline/pynative/base.h"
  46. #include "pybind_api/api_register.h"
  47. #include "vm/transform.h"
  48. #include "frontend/optimizer/ad/grad.h"
  49. #include "pipeline/jit/resource.h"
  50. #include "pipeline/jit/pipeline.h"
  51. #include "pipeline/jit/pass.h"
  52. #include "frontend/parallel/context.h"
  53. #ifdef ENABLE_GE
  54. #include "pipeline/pynative/pynative_execute_ge.h"
  55. #endif
  56. #include "debug/anf_ir_dump.h"
  57. using mindspore::tensor::TensorPy;
  58. const size_t PTR_LEN = 15;
  59. // primitive unable to infer value for constant input in PyNative mode
  60. static const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
  61. "mixed_precision_cast"};
  62. static const char kOpsFunctionModelName[] = "mindspore.ops.functional";
  63. static const char kMSDtypeModelName[] = "mindspore.common.dtype";
  64. namespace mindspore::pynative {
  65. static std::shared_ptr<session::SessionBasic> session = nullptr;
  66. PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
  67. std::mutex PynativeExecutor::instance_lock_;
  68. int64_t PynativeExecutor::graph_id_ = 0;
  69. template <typename... Args>
  70. void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) {
  71. MS_EXCEPTION_IF_NULL(executor);
  72. try {
  73. (executor->*method)(args...);
  74. } catch (const py::error_already_set &ex) {
  75. // print function call stack info before release
  76. std::ostringstream oss;
  77. trace::TraceGraphEval();
  78. trace::GetEvalStackInfo(oss);
  79. // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
  80. // these info from screen, no need to open log file to find these info
  81. py::print(oss.str());
  82. MS_LOG(ERROR) << oss.str();
  83. PynativeExecutor::GetInstance()->ClearRes();
  84. // re-throw this exception to Python interpreter to handle it
  85. throw(py::error_already_set(ex));
  86. } catch (const py::type_error &ex) {
  87. PynativeExecutor::GetInstance()->ClearRes();
  88. throw py::type_error(ex);
  89. } catch (const py::value_error &ex) {
  90. PynativeExecutor::GetInstance()->ClearRes();
  91. throw py::value_error(ex);
  92. } catch (const py::index_error &ex) {
  93. PynativeExecutor::GetInstance()->ClearRes();
  94. throw py::index_error(ex);
  95. } catch (const std::exception &ex) {
  96. PynativeExecutor::GetInstance()->ClearRes();
  97. // re-throw this exception to Python interpreter to handle it
  98. throw(std::runtime_error(ex.what()));
  99. } catch (...) {
  100. PynativeExecutor::GetInstance()->ClearRes();
  101. std::string exName(abi::__cxa_current_exception_type()->name());
  102. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
  103. }
  104. }
  105. inline ValuePtr PyAttrValue(const py::object &obj) {
  106. ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
  107. if (!converted_ret) {
  108. MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
  109. }
  110. return converted_ret;
  111. }
  112. static std::string GetId(const py::object &obj) {
  113. if (py::isinstance<tensor::Tensor>(obj)) {
  114. auto tensor_ptr = py::cast<tensor::TensorPtr>(obj);
  115. return tensor_ptr->id();
  116. } else if (py::isinstance<mindspore::Type>(obj)) {
  117. auto type_ptr = py::cast<mindspore::TypePtr>(obj);
  118. return "type" + type_ptr->ToString();
  119. } else if (py::isinstance<py::str>(obj) || py::isinstance<py::int_>(obj) || py::isinstance<py::float_>(obj)) {
  120. return std::string(py::str(obj));
  121. } else if (py::isinstance<py::none>(obj)) {
  122. return "none";
  123. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  124. auto p_list = py::cast<py::tuple>(obj);
  125. string prefix = py::isinstance<py::tuple>(obj) ? "tuple:" : "list";
  126. if (p_list.empty()) {
  127. prefix = "empty";
  128. } else {
  129. std::string key;
  130. for (size_t i = 0; i < p_list.size(); ++i) {
  131. key += std::string(py::str(GetId(p_list[i]))) + ":";
  132. }
  133. prefix += key;
  134. }
  135. return prefix;
  136. }
  137. py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
  138. return py::cast<std::string>(ret);
  139. }
  140. std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
  141. std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
  142. for (size_t i = 0; i < dtypes.size(); ++i) {
  143. auto it = type_indexes.find(dtypes[i]);
  144. if (it == type_indexes.end()) {
  145. (void)type_indexes.emplace(std::make_pair(dtypes[i], std::vector<size_t>{i}));
  146. } else {
  147. it->second.emplace_back(i);
  148. }
  149. }
  150. return type_indexes;
  151. }
  152. std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
  153. const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
  154. std::map<SignatureEnumDType, TypeId> dst_type;
  155. for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
  156. auto type = it->first;
  157. auto indexes = it->second;
  158. if (type == SignatureEnumDType::kDTypeEmptyDefaultValue || indexes.size() < 2) {
  159. continue;
  160. }
  161. size_t priority = 0;
  162. TypeId max_type = TypeId::kTypeUnknown;
  163. bool has_scalar_float32 = false;
  164. bool has_scalar_int64 = false;
  165. bool has_tensor_int8 = false;
  166. for (size_t index : indexes) {
  167. if (!has_scalar_float32 && py::isinstance<py::float_>(py_args[index])) {
  168. has_scalar_float32 = true;
  169. }
  170. if (!has_scalar_int64 && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) {
  171. has_scalar_int64 = true;
  172. }
  173. auto obj = py_args[index];
  174. if (py::isinstance<tensor::Tensor>(obj)) {
  175. auto arg = py::cast<tensor::TensorPtr>(obj);
  176. TypeId arg_type_id = arg->data_type();
  177. auto type_priority = prim::type_map.find(arg_type_id);
  178. if (type_priority == prim::type_map.end()) {
  179. continue;
  180. }
  181. if (arg_type_id == kNumberTypeInt8) {
  182. has_tensor_int8 = true;
  183. }
  184. if (type_priority->second > priority) {
  185. max_type = type_priority->first;
  186. priority = type_priority->second;
  187. }
  188. }
  189. }
  190. if (max_type == TypeId::kNumberTypeBool) {
  191. if (has_scalar_int64) {
  192. max_type = TypeId::kNumberTypeInt64;
  193. }
  194. if (has_scalar_float32) {
  195. max_type = TypeId::kNumberTypeFloat32;
  196. }
  197. }
  198. if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
  199. max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
  200. max_type = TypeId::kNumberTypeFloat32;
  201. }
  202. if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
  203. max_type = TypeId::kNumberTypeInt16;
  204. }
  205. (void)dst_type.emplace(std::make_pair(type, max_type));
  206. }
  207. return dst_type;
  208. }
  209. std::string TypeIdToMsTypeStr(const TypeId &type_id) {
  210. auto type_name = type_name_map.find(type_id);
  211. if (type_name == type_name_map.end()) {
  212. MS_LOG(EXCEPTION) << "For implicit type conversion, not support convert to the type: " << TypeIdToType(type_id);
  213. }
  214. return type_name->second;
  215. }
  216. bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
  217. MS_EXCEPTION_IF_NULL(dtypes);
  218. auto signature = prim->signatures();
  219. bool has_sig_dtype = false;
  220. (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
  221. [&has_sig_dtype](const Signature &sig) {
  222. auto dtype = sig.dtype;
  223. if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
  224. has_sig_dtype = true;
  225. }
  226. return dtype;
  227. });
  228. return has_sig_dtype;
  229. }
  230. void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info,
  231. const abstract::AbstractBasePtrList &args_spec_list) {
  232. MS_LOG(DEBUG) << "Prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list);
  233. prim->BeginRecordAddAttr();
  234. AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
  235. prim->EndRecordAddAttr();
  236. op_exec_info->abstract = infer_res;
  237. MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString();
  238. }
  239. std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
  240. const std::vector<tensor::TensorPtr> &input_tensors) {
  241. MS_EXCEPTION_IF_NULL(op_exec_info);
  242. std::string graph_info;
  243. // get input tensor info
  244. for (const auto &tensor : input_tensors) {
  245. MS_EXCEPTION_IF_NULL(tensor);
  246. auto tensor_shape = tensor->shape();
  247. (void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
  248. [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
  249. (void)graph_info.append(std::to_string(tensor->data_type()) + "_");
  250. if (tensor->device_address() != nullptr) {
  251. (void)graph_info.append(
  252. std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id()) + "_");
  253. (void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format() +
  254. "_");
  255. }
  256. }
  257. // get prim and abstract info
  258. (void)graph_info.append(op_exec_info->prim_id + "_");
  259. // get attr info
  260. const auto &op_prim = op_exec_info->py_primitive;
  261. MS_EXCEPTION_IF_NULL(op_prim);
  262. const auto &attr_map = op_prim->evaluate_added_attrs();
  263. (void)std::for_each(attr_map.begin(), attr_map.end(),
  264. [&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
  265. // Add output information(shape, type id) of the operator to graph_info to solve the problem of cache missing
  266. // caused by operators like DropoutGenMask whose output is related to values of input when input shapes are
  267. // the same but values are different
  268. auto abstr = op_exec_info->abstract;
  269. MS_EXCEPTION_IF_NULL(abstr);
  270. auto build_shape = abstr->BuildShape();
  271. MS_EXCEPTION_IF_NULL(build_shape);
  272. (void)graph_info.append(build_shape->ToString() + "_");
  273. auto build_type = abstr->BuildType();
  274. MS_EXCEPTION_IF_NULL(build_type);
  275. (void)graph_info.append(std::to_string(build_type->type_id()) + "_");
  276. return graph_info;
  277. }
  278. bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
  279. const std::unordered_set<size_t> &input_attrs) {
  280. MS_EXCEPTION_IF_NULL(op_prim);
  281. auto input_names_value = op_prim->GetAttr(kAttrInputNames);
  282. if (input_names_value == nullptr) {
  283. return false;
  284. }
  285. auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
  286. if (input_index >= input_names_vec.size()) {
  287. MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
  288. }
  289. if (input_attrs.find(input_index) != input_attrs.end()) {
  290. ValuePtr value = parse::data_converter::PyDataToValue(input_object);
  291. MS_EXCEPTION_IF_NULL(value);
  292. auto input_name = input_names_vec[input_index];
  293. op_prim->AddAttr(input_name, value);
  294. return true;
  295. }
  296. return false;
  297. }
  298. void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
  299. std::vector<tensor::TensorPtr> *input_tensors) {
  300. MS_EXCEPTION_IF_NULL(op_prim);
  301. MS_EXCEPTION_IF_NULL(input_tensors);
  302. for (const auto &input_object : tuple_inputs) {
  303. if (!py::isinstance<tensor::Tensor>(input_object)) {
  304. MS_LOG(EXCEPTION) << "The input object is not a tensor!";
  305. }
  306. auto tensor = py::cast<tensor::TensorPtr>(input_object);
  307. MS_EXCEPTION_IF_NULL(tensor);
  308. input_tensors->emplace_back(tensor);
  309. }
  310. op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(tuple_inputs.size())}));
  311. }
  312. void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensors) {
  313. MS_EXCEPTION_IF_NULL(input_tensors);
  314. ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
  315. MS_EXCEPTION_IF_NULL(input_value);
  316. if (!input_value->isa<ValueTuple>()) {
  317. MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
  318. }
  319. auto value_tuple = input_value->cast<ValueTuplePtr>();
  320. MS_EXCEPTION_IF_NULL(value_tuple);
  321. tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
  322. MS_EXCEPTION_IF_NULL(tensor_ptr);
  323. input_tensors->emplace_back(tensor_ptr);
  324. }
  325. void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  326. std::vector<tensor::TensorPtr> *input_tensors, int64_t *tensor_mask) {
  327. MS_EXCEPTION_IF_NULL(op_prim);
  328. MS_EXCEPTION_IF_NULL(input_tensors);
  329. MS_EXCEPTION_IF_NULL(tensor_mask);
  330. if (!py::isinstance<py::tuple>(input_object)) {
  331. MS_LOG(EXCEPTION) << "The input should be a tuple!";
  332. }
  333. auto tuple_inputs = py::cast<py::tuple>(input_object);
  334. if (tuple_inputs.empty()) {
  335. MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
  336. }
  337. auto inputs = py::cast<py::tuple>(input_object);
  338. if (py::isinstance<tensor::Tensor>(inputs[0])) {
  339. PlantTensorTupleToVector(inputs, op_prim, input_tensors);
  340. } else {
  341. ConvertValueTupleToTensor(input_object, input_tensors);
  342. *tensor_mask = kValueNodeTensorMask;
  343. }
  344. }
  345. void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  346. std::vector<tensor::TensorPtr> *input_tensors, int64_t *tensor_mask) {
  347. MS_EXCEPTION_IF_NULL(op_prim);
  348. MS_EXCEPTION_IF_NULL(input_tensors);
  349. MS_EXCEPTION_IF_NULL(tensor_mask);
  350. tensor::TensorPtr tensor_ptr = nullptr;
  351. if (py::isinstance<tensor::Tensor>(input_object)) {
  352. tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
  353. } else if (py::isinstance<py::float_>(input_object)) {
  354. double input_value = py::cast<py::float_>(input_object);
  355. tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
  356. *tensor_mask = kValueNodeTensorMask;
  357. } else if (py::isinstance<py::int_>(input_object)) {
  358. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
  359. *tensor_mask = kValueNodeTensorMask;
  360. } else if (py::isinstance<py::array>(input_object)) {
  361. tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
  362. } else if (py::isinstance<py::list>(input_object)) {
  363. auto list_inputs = py::cast<py::list>(input_object);
  364. py::tuple tuple_inputs(list_inputs.size());
  365. for (size_t i = 0; i < tuple_inputs.size(); ++i) {
  366. tuple_inputs[i] = list_inputs[i];
  367. }
  368. ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
  369. return;
  370. } else if (py::isinstance<py::tuple>(input_object)) {
  371. ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
  372. return;
  373. } else if (py::isinstance<py::none>(input_object)) {
  374. return;
  375. } else {
  376. MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
  377. }
  378. MS_EXCEPTION_IF_NULL(tensor_ptr);
  379. input_tensors->emplace_back(tensor_ptr);
  380. }
  381. void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t> *tensors_mask,
  382. std::vector<tensor::TensorPtr> *input_tensors) {
  383. MS_EXCEPTION_IF_NULL(op_run_info);
  384. MS_EXCEPTION_IF_NULL(tensors_mask);
  385. MS_EXCEPTION_IF_NULL(input_tensors);
  386. PrimitivePtr op_prim = op_run_info->py_primitive;
  387. MS_EXCEPTION_IF_NULL(op_prim);
  388. opt::ConstInputToAttrInfoRegister reg;
  389. bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
  390. if (op_run_info->is_dynamic_shape &&
  391. dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) {
  392. MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name;
  393. reg_exist = false;
  394. }
  395. auto ms_context = MsContext::GetInstance();
  396. MS_EXCEPTION_IF_NULL(ms_context);
  397. if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
  398. if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) {
  399. reg_exist = false;
  400. }
  401. }
  402. if (op_run_info->op_name == prim::kPrimGatherD->name()) {
  403. // Gather op needs converting const input to attr on GPU device
  404. if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
  405. reg_exist = false;
  406. }
  407. }
  408. op_prim->BeginRecordAddAttr();
  409. size_t input_num = op_run_info->op_inputs.size();
  410. for (size_t index = 0; index < input_num; ++index) {
  411. // convert const input to attr
  412. if (reg_exist &&
  413. RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
  414. continue;
  415. }
  416. // convert const and tuple input to tensor
  417. int64_t tensor_mask = static_cast<int64_t>(op_run_info->inputs_mask[index]);
  418. ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
  419. // mark tensors, data : 0, weight : 1, valuenode: 2
  420. std::vector<int64_t> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
  421. tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
  422. }
  423. op_prim->EndRecordAddAttr();
  424. }
  425. BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
  426. if (utils::isa<VectorRef>(base_ref)) {
  427. auto ref_list = utils::cast<VectorRef>(base_ref);
  428. py::tuple output_tensors(ref_list.size());
  429. for (size_t i = 0; i < ref_list.size(); ++i) {
  430. auto output = TransformBaseRefListToTuple(ref_list[i]);
  431. if (utils::isa<tensor::TensorPtr>(output)) {
  432. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  433. MS_EXCEPTION_IF_NULL(tensor_ptr);
  434. output_tensors[i] = tensor_ptr;
  435. } else if (utils::isa<PyObjectRef>(output)) {
  436. py::object obj = utils::cast<PyObjectRef>(output).object_;
  437. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  438. output_tensors[i] = tensor_tuple;
  439. } else {
  440. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  441. }
  442. }
  443. return std::make_shared<PyObjectRef>(output_tensors);
  444. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  445. return base_ref;
  446. } else {
  447. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  448. }
  449. }
  450. size_t GetTupleSize(const py::tuple &args) {
  451. size_t count = 0;
  452. for (size_t i = 0; i < args.size(); i++) {
  453. if (py::isinstance<py::tuple>(args[i])) {
  454. count += GetTupleSize(args[i]);
  455. } else {
  456. count += 1;
  457. }
  458. }
  459. return count;
  460. }
  461. void ConvertTupleArg(py::tuple *res, size_t *index, const py::tuple &arg) {
  462. for (size_t i = 0; i < arg.size(); i++) {
  463. if (py::isinstance<py::tuple>(arg[i])) {
  464. ConvertTupleArg(res, index, arg[i]);
  465. } else {
  466. (*res)[(*index)++] = arg[i];
  467. }
  468. }
  469. }
  470. py::tuple ConvertArgs(const py::tuple &args) {
  471. size_t tuple_size = GetTupleSize(args);
  472. py::tuple res(tuple_size);
  473. size_t index = 0;
  474. for (size_t i = 0; i < args.size(); i++) {
  475. if (py::isinstance<py::tuple>(args[i])) {
  476. ConvertTupleArg(&res, &index, args[i]);
  477. } else {
  478. res[index++] = args[i];
  479. }
  480. }
  481. return res;
  482. }
  483. void ClearPyNativeSession() { session = nullptr; }
  484. PynativeExecutor::~PynativeExecutor() {
  485. MS_LOG(DEBUG) << "PynativeExecutor destructor";
  486. ClearRes();
  487. }
  488. void CheckPyNativeContext() {
  489. auto parallel_context = parallel::ParallelContext::GetInstance();
  490. MS_EXCEPTION_IF_NULL(parallel_context);
  491. auto ms_context = MsContext::GetInstance();
  492. MS_EXCEPTION_IF_NULL(ms_context);
  493. auto parallel_mode = parallel_context->parallel_mode();
  494. if (parallel_mode != parallel::STAND_ALONE && parallel_mode != parallel::DATA_PARALLEL &&
  495. ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
  496. MS_LOG(EXCEPTION) << "PyNative Only support STAND_ALONE and DATA_PARALLEL, but got:" << parallel_mode;
  497. }
  498. }
  499. py::object RunOp(const py::args &args) {
  500. CheckPyNativeContext();
  501. auto executor = PynativeExecutor::GetInstance();
  502. MS_EXCEPTION_IF_NULL(executor);
  503. OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args);
  504. MS_EXCEPTION_IF_NULL(op_exec_info);
  505. MS_LOG(DEBUG) << "RunOp name: " << op_exec_info->op_name << " start, args: " << args.size();
  506. try {
  507. return executor->RunOpInner(op_exec_info);
  508. } catch (const py::error_already_set &ex) {
  509. executor->ClearRes();
  510. // re-throw this exception to Python interpreter to handle it
  511. throw(py::error_already_set(ex));
  512. } catch (const py::type_error &ex) {
  513. executor->ClearRes();
  514. throw py::type_error(ex);
  515. } catch (const py::value_error &ex) {
  516. executor->ClearRes();
  517. throw py::value_error(ex);
  518. } catch (const py::index_error &ex) {
  519. executor->ClearRes();
  520. throw py::index_error(ex);
  521. } catch (const std::exception &ex) {
  522. executor->ClearRes();
  523. // re-throw this exception to Python interpreter to handle it
  524. throw(std::runtime_error(ex.what()));
  525. } catch (...) {
  526. executor->ClearRes();
  527. std::string exName(abi::__cxa_current_exception_type()->name());
  528. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
  529. }
  530. }
  531. py::object PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
  532. MS_EXCEPTION_IF_NULL(op_exec_info);
  533. if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
  534. py::tuple ret = RunOpWithInitBackendPolicy(op_exec_info);
  535. if (ret.size() == 1) {
  536. return ret[0];
  537. }
  538. return std::move(ret);
  539. }
  540. // make cnode for building grad graph if grad flag is set.
  541. abstract::AbstractBasePtrList args_spec_list;
  542. std::vector<bool> op_masks;
  543. auto cnode = MakeCNode(op_exec_info, &op_masks, &args_spec_list);
  544. op_exec_info->inputs_mask = op_masks;
  545. // get output abstract info
  546. bool is_find = false;
  547. GetOpOutputAbstract(op_exec_info, args_spec_list, &is_find);
  548. MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString();
  549. // infer output value for const prim
  550. auto prim = op_exec_info->py_primitive;
  551. MS_EXCEPTION_IF_NULL(prim);
  552. py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
  553. if (!output["value"].is_none()) {
  554. return output["value"];
  555. }
  556. if (prim->is_const_prim()) {
  557. return py::cast("");
  558. }
  559. // add output abstract info into cache
  560. if (!is_find && !op_exec_info->is_dynamic_shape) {
  561. // const_value need infer every step
  562. auto &out = prim_abs_list_[prim->id()];
  563. out[args_spec_list].abs = op_exec_info->abstract;
  564. out[args_spec_list].attrs = prim->evaluate_added_attrs();
  565. MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
  566. }
  567. // run op with selected backend
  568. auto result = RunOpWithInitBackendPolicy(op_exec_info);
  569. py::object out_real;
  570. if (result.size() == 1 && op_exec_info->abstract != nullptr &&
  571. !op_exec_info->abstract->isa<abstract::AbstractSequeue>()) {
  572. out_real = result[0];
  573. } else {
  574. out_real = result;
  575. }
  576. // update output abstract for cnode
  577. if (cnode != nullptr) {
  578. cnode->set_abstract(op_exec_info->abstract);
  579. }
  580. std::string obj_id = GetId(out_real);
  581. node_abs_map_[obj_id] = op_exec_info->abstract;
  582. // save info for building grad graph
  583. SaveOutputNodeMap(obj_id, out_real, cnode);
  584. SaveAllResult(op_exec_info, cnode, out_real);
  585. // Update the abstract and device address of value node with tensor in grad graph
  586. UpdateAbstractAndDeviceAddress(op_exec_info, out_real);
  587. return out_real;
  588. }
  589. OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
  590. if (args.size() != PY_ARGS_NUM) {
  591. MS_LOG(ERROR) << "Three args are needed by RunOp";
  592. return nullptr;
  593. }
  594. auto op_exec_info = std::make_shared<OpExecInfo>();
  595. auto op_name = py::cast<std::string>(args[PY_NAME]);
  596. op_exec_info->op_name = op_name;
  597. if (grad_flag()) {
  598. op_exec_info->op_index = op_name + std::to_string(op_index_map_[op_name]);
  599. if (!cell_op_info_stack_.empty()) {
  600. std::string &cell_op_info = cell_op_info_stack_.top();
  601. cell_op_info += op_exec_info->op_index;
  602. }
  603. op_index_map_[op_name]++;
  604. }
  605. auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
  606. MS_EXCEPTION_IF_NULL(prim);
  607. if (!prim->HasPyObj()) {
  608. MS_LOG(EXCEPTION) << "Pyobj is empty";
  609. }
  610. op_exec_info->prim_id = GetId(prim->GetPyObj());
  611. op_exec_info->py_primitive = prim;
  612. op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
  613. op_exec_info->op_inputs = args[PY_INPUTS];
  614. return op_exec_info;
  615. }
  616. AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
  617. abstract::AbstractBasePtrList *args_spec_list) {
  618. MS_EXCEPTION_IF_NULL(op_masks);
  619. MS_EXCEPTION_IF_NULL(args_spec_list);
  620. MS_EXCEPTION_IF_NULL(op_exec_info);
  621. auto prim = op_exec_info->py_primitive;
  622. std::vector<AnfNodePtr> inputs;
  623. inputs.emplace_back(NewValueNode(prim));
  624. const auto &signature = prim->signatures();
  625. auto sig_size = signature.size();
  626. auto size = op_exec_info->op_inputs.size();
  627. // ignore signature for cast op
  628. if (sig_size > 0 && sig_size != size) {
  629. MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
  630. << "inputs size " << sig_size;
  631. }
  632. if (op_exec_info->op_name != prim::kPrimCast->name()) {
  633. RunParameterAutoMixPrecisionCast(op_exec_info);
  634. }
  635. MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag();
  636. for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
  637. abstract::AbstractBasePtr abs = nullptr;
  638. const auto &obj = op_exec_info->op_inputs[i];
  639. auto id = GetId(obj);
  640. auto it = node_abs_map_.find(id);
  641. if (it != node_abs_map_.end()) {
  642. abs = it->second;
  643. }
  644. bool op_mask = false;
  645. if (py::isinstance<tensor::MetaTensor>(obj)) {
  646. auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
  647. if (meta_tensor) {
  648. op_mask = meta_tensor->is_parameter();
  649. }
  650. }
  651. MS_LOG(DEBUG) << "Gen args i " << i << " op_mask " << op_mask;
  652. (*op_masks).emplace_back(op_mask);
  653. if (need_construct_graph()) {
  654. AnfNodePtr input_node = nullptr;
  655. if (!graph_info_map_.empty() && !top_cell_list_.empty()) {
  656. input_node = GetInput(obj, op_mask);
  657. }
  658. // update abstract
  659. if (input_node != nullptr && input_node->abstract() != nullptr) {
  660. abs = input_node->abstract();
  661. }
  662. if (input_node != nullptr) {
  663. inputs.emplace_back(input_node);
  664. }
  665. }
  666. auto const_input_index = prim->get_const_input_indexes();
  667. bool have_const_input = !const_input_index.empty();
  668. bool is_const_prim = prim->is_const_prim();
  669. MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
  670. << prim->is_const_prim();
  671. bool is_const_input =
  672. have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
  673. if (abs == nullptr || is_const_prim || is_const_input) {
  674. MS_LOG(DEBUG) << "MakeCnode get node no in map " << id;
  675. ValuePtr input_value = PyAttrValue(obj);
  676. abs = input_value->ToAbstract();
  677. if (!is_const_prim && !is_const_input) {
  678. auto config = abstract::AbstractBase::kBroadenTensorOnly;
  679. abs = abs->Broaden(config);
  680. MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
  681. }
  682. node_abs_map_[id] = abs;
  683. }
  684. (*args_spec_list).emplace_back(abs);
  685. }
  686. CNodePtr cnode = nullptr;
  687. if (need_construct_graph()) {
  688. MS_EXCEPTION_IF_NULL(curr_g_);
  689. cnode = curr_g_->NewCNode(inputs);
  690. MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << " new cnode is " << cnode->DebugString(4);
  691. }
  692. return cnode;
  693. }
  694. void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
  695. const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) {
  696. MS_EXCEPTION_IF_NULL(is_find);
  697. MS_EXCEPTION_IF_NULL(op_exec_info);
  698. *is_find = false;
  699. auto op_name = op_exec_info->op_name;
  700. auto prim = op_exec_info->py_primitive;
  701. MS_EXCEPTION_IF_NULL(prim);
  702. if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
  703. auto abs_list = prim_abs_list_[prim->id()];
  704. MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
  705. if (abs_list.find(args_spec_list) != abs_list.end()) {
  706. MS_LOG(DEBUG) << "Match prim ok " << op_name;
  707. op_exec_info->abstract = abs_list[args_spec_list].abs;
  708. prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
  709. *is_find = true;
  710. }
  711. }
  712. if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) {
  713. // use python infer method
  714. if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) {
  715. PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);
  716. }
  717. }
  718. // get output dynamic shape info
  719. auto py_abstract = op_exec_info->abstract;
  720. MS_EXCEPTION_IF_NULL(py_abstract);
  721. auto py_shape = py_abstract->BuildShape();
  722. MS_EXCEPTION_IF_NULL(py_shape);
  723. auto py_shape_info = py_shape->ToString();
  724. if (py_shape_info.find("-1") != string::npos) {
  725. if (DynamicShapeConstInputToAttr.find(op_name) != DynamicShapeConstInputToAttr.end()) {
  726. auto new_prim_name = "Dynamic" + op_name;
  727. auto attrs = prim->attrs();
  728. prim = std::make_shared<PrimitivePy>(new_prim_name, py::object());
  729. prim->SetAttrs(attrs);
  730. }
  731. auto c_abstract = abstract::CppInferShape(prim, args_spec_list);
  732. MS_EXCEPTION_IF_NULL(c_abstract);
  733. auto c_shape = c_abstract->BuildShape();
  734. MS_EXCEPTION_IF_NULL(c_shape);
  735. auto c_shape_info = c_shape->ToString();
  736. MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info;
  737. if (c_shape_info.find("-1") != string::npos) {
  738. op_exec_info->is_dynamic_shape = true;
  739. }
  740. }
  741. }
  742. py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
  743. size_t index) {
  744. py::tuple cast_args(3);
  745. cast_args[PY_PRIM] = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
  746. cast_args[PY_NAME] = prim::kPrimCast->name();
  747. std::string dst_type_str = TypeIdToMsTypeStr(type_id);
  748. py::object dst_type = parse::python_adapter::GetPyFn(kMSDtypeModelName, dst_type_str);
  749. py::tuple inputs(2);
  750. inputs[0] = arg;
  751. inputs[1] = dst_type;
  752. cast_args[PY_INPUTS] = inputs;
  753. auto op_exec = GenerateOpExecInfo(cast_args);
  754. op_exec->is_mixed_precision_cast = true;
  755. op_exec->next_op_name = op_name;
  756. op_exec->next_input_index = index;
  757. return RunOpInner(op_exec);
  758. }
  759. py::object PynativeExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name,
  760. size_t index) {
  761. MS_EXCEPTION_IF_NULL(is_cast);
  762. auto tensor = py::cast<tensor::TensorPtr>(obj);
  763. auto cast_type = tensor->cast_dtype();
  764. py::object cast_output = obj;
  765. if (cast_type != nullptr) {
  766. auto source_element = tensor->Dtype();
  767. if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
  768. MS_LOG(DEBUG) << "Cast to " << cast_type->ToString();
  769. *is_cast = true;
  770. return DoAutoCast(obj, cast_type->type_id(), op_name, index);
  771. }
  772. }
  773. return cast_output;
  774. }
  775. py::object PynativeExecutor::DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple,
  776. const std::string &op_name, size_t index) {
  777. MS_EXCEPTION_IF_NULL(is_cast);
  778. auto tuple_size = static_cast<int64_t>(tuple.size());
  779. py::tuple result(tuple_size);
  780. for (int64_t i = 0; i < tuple_size; i++) {
  781. if (py::isinstance<tensor::MetaTensor>(tuple[i])) {
  782. MS_LOG(DEBUG) << "Call cast for item " << i;
  783. result[i] = DoParamMixPrecisionCast(is_cast, tuple[i], op_name, index);
  784. } else if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) {
  785. result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i], op_name, index);
  786. } else {
  787. result[i] = tuple[i];
  788. }
  789. }
  790. return std::move(result);
  791. }
  792. void PynativeExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
  793. const std::vector<SignatureEnumDType> &dtypes,
  794. const OpExecInfoPtr &op_exec_info) {
  795. const auto &signature = prim->signatures();
  796. auto &out_args = op_exec_info->op_inputs;
  797. for (size_t i = 0; i < out_args.size(); ++i) {
  798. // No need to implicit cast if no dtype.
  799. if (dtypes.empty() || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
  800. continue;
  801. }
  802. auto it = dst_type.find(dtypes[i]);
  803. if (it == dst_type.end() || it->second == kTypeUnknown) {
  804. continue;
  805. }
  806. MS_LOG(DEBUG) << "Check inputs " << i;
  807. auto obj = out_args[i];
  808. auto sig = SignatureEnumRW::kRWDefault;
  809. if (!signature.empty()) {
  810. sig = signature[i].rw;
  811. }
  812. bool is_parameter = false;
  813. TypeId arg_type_id = kTypeUnknown;
  814. if (py::isinstance<tensor::MetaTensor>(obj)) {
  815. auto arg = py::cast<tensor::MetaTensorPtr>(obj);
  816. if (arg->is_parameter()) {
  817. is_parameter = true;
  818. MS_LOG(DEBUG) << "Parameter is read " << i;
  819. }
  820. arg_type_id = arg->data_type();
  821. }
  822. // implicit cast
  823. bool is_same_type = false;
  824. if (arg_type_id != kTypeUnknown) {
  825. is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
  826. }
  827. if (sig == SignatureEnumRW::kRWWrite) {
  828. if (!is_parameter) {
  829. prim::RaiseExceptionForCheckParameter(prim->name(), i, "not");
  830. }
  831. if (arg_type_id != kTypeUnknown) {
  832. if (!is_same_type) {
  833. prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id),
  834. TypeIdToMsTypeStr(it->second));
  835. }
  836. }
  837. }
  838. if (is_same_type) {
  839. continue;
  840. }
  841. if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) {
  842. MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i
  843. << "th input is a not support implicit conversion type: "
  844. << py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is "
  845. << py::cast<py::str>(obj) << ".";
  846. }
  847. py::object cast_output = DoAutoCast(out_args[i], it->second, op_exec_info->op_name, i);
  848. out_args[i] = cast_output;
  849. }
  850. }
  851. void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info) {
  852. size_t size = op_exec_info->op_inputs.size();
  853. auto prim = op_exec_info->py_primitive;
  854. MS_EXCEPTION_IF_NULL(prim);
  855. const auto &signature = prim->signatures();
  856. for (size_t i = 0; i < size; i++) {
  857. auto obj = op_exec_info->op_inputs[i];
  858. auto sig = SignatureEnumRW::kRWDefault;
  859. if (!signature.empty()) {
  860. sig = signature[i].rw;
  861. }
  862. MS_LOG(DEBUG) << "Check mix precision " << op_exec_info->op_name << " input " << i << " "
  863. << std::string(py::repr(obj));
  864. // mix precision for non param
  865. bool is_cast = false;
  866. py::object cast_output;
  867. if (py::isinstance<tensor::MetaTensor>(obj)) {
  868. auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
  869. if (meta_tensor && meta_tensor->is_parameter()) {
  870. if (sig != SignatureEnumRW::kRWRead) {
  871. continue;
  872. }
  873. }
  874. // redundant cast call if the tensor is a const Tensor.
  875. cast_output = DoParamMixPrecisionCast(&is_cast, obj, prim->name(), i);
  876. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  877. // mix precision for tuple inputs
  878. cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj, prim->name(), i);
  879. }
  880. if (is_cast) {
  881. op_exec_info->op_inputs[i] = cast_output;
  882. }
  883. }
  884. std::vector<SignatureEnumDType> dtypes;
  885. bool has_dtype_sig = GetSignatureType(prim, &dtypes);
  886. std::map<SignatureEnumDType, TypeId> dst_types;
  887. if (has_dtype_sig) {
  888. // fetch info for implicit cast
  889. auto type_indexes = GetTypeIndex(dtypes);
  890. dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
  891. }
  892. MS_LOG(DEBUG) << "Do signature for " << op_exec_info->op_name;
  893. DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
  894. }
  895. AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
  896. AnfNodePtr node = nullptr;
  897. std::string obj_id = GetId(obj);
  898. if (op_mask) {
  899. MS_LOG(DEBUG) << "Cell parameters(weights)";
  900. // get the parameter name from parameter object
  901. auto name_attr = parse::python_adapter::GetPyObjAttr(obj, "name");
  902. if (py::isinstance<py::none>(name_attr)) {
  903. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  904. }
  905. auto param_name = py::cast<std::string>(name_attr);
  906. auto df_builder = GetDfbuilder(top_cell_id_);
  907. MS_EXCEPTION_IF_NULL(df_builder);
  908. auto graph_info = graph_info_map_.at(df_builder);
  909. MS_EXCEPTION_IF_NULL(graph_info);
  910. if (graph_info->params.find(obj_id) == graph_info->params.end()) {
  911. auto free_param = df_builder->add_parameter();
  912. free_param->set_name(param_name);
  913. free_param->debug_info()->set_name(param_name);
  914. auto value = py::cast<tensor::TensorPtr>(obj);
  915. free_param->set_default_param(value);
  916. MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
  917. SetParamNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
  918. SetParamNodeMapInGraphInfoMap(curr_g_, obj_id, free_param);
  919. SetNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
  920. SetNodeMapInGraphInfoMap(curr_g_, obj_id, free_param);
  921. return free_param;
  922. }
  923. node = graph_info->node_map.at(obj_id).first;
  924. MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id;
  925. return node;
  926. }
  927. auto graph_info = graph_info_map_.at(curr_g_);
  928. MS_EXCEPTION_IF_NULL(graph_info);
  929. if (graph_info->node_map.find(obj_id) != graph_info->node_map.end()) {
  930. // op(x, y)
  931. // out = op(op1(x, y))
  932. // out = op(cell1(x, y))
  933. // out = op(cell1(x, y)[0])
  934. node = GetObjNode(obj, obj_id);
  935. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  936. // out = op((x, y))
  937. // out = cell((x, y))
  938. auto tuple = obj.cast<py::tuple>();
  939. // cell((1,2)): support not mix (scalar, tensor)
  940. if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) {
  941. return MakeValueNode(obj, obj_id);
  942. }
  943. std::vector<AnfNodePtr> args;
  944. args.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  945. auto tuple_size = tuple.size();
  946. for (size_t i = 0; i < tuple_size; i++) {
  947. args.emplace_back(GetInput(tuple[i], false));
  948. }
  949. auto cnode = curr_g_->NewCNode(args);
  950. SetNodeMapInGraphInfoMap(curr_g_, GetId(obj), cnode);
  951. node = cnode;
  952. } else {
  953. node = MakeValueNode(obj, obj_id);
  954. }
  955. node == nullptr ? MS_LOG(DEBUG) << "Get node is nullptr"
  956. : MS_LOG(DEBUG) << "Get input node " << node->ToString() << " " << obj_id;
  957. return node;
  958. }
  959. void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real) {
  960. MS_EXCEPTION_IF_NULL(op_exec_info);
  961. if (!grad_flag()) {
  962. return;
  963. }
  964. auto op_index = op_exec_info->op_index;
  965. auto output_value = PyAttrValue(out_real);
  966. MS_EXCEPTION_IF_NULL(output_value);
  967. std::vector<tensor::TensorPtr> output_tensors;
  968. TensorValueToTensor(output_value, &output_tensors);
  969. if (cell_op_index_with_tensor_id_[top_cell_id_].find(op_index) == cell_op_index_with_tensor_id_[top_cell_id_].end()) {
  970. // first step
  971. std::for_each(output_tensors.begin(), output_tensors.end(), [&](const tensor::TensorPtr &tensor) {
  972. cell_op_index_with_tensor_id_[top_cell_id_][op_index].emplace_back(tensor->id());
  973. });
  974. return;
  975. }
  976. auto ms_context = MsContext::GetInstance();
  977. auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  978. const auto &tensor_id_list = cell_op_index_with_tensor_id_[top_cell_id_][op_index];
  979. for (size_t i = 0; i < tensor_id_list.size(); ++i) {
  980. auto tensor_id = tensor_id_list[i];
  981. if (cell_tensor_id_with_tensor_[top_cell_id_].find(tensor_id) != cell_tensor_id_with_tensor_[top_cell_id_].end()) {
  982. auto &new_tensor = output_tensors[i];
  983. auto &tensors_in_value_node = cell_tensor_id_with_tensor_[top_cell_id_][tensor_id];
  984. std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) {
  985. MS_LOG(DEBUG) << "Debug address: Replace forward old tensor obj " << tensor.get() << ", tensor id "
  986. << tensor->id() << ", device address " << tensor->device_address().get()
  987. << " with New tensor obj " << new_tensor.get() << ", tensor id " << new_tensor->id()
  988. << ", device address " << new_tensor->device_address().get();
  989. tensor->set_shape(new_tensor->shape());
  990. tensor->set_data_type(new_tensor->data_type());
  991. if (target != kCPUDevice) {
  992. tensor->set_device_address(new_tensor->device_address());
  993. } else {
  994. auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
  995. auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
  996. auto old_ptr = old_device_address->GetMutablePtr();
  997. auto new_ptr = new_device_address->GetPtr();
  998. MS_EXCEPTION_IF_NULL(old_ptr);
  999. MS_EXCEPTION_IF_NULL(new_ptr);
  1000. auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize());
  1001. if (ret != EOK) {
  1002. MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret;
  1003. }
  1004. }
  1005. });
  1006. }
  1007. }
  1008. }
  1009. void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) {
  1010. MS_EXCEPTION_IF_NULL(resource);
  1011. std::set<std::string> forward_op_tensor_id;
  1012. for (const auto &elem : cell_op_index_with_tensor_id_[top_cell_id_]) {
  1013. const auto &tensor_id_list = elem.second;
  1014. for (const auto &tensor_id : tensor_id_list) {
  1015. forward_op_tensor_id.emplace(tensor_id);
  1016. }
  1017. }
  1018. cell_tensor_id_with_tensor_[top_cell_id_].clear();
  1019. const auto &func_graph = resource->func_graph();
  1020. const auto &value_node_list = func_graph->value_nodes();
  1021. for (const auto &elem : value_node_list) {
  1022. auto value_node = elem.first->cast<ValueNodePtr>();
  1023. MS_EXCEPTION_IF_NULL(value_node);
  1024. std::vector<tensor::TensorPtr> tensors;
  1025. TensorValueToTensor(value_node->value(), &tensors);
  1026. for (const auto &tensor : tensors) {
  1027. if (tensor->device_address() != nullptr &&
  1028. forward_op_tensor_id.find(tensor->id()) != forward_op_tensor_id.end()) {
  1029. cell_tensor_id_with_tensor_[top_cell_id_][tensor->id()].emplace_back(tensor);
  1030. MS_LOG(DEBUG) << "Debug address: Save forward tensor obj " << tensor.get() << ", tensor id " << tensor->id()
  1031. << ", device address " << tensor->device_address().get();
  1032. }
  1033. }
  1034. }
  1035. }
  1036. void PynativeExecutor::CleanPreMemoryInValueNode() {
  1037. auto ms_context = MsContext::GetInstance();
  1038. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  1039. if (device_target == "CPU") {
  1040. return;
  1041. }
  1042. if (has_dynamic_cell_) {
  1043. std::set<std::string> forward_op_tensor_id;
  1044. for (const auto &elem : cell_op_index_with_tensor_id_[top_cell_id_]) {
  1045. const auto &tensor_id_list = elem.second;
  1046. for (const auto &tensor_id : tensor_id_list) {
  1047. forward_op_tensor_id.emplace(tensor_id);
  1048. }
  1049. }
  1050. for (auto &tensor : all_value_node_tensors_) {
  1051. if (tensor->device_address() != nullptr &&
  1052. forward_op_tensor_id.find(tensor->id()) != forward_op_tensor_id.end()) {
  1053. tensor->device_address()->ClearDeviceMemory();
  1054. tensor->set_device_address(nullptr);
  1055. }
  1056. }
  1057. all_value_node_tensors_.clear();
  1058. }
  1059. const auto &tensor_id_with_tensor = cell_tensor_id_with_tensor_[top_cell_id_];
  1060. for (const auto &elem : tensor_id_with_tensor) {
  1061. const auto &tensors_in_value_node = elem.second;
  1062. for (const auto &tensor : tensors_in_value_node) {
  1063. MS_EXCEPTION_IF_NULL(tensor);
  1064. tensor->set_device_address(nullptr);
  1065. }
  1066. }
  1067. }
  1068. AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
  1069. auto graph_info = graph_info_map_.at(curr_g_);
  1070. MS_EXCEPTION_IF_NULL(graph_info);
  1071. auto &out = graph_info->node_map.at(obj_id);
  1072. if (out.second.size() == 1 && out.second[0] == -1) {
  1073. return out.first;
  1074. }
  1075. MS_LOG(DEBUG) << "Output size " << out.second.size();
  1076. // Params node
  1077. if (graph_info->params.find(obj_id) != graph_info->params.end()) {
  1078. auto para_node = out.first;
  1079. for (auto &idx : out.second) {
  1080. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node,
  1081. NewValueNode(idx)};
  1082. para_node = curr_g_->NewCNode(tuple_get_item_inputs);
  1083. }
  1084. return para_node;
  1085. }
  1086. // Normal node
  1087. auto node = out.first->cast<CNodePtr>();
  1088. auto abs = node->abstract();
  1089. ValuePtr out_obj = nullptr;
  1090. if (node->forward().first != nullptr) {
  1091. out_obj = node->forward().first;
  1092. } else {
  1093. out_obj = PyAttrValue(obj);
  1094. }
  1095. for (auto &idx : out.second) {
  1096. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
  1097. node = curr_g_->NewCNode(tuple_get_item_inputs);
  1098. if (out_obj->isa<ValueTuple>()) {
  1099. node->add_input_value(out_obj, "");
  1100. node->add_input_value(MakeValue(idx), "");
  1101. out_obj = (*out_obj->cast<ValueTuplePtr>())[idx];
  1102. node->set_forward(out_obj, "");
  1103. }
  1104. if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
  1105. auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[idx];
  1106. MS_LOG(DEBUG) << "Set tuple getitem abs " << prim_abs->ToString();
  1107. node->set_abstract(prim_abs);
  1108. }
  1109. }
  1110. if (node->abstract() != nullptr) {
  1111. node_abs_map_[obj_id] = node->abstract();
  1112. }
  1113. MS_LOG(DEBUG) << "GetObjNode output " << node->DebugString(6);
  1114. return node;
  1115. }
  1116. AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) {
  1117. ValuePtr converted_ret = nullptr;
  1118. parse::ConvertData(obj, &converted_ret);
  1119. auto node = NewValueNode(converted_ret);
  1120. SetNodeMapInGraphInfoMap(curr_g_, obj_id, node);
  1121. return node;
  1122. }
  1123. void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real,
  1124. const AnfNodePtr &cnode) {
  1125. if (!need_construct_graph()) {
  1126. MS_LOG(DEBUG) << "No need save output";
  1127. return;
  1128. }
  1129. MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString(4) << " id " << obj_id;
  1130. if (py::isinstance<py::tuple>(out_real)) {
  1131. auto value = py::cast<py::tuple>(out_real);
  1132. auto size = static_cast<int64_t>(value.size());
  1133. if (size > 1) {
  1134. for (int64_t i = 0; i < size; ++i) {
  1135. auto value_id = GetId(value[i]);
  1136. SetNodeMapInGraphInfoMap(curr_g_, value_id, cnode, i);
  1137. }
  1138. }
  1139. }
  1140. SetNodeMapInGraphInfoMap(curr_g_, obj_id, cnode);
  1141. SetPyObjInGraphInfoMap(curr_g_, obj_id);
  1142. }
  1143. void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node,
  1144. const py::object &out_real) {
  1145. if (!grad_flag() || node == nullptr) {
  1146. return;
  1147. }
  1148. MS_EXCEPTION_IF_NULL(op_exec_info);
  1149. auto cnode = node->cast<CNodePtr>();
  1150. MS_EXCEPTION_IF_NULL(cnode);
  1151. // save input object
  1152. size_t size = op_exec_info->op_inputs.size();
  1153. for (size_t i = 0; i < size; i++) {
  1154. auto obj = op_exec_info->op_inputs[i];
  1155. auto obj_id = GetId(obj);
  1156. if (obj_to_forward_id_.find(obj_id) != obj_to_forward_id_.end()) {
  1157. cnode->add_input_value(PyAttrValue(obj), obj_to_forward_id_[obj_id]);
  1158. } else {
  1159. cnode->add_input_value(nullptr, "");
  1160. }
  1161. }
  1162. // save output object
  1163. auto output_value = PyAttrValue(out_real);
  1164. MS_EXCEPTION_IF_NULL(output_value);
  1165. cnode->set_forward(output_value, op_exec_info->op_index);
  1166. auto out_id = GetId(out_real);
  1167. if (py::isinstance<py::tuple>(out_real)) {
  1168. auto tuple_item = py::cast<py::tuple>(out_real);
  1169. for (size_t i = 0; i < tuple_item.size(); i++) {
  1170. auto tuple_item_id = GetId(tuple_item[i]);
  1171. obj_to_forward_id_[tuple_item_id] = op_exec_info->op_index;
  1172. }
  1173. }
  1174. obj_to_forward_id_[out_id] = op_exec_info->op_index;
  1175. }
  1176. void PynativeExecutor::GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map) {
  1177. if (t_map == nullptr) {
  1178. return;
  1179. }
  1180. for (size_t i = 0; i < tuple->size(); i++) {
  1181. ValuePtr tuple_i = (*tuple)[i];
  1182. if (tuple_i->isa<tensor::Tensor>()) {
  1183. auto t = tuple_i->cast<tensor::TensorPtr>();
  1184. (*t_map)[t->id()] = t;
  1185. } else if (tuple_i->isa<ValueTuple>()) {
  1186. GenTupleMap(tuple_i->cast<ValueTuplePtr>(), t_map);
  1187. }
  1188. }
  1189. MS_LOG(DEBUG) << "End GenTupleMap " << tuple->ToString();
  1190. }
  1191. ValuePtr PynativeExecutor::CleanTupleAddr(const ValueTuplePtr &tuple) {
  1192. std::vector<ValuePtr> value_list;
  1193. for (size_t i = 0; i < tuple->size(); i++) {
  1194. ValuePtr tuple_i = (*tuple)[i];
  1195. if (tuple_i->isa<tensor::Tensor>()) {
  1196. auto t = tuple_i->cast<tensor::TensorPtr>();
  1197. auto new_tensor = std::make_shared<tensor::Tensor>(*t);
  1198. new_tensor->set_device_address(nullptr);
  1199. value_list.emplace_back(new_tensor);
  1200. } else if (tuple_i->isa<ValueTuple>()) {
  1201. value_list.emplace_back(CleanTupleAddr(tuple_i->cast<ValueTuplePtr>()));
  1202. } else {
  1203. MS_LOG(DEBUG) << "Tuple[i] value " << tuple_i->ToString();
  1204. value_list.emplace_back(tuple_i);
  1205. }
  1206. }
  1207. MS_LOG(DEBUG) << "End CleanTupleAddr";
  1208. return std::make_shared<ValueTuple>(value_list);
  1209. }
  1210. py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) {
  1211. auto backend_policy = InitEnv(op_exec_info);
  1212. PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
  1213. // returns a null py::tuple on error
  1214. py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
  1215. if (status != PYNATIVE_SUCCESS) {
  1216. MS_LOG(EXCEPTION) << "Failed to run " << op_exec_info->op_name;
  1217. }
  1218. MS_LOG(DEBUG) << "RunOp end";
  1219. return result;
  1220. }
  1221. MsBackendPolicy PynativeExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) {
  1222. MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
  1223. parse::python_adapter::set_python_env_flag(true);
  1224. MsBackendPolicy backend_policy;
  1225. #if (!defined ENABLE_GE)
  1226. auto ms_context = MsContext::GetInstance();
  1227. MS_EXCEPTION_IF_NULL(ms_context);
  1228. if (!context::IsTsdOpened(ms_context)) {
  1229. if (!context::OpenTsd(ms_context)) {
  1230. MS_LOG(EXCEPTION) << "Open tsd failed";
  1231. }
  1232. }
  1233. if (ms_context->backend_policy() == "ms") {
  1234. backend_policy = kMsBackendMsPrior;
  1235. } else {
  1236. backend_policy = kMsBackendVmOnly;
  1237. }
  1238. #else
  1239. auto ms_context = MsContext::GetInstance();
  1240. MS_EXCEPTION_IF_NULL(ms_context);
  1241. context::PynativeInitGe(ms_context);
  1242. backend_policy = kMsBackendGeOnly;
  1243. #endif
  1244. if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
  1245. backend_policy = kMsBackendVmOnly;
  1246. }
  1247. return backend_policy;
  1248. }
  1249. py::object PynativeExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
  1250. PynativeStatusCode *const status) {
  1251. MS_EXCEPTION_IF_NULL(status);
  1252. py::object result;
  1253. switch (backend_policy) {
  1254. case kMsBackendVmOnly: {
  1255. // use vm only
  1256. MS_LOG(INFO) << "RunOp use VM only backend";
  1257. result = RunOpInVM(op_exec_info, status);
  1258. break;
  1259. }
  1260. case kMsBackendGePrior: {
  1261. #ifdef ENABLE_GE
  1262. // use GE first, use vm when GE fails
  1263. MS_LOG(INFO) << "RunOp use GE first backend";
  1264. result = RunOpInGE(op_exec_info, status);
  1265. if (*status != PYNATIVE_SUCCESS) {
  1266. result = RunOpInVM(op_exec_info, status);
  1267. }
  1268. #endif
  1269. break;
  1270. }
  1271. case kMsBackendMsPrior: {
  1272. // use Ms fisrt,use others when ms failed
  1273. MS_LOG(INFO) << "RunOp use Ms first backend";
  1274. result = RunOpInMs(op_exec_info, status);
  1275. if (*status != PYNATIVE_SUCCESS) {
  1276. MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
  1277. }
  1278. break;
  1279. }
  1280. default:
  1281. MS_LOG(ERROR) << "No backend configured for run op";
  1282. }
  1283. return result;
  1284. }
  1285. py::object PynativeExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  1286. MS_LOG(INFO) << "RunOpInVM start";
  1287. MS_EXCEPTION_IF_NULL(status);
  1288. MS_EXCEPTION_IF_NULL(op_exec_info);
  1289. MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
  1290. auto &op_inputs = op_exec_info->op_inputs;
  1291. if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf" ||
  1292. op_exec_info->op_name == "stop_gradient") {
  1293. py::tuple result(op_inputs.size());
  1294. for (size_t i = 0; i < op_inputs.size(); i++) {
  1295. py::object input = op_inputs[i];
  1296. auto input_obj_id = GetId(input);
  1297. auto tensor = py::cast<tensor::TensorPtr>(input);
  1298. MS_EXCEPTION_IF_NULL(tensor);
  1299. if (obj_to_forward_id_.find(input_obj_id) == obj_to_forward_id_.end() &&
  1300. op_exec_info->op_name == "HookBackward") {
  1301. // the input object is not a output of forward cnode, eg: parameter
  1302. result[i] = tensor;
  1303. } else {
  1304. // the input object is a output of forward cnode
  1305. auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
  1306. new_tensor->set_device_address(tensor->device_address());
  1307. new_tensor->set_sync_status(tensor->sync_status());
  1308. result[i] = new_tensor;
  1309. }
  1310. }
  1311. *status = PYNATIVE_SUCCESS;
  1312. MS_LOG(INFO) << "RunOpInVM end";
  1313. return std::move(result);
  1314. }
  1315. auto primitive = op_exec_info->py_primitive;
  1316. MS_EXCEPTION_IF_NULL(primitive);
  1317. auto result = primitive->RunPyComputeFunction(op_inputs);
  1318. MS_LOG(INFO) << "RunOpInVM end";
  1319. if (py::isinstance<py::none>(result)) {
  1320. MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
  1321. *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
  1322. py::tuple err_ret(0);
  1323. return std::move(err_ret);
  1324. }
  1325. *status = PYNATIVE_SUCCESS;
  1326. if (py::isinstance<py::tuple>(result)) {
  1327. return result;
  1328. }
  1329. py::tuple tuple_result = py::make_tuple(result);
  1330. return std::move(tuple_result);
  1331. }
  1332. py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  1333. MS_EXCEPTION_IF_NULL(op_exec_info);
  1334. MS_EXCEPTION_IF_NULL(status);
  1335. MS_LOG(INFO) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms";
  1336. auto ms_context = MsContext::GetInstance();
  1337. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
  1338. if (session == nullptr) {
  1339. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  1340. session = session::SessionFactory::Get().Create(device_target);
  1341. MS_EXCEPTION_IF_NULL(session);
  1342. session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));
  1343. }
  1344. std::vector<tensor::TensorPtr> input_tensors;
  1345. std::vector<int64_t> tensors_mask;
  1346. ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
  1347. // get graph info for checking it whether existing in the cache
  1348. std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
  1349. #if defined(__APPLE__)
  1350. session::OpRunInfo op_run_info = {op_exec_info->op_name,
  1351. op_exec_info->py_primitive,
  1352. op_exec_info->abstract,
  1353. op_exec_info->is_dynamic_shape,
  1354. op_exec_info->is_mixed_precision_cast,
  1355. op_exec_info->next_op_name,
  1356. static_cast<int>(op_exec_info->next_input_index)};
  1357. #else
  1358. session::OpRunInfo op_run_info = {op_exec_info->op_name,
  1359. op_exec_info->py_primitive,
  1360. op_exec_info->abstract,
  1361. op_exec_info->is_dynamic_shape,
  1362. op_exec_info->is_mixed_precision_cast,
  1363. op_exec_info->next_op_name,
  1364. op_exec_info->next_input_index};
  1365. #endif
  1366. VectorRef outputs;
  1367. session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask);
  1368. if (op_exec_info->is_dynamic_shape) {
  1369. op_exec_info->abstract = op_run_info.abstract;
  1370. }
  1371. auto result = BaseRefToPyData(outputs);
  1372. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
  1373. *status = PYNATIVE_SUCCESS;
  1374. MS_LOG(INFO) << "End run op [" << op_exec_info->op_name << "] with backend policy ms";
  1375. return result;
  1376. }
  1377. void PynativeExecutor::PushCurrentGraphToStack() { graph_stack_.push(curr_g_); }
  1378. void PynativeExecutor::PushCurrentCellOpInfoToStack() {
  1379. std::string cell_op_info = "Cell ops: ";
  1380. cell_op_info_stack_.push(cell_op_info);
  1381. }
  1382. void PynativeExecutor::PopGraphStack() {
  1383. if (graph_stack_.empty()) {
  1384. MS_LOG(EXCEPTION) << "Stack graph_stack_ is empty";
  1385. }
  1386. graph_stack_.pop();
  1387. if (!graph_stack_.empty()) {
  1388. curr_g_ = graph_stack_.top();
  1389. }
  1390. }
  1391. void PynativeExecutor::PopCurrentCellOpInfoFromStack() {
  1392. if (cell_op_info_stack_.empty()) {
  1393. MS_LOG(EXCEPTION) << "The cell op info stack is empty";
  1394. }
  1395. cell_op_info_stack_.pop();
  1396. }
  1397. std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) {
  1398. auto cell_id = GetId(cell);
  1399. for (size_t i = 0; i < args.size(); i++) {
  1400. std::string arg_id = GetId(args[i]);
  1401. auto it = node_abs_map_.find(arg_id);
  1402. if (it != node_abs_map_.end()) {
  1403. cell_id += "_" + it->second->BuildShape()->ToString();
  1404. cell_id += it->second->BuildType()->ToString();
  1405. } else {
  1406. auto abs = PyAttrValue(args[i])->ToAbstract();
  1407. auto config = abstract::AbstractBase::kBroadenTensorOnly;
  1408. abs = abs->Broaden(config);
  1409. node_abs_map_[arg_id] = abs;
  1410. cell_id += "_" + abs->BuildShape()->ToString();
  1411. cell_id += abs->BuildType()->ToString();
  1412. }
  1413. }
  1414. return GetTensorCellId(cell_id);
  1415. }
  1416. std::string PynativeExecutor::GetTensorCellId(const std::string &cell_id) {
  1417. if (cell_id.find("NoShape") == std::string::npos) {
  1418. return cell_id;
  1419. }
  1420. std::string key = cell_id.substr(0, PTR_LEN);
  1421. auto fn = [](const std::string &str, std::vector<std::string> &value) {
  1422. size_t pos = 0;
  1423. size_t pre_pos = 0;
  1424. while ((pos = str.find_first_of('_', pre_pos)) != std::string::npos) {
  1425. value.emplace_back(str.substr(pre_pos, pos - pre_pos + 1));
  1426. pre_pos = pos + 1;
  1427. }
  1428. value.emplace_back(str.substr(pre_pos));
  1429. };
  1430. auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&key](const CellInfoPtr &value) {
  1431. return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos;
  1432. });
  1433. if (it != cell_graph_list_.end()) {
  1434. std::vector<std::string> pre_cell_id;
  1435. std::vector<std::string> cur_cell_id;
  1436. fn((*it)->cell_id, pre_cell_id);
  1437. fn(cell_id, cur_cell_id);
  1438. auto pre_tensor_size = pre_cell_id.size();
  1439. if (pre_tensor_size == cur_cell_id.size()) {
  1440. size_t same_tensor_count = 0;
  1441. for (size_t i = 0; i < pre_tensor_size; ++i) {
  1442. if (cur_cell_id[i].find("NoShape") != std::string::npos || cur_cell_id[i] == pre_cell_id[i]) {
  1443. ++same_tensor_count;
  1444. }
  1445. }
  1446. if (same_tensor_count == pre_tensor_size) {
  1447. MS_LOG(DEBUG) << "Changed cell id from " << cell_id << " to " << (*it)->cell_id;
  1448. return (*it)->cell_id;
  1449. }
  1450. }
  1451. }
  1452. return cell_id;
  1453. }
  1454. void PynativeExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) {
  1455. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  1456. DumpIR(filename, graph);
  1457. }
  1458. }
  1459. bool PynativeExecutor::IsNestedGrad() const {
  1460. MS_LOG(DEBUG) << "Grad nested order is " << grad_order_;
  1461. return grad_order_ > 1;
  1462. }
  1463. bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
  1464. return std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
  1465. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; });
  1466. }
  1467. bool PynativeExecutor::IsTopestGraph(const std::string &cell_id) {
  1468. return std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
  1469. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->is_topest; });
  1470. }
  1471. void PynativeExecutor::UpdateTopCellCompileInfo(const std::string &cell_id, bool vm_compiled) {
  1472. auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
  1473. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; });
  1474. if (it != top_cell_list_.end()) {
  1475. (*it)->do_vm_compiled = vm_compiled;
  1476. }
  1477. }
  1478. bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) {
  1479. return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) {
  1480. return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos;
  1481. });
  1482. }
  1483. bool PynativeExecutor::IsFirstGradStep(const std::string &cell_id) { return !CheckCellGraph(cell_id, true); }
  1484. void PynativeExecutor::SubNestedGradOrder() {
  1485. if (grad_order_ > 0) {
  1486. --grad_order_;
  1487. }
  1488. }
  1489. bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) {
  1490. return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfoPtr &value) {
  1491. return value->cell_id == cell_id && (!is_grad || value->is_grad);
  1492. });
  1493. }
  1494. bool PynativeExecutor::CheckDynamicCell(const std::string &cell_id) {
  1495. return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(),
  1496. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_dynamic; });
  1497. }
  1498. bool PynativeExecutor::CheckRealDynamicCell(const std::string &cell_id) {
  1499. return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) {
  1500. return value->cell_id == cell_id && value->is_real_dynamic;
  1501. });
  1502. }
  1503. void PynativeExecutor::ClearResidualRes(const std::string &cell_id) {
  1504. // Abnormal case
  1505. if (top_cell_list_.empty() && !graph_stack_.empty()) {
  1506. graph_id_ = 0;
  1507. graph_info_map_.clear();
  1508. cell_graph_list_.clear();
  1509. std::stack<FuncGraphPtr>().swap(graph_stack_);
  1510. }
  1511. if (CheckRealDynamicCell(cell_id)) {
  1512. if (IsTopGraph(cell_id) && graph_stack_.empty() && !IsBpropGraph(cell_id)) {
  1513. // Clear previous step resource
  1514. auto resource = GetResource(cell_id);
  1515. if (resource != nullptr && resource->results().find(pipeline::kBackend) != resource->results().end()) {
  1516. compile::BackendPtr backend = resource->results()[pipeline::kBackend].cast<compile::BackendPtr>();
  1517. auto ms_backend = std::dynamic_pointer_cast<compile::MsBackend>(backend);
  1518. ms_backend->ClearSessionGraphs();
  1519. }
  1520. }
  1521. }
  1522. }
  1523. FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
  1524. // If top graph hold
  1525. for (auto it = top_cell_list_.rbegin(); it != top_cell_list_.rend(); ++it) {
  1526. if (cell_id.find((*it)->cell_id) != std::string::npos) {
  1527. return (*it)->df_builder;
  1528. }
  1529. }
  1530. // Current cell is not top graph, get first top cell
  1531. if (!top_cell_list_.empty()) {
  1532. return top_cell_list_.front()->df_builder;
  1533. }
  1534. return nullptr;
  1535. }
  1536. ResourcePtr PynativeExecutor::GetResource(const std::string &cell_id) {
  1537. for (auto it = top_cell_list_.rbegin(); it != top_cell_list_.rend(); ++it) {
  1538. if (cell_id.find((*it)->cell_id) != std::string::npos) {
  1539. return (*it)->resource;
  1540. }
  1541. }
  1542. // Current cell is not top graph, get first top cell
  1543. if (!top_cell_list_.empty()) {
  1544. return top_cell_list_.front()->resource;
  1545. }
  1546. return nullptr;
  1547. }
  1548. std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
  1549. parse::AstMainType type) {
  1550. MS_EXCEPTION_IF_NULL(ast);
  1551. if (py::isinstance<py::none>(node)) {
  1552. MS_LOG(DEBUG) << "Get none type node!";
  1553. return "";
  1554. }
  1555. auto node_type = ast->GetNodeType(node);
  1556. MS_EXCEPTION_IF_NULL(node_type);
  1557. // check node type
  1558. parse::AstMainType node_main_type = node_type->main_type();
  1559. if (node_main_type != type) {
  1560. MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type;
  1561. return "";
  1562. }
  1563. std::string node_name = node_type->node_name();
  1564. MS_LOG(DEBUG) << "Ast node is " << node_name;
  1565. return node_name;
  1566. }
  1567. void PynativeExecutor::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
  1568. MS_EXCEPTION_IF_NULL(ast);
  1569. py::list args = ast->GetArgs(fn_node);
  1570. for (size_t i = 1; i < args.size(); i++) {
  1571. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  1572. MS_LOG(DEBUG) << "Input arg name: " << arg_name;
  1573. cell_input_args_.emplace(arg_name);
  1574. }
  1575. }
  1576. bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
  1577. MS_LOG(DEBUG) << "Parse if/while expr";
  1578. py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
  1579. const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR);
  1580. if (node_name == parse::NAMED_PRIMITIVE_COMPARE) {
  1581. py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT);
  1582. py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS);
  1583. if (comparators_node.empty()) {
  1584. MS_LOG(DEBUG) << "Get comparators node falied!";
  1585. return false;
  1586. }
  1587. auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
  1588. auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
  1589. // while self.a > self.b and changed self.a or self.b
  1590. if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) {
  1591. auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
  1592. std::string left_variable;
  1593. if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) {
  1594. left_variable = py::cast<std::string>(left_value.attr("id")) + py::cast<std::string>(left_node.attr("attr"));
  1595. }
  1596. auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE);
  1597. std::string right_variable;
  1598. if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) {
  1599. right_variable =
  1600. py::cast<std::string>(right_value.attr("id")) + py::cast<std::string>(comparators_node[0].attr("attr"));
  1601. }
  1602. return ParseBodyContext(ast, node, {left_variable, right_variable});
  1603. }
  1604. // if a[0]
  1605. if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
  1606. py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
  1607. left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
  1608. }
  1609. MS_LOG(DEBUG) << "Left is " << left << " Right is " << right;
  1610. if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
  1611. unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
  1612. return true;
  1613. }
  1614. }
  1615. // if flag:
  1616. if (node_name == parse::NAMED_PRIMITIVE_NAME) {
  1617. std::string id = py::cast<std::string>(test_node.attr("id"));
  1618. if (cell_input_args_.find(id) != cell_input_args_.end()) {
  1619. return true;
  1620. }
  1621. }
  1622. return false;
  1623. }
  1624. bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
  1625. MS_LOG(DEBUG) << "Parse assign expr";
  1626. py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE);
  1627. const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR);
  1628. if (node_name == parse::NAMED_PRIMITIVE_CALL) {
  1629. py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC);
  1630. const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR);
  1631. if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
  1632. py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
  1633. py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
  1634. const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
  1635. if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end()) {
  1636. return true;
  1637. }
  1638. }
  1639. }
  1640. return false;
  1641. }
  1642. bool PynativeExecutor::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
  1643. const std::vector<std::string> &compare_prim) {
  1644. MS_LOG(DEBUG) << "Parse augassign expr";
  1645. bool ret = false;
  1646. if (compare_prim.empty()) {
  1647. return ret;
  1648. }
  1649. py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET);
  1650. if (py::isinstance<py::none>(target_node)) {
  1651. MS_LOG(DEBUG) << "Parse target node is none!";
  1652. return ret;
  1653. }
  1654. py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE);
  1655. if (py::isinstance<py::none>(value_node)) {
  1656. MS_LOG(DEBUG) << "Parse value node is none!";
  1657. return ret;
  1658. }
  1659. std::string assign_prim;
  1660. if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) {
  1661. assign_prim = py::cast<std::string>(value_node.attr("id")) + py::cast<std::string>(target_node.attr("attr"));
  1662. }
  1663. auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim);
  1664. if (iter != compare_prim.end()) {
  1665. ret = true;
  1666. }
  1667. return ret;
  1668. }
  1669. bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
  1670. MS_LOG(DEBUG) << "Parse for expr";
  1671. py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
  1672. if (py::isinstance<py::none>(body_node)) {
  1673. MS_LOG(DEBUG) << "Parse body of for expression is none!";
  1674. return false;
  1675. }
  1676. py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
  1677. size_t count = LongToSize(pcount);
  1678. MS_LOG(DEBUG) << "The for nodes count in body is " << count;
  1679. for (size_t i = 0; i < count; ++i) {
  1680. auto it = py::cast<py::list>(body_node)[i];
  1681. const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
  1682. if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) {
  1683. return true;
  1684. }
  1685. }
  1686. return false;
  1687. }
  1688. bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
  1689. const std::vector<std::string> &compare_prim) {
  1690. MS_EXCEPTION_IF_NULL(ast);
  1691. py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
  1692. if (py::isinstance<py::none>(func_obj)) {
  1693. MS_LOG(DEBUG) << "Parse body of cell is none!";
  1694. return false;
  1695. }
  1696. py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
  1697. size_t count = IntToSize(pcount);
  1698. MS_LOG(DEBUG) << "The nodes count in body is " << count;
  1699. bool ret = false;
  1700. for (size_t i = 0; i < count; ++i) {
  1701. auto node = py::cast<py::list>(func_obj)[i];
  1702. const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT);
  1703. if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) {
  1704. ret = ParseAssignExprNode(ast, node);
  1705. } else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) {
  1706. ret = ParseAugAssignExprNode(ast, node, compare_prim);
  1707. } else if (node_name == parse::NAMED_PRIMITIVE_FOR) {
  1708. ret = ParseForExprNode(ast, node);
  1709. } else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) {
  1710. ret = ParseIfWhileExprNode(ast, node);
  1711. }
  1712. if (ret) {
  1713. MS_LOG(INFO) << "Current cell is dynamic!";
  1714. break;
  1715. }
  1716. }
  1717. return ret;
  1718. }
  1719. std::string PynativeExecutor::GetCellInfo(const py::object &cell) {
  1720. if (py::isinstance<Cell>(cell)) {
  1721. auto c_cell = py::cast<CellPtr>(cell);
  1722. MS_EXCEPTION_IF_NULL(c_cell);
  1723. auto cell_info = c_cell->ToString();
  1724. return cell_info;
  1725. }
  1726. return "";
  1727. }
  1728. bool PynativeExecutor::IsDynamicCell(const py::object &cell) {
  1729. std::string cell_info = GetCellInfo(cell);
  1730. if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) {
  1731. return false;
  1732. }
  1733. // using ast parse to check whether the construct of cell will be changed
  1734. auto ast = std::make_shared<parse::ParseAst>(cell);
  1735. bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
  1736. if (!success) {
  1737. MS_LOG(ERROR) << "Parse code to ast tree failed";
  1738. return false;
  1739. }
  1740. py::object fn_node = ast->GetAstNode();
  1741. // get the name of input args as the initialize of dynamic_variables
  1742. ParseInputArgs(ast, fn_node);
  1743. // parse body context
  1744. bool ret = false;
  1745. ret = ParseBodyContext(ast, fn_node);
  1746. cell_input_args_.clear();
  1747. return ret;
  1748. }
  1749. void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
  1750. auto cell_id = GetCellId(cell, args);
  1751. MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
  1752. // check whether cell needed to construct grad graph
  1753. if (graph_stack_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) {
  1754. if (top_cell_list_.empty()) {
  1755. MS_LOG(EXCEPTION) << "Top cell list is empty";
  1756. }
  1757. if (IsTopestGraph(cell_id)) {
  1758. // Clear previous step resource
  1759. CleanPreMemoryInValueNode();
  1760. op_index_map_.clear();
  1761. top_cell_id_ = cell_id;
  1762. }
  1763. PushCurrentCellOpInfoToStack();
  1764. MS_LOG(INFO) << "NewGraph already compiled";
  1765. return;
  1766. }
  1767. // Init resource for constructing forward graph and grad graph
  1768. curr_g_ = std::make_shared<FuncGraph>();
  1769. ClearResidualRes(cell_id);
  1770. if (graph_stack_.empty() && !IsBpropGraph(cell_id)) {
  1771. MakeNewTopGraph(cell_id, args);
  1772. }
  1773. PushCurrentGraphToStack();
  1774. PushCurrentCellOpInfoToStack();
  1775. if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) {
  1776. auto graph_info = std::make_shared<GraphInfo>(cell_id);
  1777. graph_info_map_[curr_g_] = graph_info;
  1778. }
  1779. for (size_t i = 0; i < args.size(); ++i) {
  1780. auto param = args[i];
  1781. auto new_param = curr_g_->add_parameter();
  1782. std::string param_id = GetId(param);
  1783. SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true);
  1784. SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
  1785. SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
  1786. }
  1787. // Check whether the construct of cell will be changed
  1788. if (!has_dynamic_cell_) {
  1789. has_dynamic_cell_ = IsDynamicCell(cell);
  1790. MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << has_dynamic_cell_;
  1791. }
  1792. }
  1793. void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args) {
  1794. for (const auto &arg : args) {
  1795. if (py::isinstance<tensor::Tensor>(arg)) {
  1796. auto tensor = arg.cast<tensor::TensorPtr>();
  1797. if (tensor && tensor->is_parameter()) {
  1798. MS_EXCEPTION(TypeError) << "The inputs could not be Parameter.";
  1799. }
  1800. }
  1801. }
  1802. // Clear resource in old top cell
  1803. if (CheckRealDynamicCell(cell_id)) {
  1804. VectorClear<std::vector<TopCellInfoPtr>>(&top_cell_list_, cell_id);
  1805. }
  1806. CleanPreMemoryInValueNode();
  1807. // Init resource for new top cell
  1808. if (!CheckCellGraph(cell_id)) {
  1809. has_dynamic_cell_ = false;
  1810. }
  1811. op_index_map_.clear();
  1812. top_cell_id_ = cell_id;
  1813. auto df_builder = std::make_shared<FuncGraph>();
  1814. auto graph_info = std::make_shared<GraphInfo>(cell_id);
  1815. graph_info_map_[df_builder] = graph_info;
  1816. auto resource = std::make_shared<pipeline::Resource>();
  1817. resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
  1818. auto top_cell_info = std::make_shared<TopCellInfo>(true, resource, df_builder, cell_id);
  1819. top_cell_list_.emplace_back(top_cell_info);
  1820. MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get();
  1821. }
  1822. std::string PynativeExecutor::GetCellOpInfo() {
  1823. if (cell_op_info_stack_.empty()) {
  1824. MS_LOG(EXCEPTION) << "The cell op info stack is empty";
  1825. }
  1826. return cell_op_info_stack_.top();
  1827. }
  1828. void PynativeExecutor::ReplaceCellOpInfoByCellId(const std::string &cell_id) {
  1829. if (cell_id.empty()) {
  1830. MS_LOG(EXCEPTION) << "The cell id is empty";
  1831. }
  1832. if (cell_op_info_stack_.empty()) {
  1833. MS_LOG(DEBUG) << "The cell op info stack is empty, No need replace";
  1834. return;
  1835. }
  1836. cell_op_info_stack_.top() = cell_op_info_stack_.top() + cell_id;
  1837. }
  1838. void PynativeExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
  1839. bool is_param) {
  1840. if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
  1841. return;
  1842. }
  1843. auto tuple = args.cast<py::tuple>();
  1844. auto tuple_size = static_cast<int64_t>(tuple.size());
  1845. for (int64_t i = 0; i < tuple_size; ++i) {
  1846. auto id = GetId(tuple[i]);
  1847. if (is_param && node->isa<Parameter>()) {
  1848. auto param = node->cast<ParameterPtr>();
  1849. MS_EXCEPTION_IF_NULL(param);
  1850. SetParamNodeMapInGraphInfoMap(g, id, param);
  1851. }
  1852. SetNodeMapInGraphInfoMap(g, id, node, i);
  1853. SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, std::vector<int64_t>{i}, is_param);
  1854. }
  1855. }
  1856. void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args,
  1857. const AnfNodePtr &node,
  1858. const std::vector<int64_t> &index_sequence, bool is_param) {
  1859. if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
  1860. return;
  1861. }
  1862. auto tuple = args.cast<py::tuple>();
  1863. auto tuple_size = static_cast<int64_t>(tuple.size());
  1864. for (int64_t i = 0; i < tuple_size; ++i) {
  1865. std::vector<int64_t> tmp = index_sequence;
  1866. tmp.emplace_back(i);
  1867. auto id = GetId(tuple[i]);
  1868. if (is_param && node->isa<Parameter>()) {
  1869. auto param = node->cast<ParameterPtr>();
  1870. MS_EXCEPTION_IF_NULL(param);
  1871. SetParamNodeMapInGraphInfoMap(g, id, param);
  1872. }
  1873. SetNodeMapInGraphInfoMap(g, id, node, tmp);
  1874. SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, tmp, is_param);
  1875. }
  1876. }
  1877. void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
  1878. const auto &cell_id = GetCellId(cell, args);
  1879. MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
  1880. if (graph_stack_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) {
  1881. PopCurrentCellOpInfoFromStack();
  1882. MS_LOG(INFO) << "Endgraph already compiled";
  1883. return;
  1884. }
  1885. auto out_id = GetId(out);
  1886. // x =op1, y =op2, return (x, y)
  1887. auto graph_info = graph_info_map_.at(curr_g_);
  1888. MS_EXCEPTION_IF_NULL(graph_info);
  1889. if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) {
  1890. if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
  1891. auto tuple = out.cast<py::tuple>();
  1892. auto tuple_size = static_cast<int64_t>(tuple.size());
  1893. std::vector<AnfNodePtr> inputs;
  1894. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  1895. for (int64_t i = 0; i < tuple_size; i++) {
  1896. inputs.emplace_back(GetInput(tuple[i], false));
  1897. }
  1898. auto cnode = curr_g_->NewCNode(inputs);
  1899. SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
  1900. SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
  1901. } else {
  1902. MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
  1903. MakeValueNode(out, out_id);
  1904. }
  1905. }
  1906. EndGraphByOutId(cell, cell_id, out, out_id, args);
  1907. }
  1908. void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
  1909. const std::string &out_id, const py::args &args) {
  1910. AnfNodePtr output_node = GetObjNode(out, out_id);
  1911. curr_g_->set_output(output_node);
  1912. MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString();
  1913. if (EndBpropGraph(cell_id)) {
  1914. MS_LOG(DEBUG) << "Get bprop function cell";
  1915. return;
  1916. }
  1917. auto resource = GetResource(top_cell_id_);
  1918. MS_EXCEPTION_IF_NULL(resource);
  1919. resource->manager()->AddFuncGraph(curr_g_);
  1920. UpdateCellGraph(cell, curr_g_, cell_id, true, false);
  1921. FuncGraphPtr newfg = nullptr;
  1922. // Cell no Change
  1923. if (CheckDynamicCell(cell_id) && !CheckCellChanged(cell_id)) {
  1924. MS_LOG(DEBUG) << "Cell is not dynamic, No need make ad grad";
  1925. } else {
  1926. MS_LOG(DEBUG) << "Need make ad grad";
  1927. newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args);
  1928. }
  1929. if (graph_stack_.size() > 1) {
  1930. std::vector<AnfNodePtr> inputs;
  1931. inputs.emplace_back(NewValueNode(curr_g_));
  1932. PopGraphStack();
  1933. PopCurrentCellOpInfoFromStack();
  1934. ReplaceCellOpInfoByCellId(cell_id);
  1935. // connect the previous graph to the inside graph
  1936. auto graph_prev = graph_stack_.top();
  1937. for (size_t i = 0; i < args.size(); i++) {
  1938. auto input = GetInput(args[i], false);
  1939. inputs.emplace_back(input);
  1940. }
  1941. auto out_cnode = graph_prev->NewCNode(inputs);
  1942. SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args));
  1943. SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
  1944. SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
  1945. } else {
  1946. if (newfg != nullptr) {
  1947. DumpGraphIR("before_resolve.ir", newfg);
  1948. parse::ResolveFuncGraph(newfg, resource);
  1949. DumpGraphIR("after_resolve.ir", newfg);
  1950. resource->set_func_graph(newfg);
  1951. }
  1952. PopGraphStack();
  1953. PopCurrentCellOpInfoFromStack();
  1954. }
  1955. }
  1956. bool PynativeExecutor::EndBpropGraph(const string &cell_id) {
  1957. auto is_bprop_graph = IsBpropGraph(cell_id);
  1958. if (is_bprop_graph) {
  1959. if (!IsNestedGrad()) {
  1960. PopGraphStack();
  1961. PopCurrentCellOpInfoFromStack();
  1962. ReplaceCellOpInfoByCellId(cell_id);
  1963. }
  1964. return true;
  1965. }
  1966. return false;
  1967. }
  1968. bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) {
  1969. bool res = false;
  1970. if (CheckRealDynamicCell(cell_id)) {
  1971. MS_LOG(DEBUG) << "Cur cell " << cell_id << " is dynamic, no need check";
  1972. return true;
  1973. }
  1974. if (GetCellOpInfo().empty()) {
  1975. MS_LOG(DEBUG) << "Cell op info is empty";
  1976. return true;
  1977. }
  1978. auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
  1979. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  1980. if (it == cell_graph_list_.end() || IsFirstGradStep(top_cell_id_)) {
  1981. return true;
  1982. }
  1983. MS_LOG(DEBUG) << "Cell op info " << GetCellOpInfo() << ", old " << (*it)->cell_ops_info.at((*it)->call_times);
  1984. if ((*it)->cell_ops_info.at((*it)->call_times) != GetCellOpInfo()) {
  1985. res = true;
  1986. UpdateCellDynamic(cell_id);
  1987. MS_LOG(DEBUG) << "Cell self changed";
  1988. }
  1989. (*it)->call_times = (*it)->call_times < (*it)->cell_ops_info.size() - 1 ? (*it)->call_times + 1 : 0;
  1990. return res;
  1991. }
  1992. void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) {
  1993. for (auto &it : cell_graph_list_) {
  1994. if (it->cell_id != cell_id) {
  1995. it->is_real_dynamic = true;
  1996. continue;
  1997. }
  1998. it->is_real_dynamic = true;
  1999. break;
  2000. }
  2001. }
  2002. void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
  2003. bool need_cloned, bool is_grad) {
  2004. auto update_in_endgraph = need_cloned && !is_grad;
  2005. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  2006. // Bprop just save backward graph
  2007. auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
  2008. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  2009. if (it != cell_graph_list_.end()) {
  2010. (*it)->is_grad = is_grad;
  2011. if (g != (*it)->fg) {
  2012. graph_info_map_.update((*it)->fg, g);
  2013. (*it)->fg = g;
  2014. }
  2015. if (update_in_endgraph && IsFirstGradStep(top_cell_id_)) {
  2016. (*it)->cell_ops_info.emplace_back(GetCellOpInfo());
  2017. }
  2018. MS_LOG(DEBUG) << "Update bprop bg cell id " << cell_id;
  2019. } else {
  2020. py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
  2021. auto bprop_func_cell_id = GetId(bprop_func);
  2022. MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id
  2023. << " cell ops info " << GetCellOpInfo();
  2024. auto cell_info = std::make_shared<CellInfo>(true, has_dynamic_cell_, g, cell_id, bprop_func_cell_id);
  2025. cell_info->cell_ops_info.emplace_back(GetCellOpInfo());
  2026. cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
  2027. }
  2028. return;
  2029. }
  2030. FuncGraphPtr tmp = g;
  2031. if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) {
  2032. MS_LOG(DEBUG) << "No need cloned";
  2033. need_cloned = false;
  2034. }
  2035. auto clone_fn = [&g, &tmp, need_cloned, this]() {
  2036. if (!need_cloned) {
  2037. return;
  2038. }
  2039. tmp = BasicClone(g);
  2040. graph_info_map_.update(g, tmp);
  2041. ClearCnodeRes(tmp->output());
  2042. };
  2043. // First call or cell id not exist
  2044. if (update_in_endgraph && (IsFirstGradStep(top_cell_id_) || !CheckCellGraph(cell_id))) {
  2045. if (!CheckCellGraph(cell_id)) {
  2046. clone_fn();
  2047. MS_LOG(DEBUG) << "Add new cell with cloned graph " << cell_id << " cell ops info " << GetCellOpInfo();
  2048. auto cell_info = std::make_shared<CellInfo>(true, has_dynamic_cell_, tmp, cell_id, "");
  2049. cell_info->cell_ops_info.emplace_back(GetCellOpInfo());
  2050. cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
  2051. } else {
  2052. auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
  2053. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  2054. if (it != cell_graph_list_.end()) {
  2055. (*it)->cell_ops_info.emplace_back(GetCellOpInfo());
  2056. }
  2057. MS_LOG(DEBUG) << "Add another same cell ops info";
  2058. }
  2059. return;
  2060. }
  2061. for (auto &it : cell_graph_list_) {
  2062. if (it->cell_id != cell_id) {
  2063. continue;
  2064. }
  2065. if (IsFirstGradStep(cell_id)) {
  2066. // no compute grad
  2067. it->is_grad = is_grad;
  2068. }
  2069. if (need_cloned) {
  2070. clone_fn();
  2071. if (it->fg != nullptr) {
  2072. graph_info_map_.erase(it->fg);
  2073. }
  2074. MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with cloned new " << tmp.get();
  2075. it->fg = tmp;
  2076. }
  2077. if (!need_cloned && !is_grad) {
  2078. graph_info_map_.erase(it->fg);
  2079. MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with new " << tmp.get();
  2080. it->fg = tmp;
  2081. }
  2082. break;
  2083. }
  2084. }
  2085. void PynativeExecutor::ClearCnodeRes(const AnfNodePtr &node) {
  2086. MS_EXCEPTION_IF_NULL(node);
  2087. if (!node->isa<CNode>()) {
  2088. return;
  2089. }
  2090. auto cnode = node->cast<CNodePtr>();
  2091. cnode->clear_inputs_value();
  2092. for (size_t i = 0; i < cnode->size(); ++i) {
  2093. auto n = cnode->input(i);
  2094. cnode->set_forward(nullptr, "");
  2095. ClearCnodeRes(n);
  2096. }
  2097. }
  2098. FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r,
  2099. const std::string &cell_id, const py::args &args) {
  2100. bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME);
  2101. if (is_custom_bprop) {
  2102. size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
  2103. if (par_number > 0) {
  2104. MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
  2105. << " parameters that is not supported in the net.";
  2106. }
  2107. MS_LOG(INFO) << "Use cell custom bprop function.";
  2108. FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
  2109. if (bprop_graph != nullptr) {
  2110. (void)g->transforms().emplace(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
  2111. (void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g)));
  2112. }
  2113. }
  2114. DumpGraphIR("fg.ir", g);
  2115. auto is_top = IsTopGraph(cell_id);
  2116. MS_LOG(DEBUG) << "Grad top cell " << is_top;
  2117. set_need_replace_forward(!IsNestedGrad());
  2118. // Obtain grad graph
  2119. auto newfg = ad::Grad(g, r, is_top);
  2120. if (is_custom_bprop) {
  2121. auto params = newfg->parameters();
  2122. auto manager = Manage({newfg}, false);
  2123. if (args.size() > params.size()) {
  2124. MS_EXCEPTION(TypeError) << "The number of arguments " << args.size()
  2125. << " is more than the number of parameters required, which is " << params.size();
  2126. }
  2127. for (size_t i = 0; i < args.size(); i++) {
  2128. ValuePtr value = PyAttrValue(args[i]);
  2129. auto v_node = NewValueNode(value);
  2130. manager->Replace(params[i], v_node);
  2131. }
  2132. UpdateCellGraph(cell, newfg, cell_id, false, false);
  2133. }
  2134. return newfg;
  2135. }
  2136. std::string PynativeExecutor::GetGradCellId(bool has_sens, const py::object &cell, const py::args &args,
  2137. py::object *forward_args, py::object *sens) {
  2138. auto size = args.size();
  2139. size_t forward_args_size = size;
  2140. if (has_sens) {
  2141. if (size >= 1) {
  2142. --forward_args_size;
  2143. if (sens != nullptr) {
  2144. *sens = args[forward_args_size];
  2145. }
  2146. }
  2147. py::tuple f_args(forward_args_size);
  2148. for (size_t i = 0; i < forward_args_size; ++i) {
  2149. f_args[i] = args[i];
  2150. }
  2151. *forward_args = f_args;
  2152. }
  2153. const auto &cell_id = GetCellId(cell, *forward_args);
  2154. return cell_id;
  2155. }
  2156. void PynativeExecutor::SaveAllValueNodeTensors(const FuncGraphPtr &graph) {
  2157. std::unordered_set<tensor::TensorPtr> all_value_node_tensors;
  2158. auto trace_function = [&all_value_node_tensors](const AnfNodePtr &anf_node) {
  2159. auto value = GetValueNode(anf_node);
  2160. if (value) {
  2161. if (value->isa<tensor::Tensor>()) {
  2162. auto tensor = value->cast<tensor::TensorPtr>();
  2163. MS_EXCEPTION_IF_NULL(tensor);
  2164. if (tensor->device_address()) {
  2165. all_value_node_tensors.emplace(tensor);
  2166. }
  2167. } else if (value->isa<ValueTuple>()) {
  2168. auto tuple = value->cast<ValueTuplePtr>();
  2169. MS_EXCEPTION_IF_NULL(tuple);
  2170. for (size_t i = 0; i < tuple->size(); i++) {
  2171. if ((*tuple)[i]->isa<tensor::Tensor>()) {
  2172. auto tensor = (*tuple)[i]->cast<tensor::TensorPtr>();
  2173. MS_EXCEPTION_IF_NULL(tensor);
  2174. if (tensor->device_address()) {
  2175. all_value_node_tensors.emplace(tensor);
  2176. }
  2177. }
  2178. }
  2179. }
  2180. }
  2181. return FOLLOW;
  2182. };
  2183. (void)TopoSort(graph->get_return(), SuccDeeperSimple, trace_function);
  2184. all_value_node_tensors_ = all_value_node_tensors;
  2185. }
  2186. void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  2187. const py::args &args) {
  2188. auto size = args.size();
  2189. py::object sens = py::none();
  2190. py::object forward_args = args;
  2191. const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args, &forward_args, &sens);
  2192. MS_LOG(DEBUG) << "GradNet start " << size << " " << cell_id;
  2193. const auto &params_changed = CheckGradParamsChanged(cell_id, weights, sens);
  2194. if (!params_changed && !IsFirstGradStep(cell_id) && !CheckRealDynamicCell(cell_id)) {
  2195. UpdateTopCellCompileInfo(cell_id, false);
  2196. ClearDynamicTopRes(cell_id);
  2197. MS_LOG(INFO) << "Gradgraph already compiled";
  2198. return;
  2199. }
  2200. // Nested graph
  2201. if (CheckCellGraph(cell_id) && !graph_stack_.empty()) {
  2202. MS_LOG(DEBUG) << "Set nested top graph";
  2203. SetNestedTopGraph(cell, forward_args, cell_id);
  2204. }
  2205. auto df_builder = GetDfbuilder(cell_id);
  2206. MS_EXCEPTION_IF_NULL(df_builder);
  2207. auto resource = GetResource(cell_id);
  2208. MS_EXCEPTION_IF_NULL(resource);
  2209. MS_LOG(DEBUG) << "df_builder ptr " << df_builder.get() << " resource ptr " << resource.get();
  2210. // Set all params(input+weights)
  2211. SetGradGraphParams(df_builder, resource, size);
  2212. // Get params(weights) require derivative
  2213. auto w_args = GetWeightsArgs(weights, df_builder);
  2214. // Get the parameters items and add the value to args_spec
  2215. auto args_spec = GetArgsSpec(args, df_builder);
  2216. resource->set_args_spec(args_spec);
  2217. // Get real grad graph
  2218. DumpGraphIR("before_grad.ir", resource->func_graph());
  2219. GradGraph(resource->func_graph(), grad, w_args, size, cell_id);
  2220. DumpGraphIR("after_grad.ir", df_builder);
  2221. resource->set_func_graph(df_builder);
  2222. resource->manager()->KeepRoots({df_builder});
  2223. resource->results()[pipeline::kBackend] = compile::CreateBackend();
  2224. MS_LOG(INFO) << "Start opt";
  2225. if (has_dynamic_cell_) {
  2226. SaveAllValueNodeTensors(resource->func_graph());
  2227. }
  2228. PynativeOptimizeAction(resource);
  2229. DumpGraphIR("after_opt.ir", resource->func_graph());
  2230. SaveTensorsInValueNode(resource);
  2231. TaskEmitAction(resource);
  2232. ExecuteAction(resource);
  2233. ClearUselessRes(df_builder, cell, cell_id);
  2234. UpdateCellGraph(cell, curr_g_, cell_id, false, true);
  2235. UpdateTopCellCompileInfo(cell_id, true);
  2236. resource->Clean();
  2237. }
  2238. void PynativeExecutor::ClearDynamicTopRes(const std::string &cell_id) {
  2239. if (IsTopestGraph(cell_id)) {
  2240. op_index_map_.clear();
  2241. }
  2242. // Delete unused top cell resource
  2243. if (!CheckDynamicCell(cell_id)) {
  2244. return;
  2245. }
  2246. int same_top_cell_count = 0;
  2247. for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) {
  2248. if ((*it)->cell_id == cell_id) {
  2249. ++same_top_cell_count;
  2250. if (same_top_cell_count > 1) {
  2251. graph_info_map_.erase((*it)->df_builder);
  2252. it = top_cell_list_.erase(it);
  2253. --same_top_cell_count;
  2254. } else {
  2255. ++it;
  2256. }
  2257. } else {
  2258. ++it;
  2259. }
  2260. }
  2261. }
  2262. bool PynativeExecutor::CheckGradParamsChanged(const std::string &cell_id, const py::object &weights,
  2263. const py::object &sens) {
  2264. bool res = false;
  2265. auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
  2266. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; });
  2267. if (it == top_cell_list_.end()) {
  2268. return res;
  2269. }
  2270. auto fn = [](const py::object &arg) {
  2271. std::string arg_id;
  2272. if (py::isinstance<tensor::Tensor>(arg)) {
  2273. auto tensor_ptr = py::cast<tensor::TensorPtr>(arg);
  2274. auto dtype = tensor_ptr->data_type();
  2275. auto shape = tensor_ptr->shape();
  2276. std::stringstream ss;
  2277. std::for_each(shape.begin(), shape.end(), [&ss](int i) { ss << i; });
  2278. arg_id = ss.str() + std::to_string(dtype);
  2279. } else {
  2280. arg_id = std::string(py::str(arg));
  2281. }
  2282. return arg_id;
  2283. };
  2284. std::string sens_id = "sens";
  2285. if (!py::isinstance<py::none>(sens)) {
  2286. sens_id = fn(sens);
  2287. }
  2288. if (!(*it)->sens_id.empty() && (*it)->sens_id != sens_id) {
  2289. (*it)->sens_id = sens_id;
  2290. }
  2291. std::string weights_id = fn(weights);
  2292. if (!(*it)->weights_id.empty() && (*it)->weights_id != weights_id) {
  2293. (*it)->weights_id = weights_id;
  2294. res = true;
  2295. }
  2296. return res;
  2297. }
  2298. void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id) {
  2299. if (IsTopGraph(cell_id)) {
  2300. VectorClear<std::vector<TopCellInfoPtr>>(&top_cell_list_, cell_id);
  2301. }
  2302. ResourcePtr resource = nullptr;
  2303. auto ia = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
  2304. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; });
  2305. if (ia != top_cell_list_.end()) {
  2306. resource = GetResource((*ia)->cell_id);
  2307. MS_EXCEPTION_IF_NULL(resource);
  2308. MS_LOG(DEBUG) << "Find old resource " << resource.get();
  2309. }
  2310. if (resource == nullptr) {
  2311. resource = std::make_shared<pipeline::Resource>();
  2312. resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
  2313. MS_LOG(DEBUG) << "Make new resource " << resource.get();
  2314. }
  2315. MS_EXCEPTION_IF_NULL(resource);
  2316. FuncGraphPtr df_builder = std::make_shared<FuncGraph>();
  2317. auto graph_info = std::make_shared<GraphInfo>(cell_id);
  2318. graph_info_map_[df_builder] = graph_info;
  2319. auto top_cell_info = std::make_shared<TopCellInfo>(false, resource, df_builder, cell_id);
  2320. top_cell_list_.emplace_back(top_cell_info);
  2321. FuncGraphPtr forward_graph = nullptr;
  2322. auto ib = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
  2323. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  2324. if (ib != cell_graph_list_.end()) {
  2325. forward_graph = (*ib)->fg;
  2326. }
  2327. MS_EXCEPTION_IF_NULL(forward_graph);
  2328. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  2329. DumpGraphIR("nested_bprop.ir", forward_graph);
  2330. // Custom bprop get backward graph(before opt), which use like other forward graph
  2331. curr_g_ = forward_graph;
  2332. resource->set_func_graph(forward_graph);
  2333. return;
  2334. }
  2335. // Copy weights parameters
  2336. ReplaceGraphParams(df_builder, forward_graph, cell_id);
  2337. resource->manager()->AddFuncGraph(forward_graph);
  2338. DumpGraphIR("nested_fg.ir", forward_graph);
  2339. set_need_replace_forward(false);
  2340. auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args);
  2341. resource->set_func_graph(newfg);
  2342. }
  2343. void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph,
  2344. const std::string &cell_id) {
  2345. std::vector<FuncGraphPtr> graph_before{};
  2346. bool index_find = false;
  2347. for (const auto &it : cell_graph_list_) {
  2348. if (IsBpropGraph(it->cell_id) || it->fg == nullptr) {
  2349. continue;
  2350. }
  2351. if (index_find) {
  2352. graph_before.emplace_back(it->fg);
  2353. continue;
  2354. }
  2355. if (it->cell_id == cell_id) {
  2356. index_find = true;
  2357. graph_before.emplace_back(it->fg);
  2358. }
  2359. }
  2360. auto manager = Manage({forward_graph}, false);
  2361. for (const auto &f : graph_before) {
  2362. auto graph_info = graph_info_map_.at(f);
  2363. MS_EXCEPTION_IF_NULL(graph_info);
  2364. for (const auto &it : graph_info->params) {
  2365. if (!it.second->has_default()) {
  2366. continue;
  2367. }
  2368. auto new_param = df_builder->add_parameter();
  2369. new_param->set_abstract(it.second->abstract());
  2370. new_param->set_name(it.second->name());
  2371. new_param->set_default_param(it.second->default_param());
  2372. ScopePtr scope = (it.second->scope() != kDefaultScope) ? it.second->scope() : kDefaultScope;
  2373. new_param->set_scope(scope);
  2374. manager->Replace(it.second, new_param);
  2375. replace_weights_map_[forward_graph].emplace_back(std::make_pair(it.second, new_param));
  2376. MS_LOG(DEBUG) << "Param name " << new_param->name() << " ptr " << new_param.get();
  2377. auto graph_info_of_df_builder = graph_info_map_.at(df_builder);
  2378. MS_EXCEPTION_IF_NULL(graph_info_of_df_builder);
  2379. graph_info_of_df_builder->params[it.first] = new_param;
  2380. SetParamNodeMapInGraphInfoMap(df_builder, it.first, new_param);
  2381. SetNodeMapInGraphInfoMap(df_builder, it.first, new_param);
  2382. }
  2383. }
  2384. }
  2385. void PynativeExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size) {
  2386. std::vector<AnfNodePtr> new_params;
  2387. for (size_t i = 0; i < size; i++) {
  2388. ParameterPtr p = std::make_shared<Parameter>(df_builder);
  2389. new_params.emplace_back(p);
  2390. }
  2391. MS_LOG(DEBUG) << "GradNet weight param size " << df_builder->parameters().size();
  2392. // df_builder_->parameters() set in GetInput, which are weights params
  2393. new_params.insert(new_params.end(), df_builder->parameters().begin(), df_builder->parameters().end());
  2394. df_builder->set_parameters(new_params);
  2395. resource->manager()->SetParameters(df_builder, new_params);
  2396. }
  2397. std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder) {
  2398. std::vector<AnfNodePtr> w_args;
  2399. if (!py::hasattr(weights, "__parameter_tuple__")) {
  2400. MS_LOG(DEBUG) << "No paramter_tuple get";
  2401. return {};
  2402. }
  2403. auto tuple = weights.cast<py::tuple>();
  2404. MS_LOG(DEBUG) << "Get weights tuple size " << tuple.size();
  2405. w_args.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  2406. for (size_t it = 0; it < tuple.size(); ++it) {
  2407. auto param = tuple[it];
  2408. auto param_id = GetId(param);
  2409. AnfNodePtr para_node = nullptr;
  2410. auto graph_info = graph_info_map_.at(df_builder);
  2411. MS_EXCEPTION_IF_NULL(graph_info);
  2412. if (graph_info->params.find(param_id) != graph_info->params.end() &&
  2413. graph_info->node_map.find(param_id) != graph_info->node_map.end()) {
  2414. para_node = graph_info->node_map[param_id].first;
  2415. } else {
  2416. auto name_attr = parse::python_adapter::GetPyObjAttr(param, "name");
  2417. if (py::isinstance<py::none>(name_attr)) {
  2418. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  2419. }
  2420. auto param_name = py::cast<std::string>(name_attr);
  2421. auto free_param = df_builder->add_parameter();
  2422. free_param->set_name(param_name);
  2423. auto value = py::cast<tensor::TensorPtr>(param);
  2424. free_param->set_default_param(value);
  2425. free_param->debug_info()->set_name(param_name);
  2426. para_node = free_param;
  2427. }
  2428. w_args.emplace_back(para_node);
  2429. }
  2430. return w_args;
  2431. }
  2432. abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder) {
  2433. abstract::AbstractBasePtrList args_spec;
  2434. std::size_t size = args.size();
  2435. auto df_params = df_builder->parameters();
  2436. if (df_params.size() < size) {
  2437. MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << size;
  2438. }
  2439. // input params
  2440. for (std::size_t i = 0; i < size; i++) {
  2441. ValuePtr converted = nullptr;
  2442. bool succ = parse::ConvertData(args[i], &converted);
  2443. if (!succ) {
  2444. MS_LOG(EXCEPTION) << "Args convert error";
  2445. }
  2446. bool broaden = true;
  2447. auto abs = abstract::FromValue(converted, broaden);
  2448. args_spec.emplace_back(abs);
  2449. auto param_node = std::static_pointer_cast<Parameter>(df_params[i]);
  2450. param_node->set_abstract(abs);
  2451. }
  2452. // weights params
  2453. for (const auto &param : df_params) {
  2454. auto param_node = std::static_pointer_cast<Parameter>(param);
  2455. if (param_node->has_default()) {
  2456. ValuePtr value = param_node->default_param();
  2457. auto ptr = value->ToAbstract();
  2458. MS_EXCEPTION_IF_NULL(ptr);
  2459. args_spec.emplace_back(ptr);
  2460. param_node->set_abstract(ptr);
  2461. }
  2462. }
  2463. MS_LOG(DEBUG) << "Args_spec size " << args_spec.size();
  2464. return args_spec;
  2465. }
  2466. void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op,
  2467. const std::vector<AnfNodePtr> &weights, size_t arg_size, const std::string &cell_id) {
  2468. FuncGraphPtr top_g = nullptr;
  2469. auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
  2470. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  2471. if (it != cell_graph_list_.end()) {
  2472. top_g = (*it)->fg;
  2473. }
  2474. MS_EXCEPTION_IF_NULL(top_g);
  2475. auto nparam = top_g->parameters().size();
  2476. MS_LOG(DEBUG) << "Top graph input params size " << nparam;
  2477. std::ostringstream ss;
  2478. ss << "grad{" << nparam << "}";
  2479. auto df_builder = GetDfbuilder(cell_id);
  2480. MS_EXCEPTION_IF_NULL(df_builder);
  2481. auto resource = GetResource(cell_id);
  2482. MS_EXCEPTION_IF_NULL(resource);
  2483. df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  2484. df_builder->debug_info()->set_name(ss.str());
  2485. auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g->parameters(), weights);
  2486. std::vector<AnfNodePtr> inputs = {NewValueNode(df)};
  2487. auto df_params = df_builder->parameters();
  2488. if (df_params.size() < arg_size) {
  2489. MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << arg_size;
  2490. }
  2491. for (size_t i = 0; i < arg_size; ++i) {
  2492. inputs.emplace_back(df_params[i]);
  2493. }
  2494. auto out = df_builder->NewCNode(inputs);
  2495. df_builder->set_output(out);
  2496. resource->manager()->AddFuncGraph(df);
  2497. resource->manager()->AddFuncGraph(df_builder);
  2498. }
  2499. void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell,
  2500. const std::string &cell_id) {
  2501. graph_info_map_.erase(df_builder);
  2502. bool has_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME);
  2503. bool is_dynamic_top_fist_grad = CheckDynamicCell(cell_id) && IsFirstGradStep(cell_id);
  2504. bool is_topmost = IsTopestGraph(cell_id) && top_cell_list_.front()->cell_id == cell_id;
  2505. if (has_custom_bprop || is_dynamic_top_fist_grad || !is_topmost) {
  2506. return;
  2507. }
  2508. MS_LOG(DEBUG) << "Update topmost cell graph list and graph info map";
  2509. // Clear graph_info_map_
  2510. std::vector<std::string> l{};
  2511. bool index_find = false;
  2512. for (auto &it : cell_graph_list_) {
  2513. if (index_find) {
  2514. it->fg = nullptr;
  2515. l.emplace_back(it->cell_id);
  2516. continue;
  2517. }
  2518. if (it->cell_id == cell_id) {
  2519. index_find = true;
  2520. it->fg = nullptr;
  2521. l.emplace_back(it->cell_id);
  2522. }
  2523. }
  2524. for (const auto &it : l) {
  2525. for (auto ic = graph_info_map_.begin(); ic != graph_info_map_.end();) {
  2526. if (ic->second->cell_id.find(it) != std::string::npos) {
  2527. ic = graph_info_map_.erase(ic);
  2528. } else {
  2529. ++ic;
  2530. }
  2531. }
  2532. }
  2533. }
  2534. py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) {
  2535. BaseRef ret = false;
  2536. AddNestedGradOrder();
  2537. if (!grad_running()) {
  2538. MS_LOG(DEBUG) << "Grad not running yet";
  2539. return BaseRefToPyData(ret);
  2540. }
  2541. const auto &cell_id = GetCellId(cell, args);
  2542. std::string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size()));
  2543. MS_LOG(DEBUG) << "Key is " << key;
  2544. for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) {
  2545. MS_LOG(DEBUG) << "Cur cell id " << (*it)->cell_id;
  2546. if (key != (*it)->cell_id.substr(0, std::min(PTR_LEN, (*it)->cell_id.size()))) {
  2547. continue;
  2548. }
  2549. MS_LOG(DEBUG) << "Delete cellid from cell graph list";
  2550. graph_info_map_.erase((*it)->fg);
  2551. cell_graph_list_.erase(it);
  2552. ret = true;
  2553. break;
  2554. }
  2555. return BaseRefToPyData(ret);
  2556. }
  2557. py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) {
  2558. const auto &cell_id = GetCellId(cell, args);
  2559. bool already_run = CheckCellGraph(cell_id);
  2560. MS_LOG(DEBUG) << "Graph have already run " << already_run << " cell id " << cell_id;
  2561. return BaseRefToPyData(already_run);
  2562. }
  2563. py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) {
  2564. auto cell_id = GetCellId(cell, args);
  2565. MS_LOG(DEBUG) << "Run start cell id " << cell_id;
  2566. bool has_sens = false;
  2567. for (const auto &it : top_cell_list_) {
  2568. if (cell_id.find(it->cell_id) != std::string::npos && cell_id != it->cell_id) {
  2569. has_sens = true;
  2570. break;
  2571. }
  2572. }
  2573. py::object forward_args = args;
  2574. cell_id = GetGradCellId(has_sens, cell, args, &forward_args);
  2575. MS_LOG(DEBUG) << "Run has sens " << has_sens << " forward cell id " << cell_id;
  2576. auto resource = GetResource(cell_id);
  2577. MS_EXCEPTION_IF_NULL(resource);
  2578. MS_LOG(DEBUG) << "Run resource ptr " << resource.get();
  2579. VectorRef arg_list;
  2580. py::tuple converted_args = ConvertArgs(args);
  2581. pipeline::ProcessVmArgInner(converted_args, resource, &arg_list);
  2582. if (resource->results().find(pipeline::kOutput) == resource->results().end()) {
  2583. MS_LOG(EXCEPTION) << "Can't find run graph output";
  2584. }
  2585. if (!resource->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
  2586. MS_LOG(EXCEPTION) << "Run graph is not VmEvalFuncPtr";
  2587. }
  2588. compile::VmEvalFuncPtr run = resource->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
  2589. MS_EXCEPTION_IF_NULL(run);
  2590. std::string backend = MsContext::GetInstance()->backend_policy();
  2591. MS_LOG(DEBUG) << "Eval run " << backend;
  2592. set_grad_runing(true);
  2593. BaseRef value = (*run)(arg_list);
  2594. set_grad_runing(false);
  2595. MS_LOG(DEBUG) << "Eval run end " << value.ToString();
  2596. auto out = BaseRefToPyData(value);
  2597. auto do_vm_compiled =
  2598. std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
  2599. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->do_vm_compiled; });
  2600. if (do_vm_compiled) {
  2601. if (MakeBpropNestedCnode(cell, out, cell_id)) {
  2602. return out;
  2603. }
  2604. MakeNestedCnode(cell_id, args, resource, out, has_sens);
  2605. }
  2606. return out;
  2607. }
  2608. bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id) {
  2609. if (graph_stack_.empty() || !py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  2610. MS_LOG(DEBUG) << "No nested bprop grad find";
  2611. return false;
  2612. }
  2613. auto out_id = GetId(out);
  2614. std::vector<AnfNodePtr> inputs;
  2615. inputs.emplace_back(NewValueNode(curr_g_));
  2616. PopGraphStack();
  2617. auto graph_info = graph_info_map_.at(curr_g_);
  2618. MS_EXCEPTION_IF_NULL(graph_info);
  2619. for (const auto &ig : graph_info->params) {
  2620. if (!ig.second->has_default()) {
  2621. inputs.emplace_back(ig.second);
  2622. }
  2623. }
  2624. auto cnode = curr_g_->NewCNode(inputs);
  2625. SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
  2626. SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
  2627. MS_LOG(DEBUG) << "Custom bprop make nested node is " << cnode->DebugString(4);
  2628. return true;
  2629. }
  2630. void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
  2631. const py::object &out, bool has_sens) {
  2632. if (graph_stack_.empty()) {
  2633. MS_LOG(DEBUG) << "No nested grad find";
  2634. return;
  2635. }
  2636. auto graph_prev = graph_stack_.top();
  2637. MS_EXCEPTION_IF_NULL(graph_prev);
  2638. MS_LOG(DEBUG) << "Get pre graph ptr " << graph_prev.get();
  2639. auto newfg = resource->func_graph();
  2640. MS_EXCEPTION_IF_NULL(newfg);
  2641. auto inputs_size = args.size();
  2642. if (has_sens) {
  2643. inputs_size -= 1;
  2644. }
  2645. std::vector<AnfNodePtr> inputs;
  2646. inputs.emplace_back(NewValueNode(newfg));
  2647. for (size_t i = 0; i < inputs_size; ++i) {
  2648. inputs.emplace_back(GetInput(args[i], false));
  2649. }
  2650. if (newfg->parameters().size() > args.size()) {
  2651. RecoverGraphParams(newfg, cell_id, &inputs);
  2652. }
  2653. auto out_id = GetId(out);
  2654. auto cnode = graph_prev->NewCNode(inputs);
  2655. SetTupleArgsToGraphInfoMap(graph_prev, out, cnode);
  2656. SetNodeMapInGraphInfoMap(graph_prev, out_id, cnode);
  2657. MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
  2658. }
  2659. void PynativeExecutor::RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id,
  2660. std::vector<AnfNodePtr> *inputs) {
  2661. FuncGraphPtr forward_graph = nullptr;
  2662. auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
  2663. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  2664. if (ic != cell_graph_list_.end()) {
  2665. forward_graph = (*ic)->fg;
  2666. }
  2667. MS_EXCEPTION_IF_NULL(forward_graph);
  2668. auto param_list = replace_weights_map_.at(forward_graph);
  2669. auto params = newfg->parameters();
  2670. auto manage = Manage({newfg}, false);
  2671. for (const auto &it : params) {
  2672. auto param = it->cast<ParameterPtr>();
  2673. if (!param->has_default()) {
  2674. continue;
  2675. }
  2676. for (auto p = param_list.begin(); p != param_list.end();) {
  2677. MS_LOG(DEBUG) << "Param name " << param->name() << " ptr " << param.get();
  2678. if (p->second->name() == param->name()) {
  2679. manage->Replace(param, p->first);
  2680. inputs->emplace_back(p->first);
  2681. param_list.erase(p);
  2682. break;
  2683. }
  2684. }
  2685. }
  2686. replace_weights_map_.erase(forward_graph);
  2687. }
  2688. void PynativeExecutor::Clear(const std::string &cell_id) {
  2689. if (cell_id.empty()) {
  2690. Clean();
  2691. return;
  2692. }
  2693. MS_LOG(DEBUG) << "Clear cell res, cell id " << cell_id;
  2694. for (auto it = graph_info_map_.begin(); it != graph_info_map_.end();) {
  2695. if (it->second->cell_id.find(cell_id) != std::string::npos) {
  2696. it = graph_info_map_.erase(it);
  2697. } else {
  2698. ++it;
  2699. }
  2700. }
  2701. // Maybe exit in runop step
  2702. auto ms_context = MsContext::GetInstance();
  2703. if (ms_context != nullptr) {
  2704. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
  2705. }
  2706. ConfigManager::GetInstance().ResetIterNum();
  2707. VectorClear<std::vector<CellInfoPtr>>(&cell_graph_list_, cell_id);
  2708. VectorClear<std::vector<TopCellInfoPtr>>(&top_cell_list_, cell_id);
  2709. node_abs_map_.clear();
  2710. }
  2711. void PynativeExecutor::Clean() {
  2712. MS_LOG(DEBUG) << "Clean";
  2713. SubNestedGradOrder();
  2714. node_abs_map_.clear();
  2715. obj_to_forward_id_.clear();
  2716. ad::CleanRes();
  2717. pipeline::ReclaimOptimizer();
  2718. }
  2719. void PynativeExecutor::ClearRes() {
  2720. MS_LOG(DEBUG) << "Clear all res";
  2721. Clean();
  2722. graph_id_ = 0;
  2723. grad_order_ = 0;
  2724. grad_flag_ = false;
  2725. has_dynamic_cell_ = false;
  2726. grad_is_running_ = false;
  2727. need_replace_forward_ = true;
  2728. curr_g_ = nullptr;
  2729. graph_info_map_.clear();
  2730. replace_weights_map_.clear();
  2731. cell_graph_list_.clear();
  2732. top_cell_list_.clear();
  2733. op_index_map_.clear();
  2734. cell_op_index_with_tensor_id_.clear();
  2735. cell_tensor_id_with_tensor_.clear();
  2736. prim_abs_list_.clear();
  2737. std::stack<FuncGraphPtr>().swap(graph_stack_);
  2738. }
  2739. void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
  2740. PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args);
  2741. }
  2742. void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
  2743. MS_LOG(DEBUG) << "Enter end graph process.";
  2744. auto &mem_cleaner = pipeline::Resource::mem_cleaner();
  2745. mem_cleaner.EnterPynativeEndGraphProcess();
  2746. PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
  2747. mem_cleaner.LeavePynativeEndGraphProcess();
  2748. MS_LOG(DEBUG) << "Leave end graph process.";
  2749. }
  2750. void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  2751. const py::args &args) {
  2752. PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
  2753. }
  2754. void PynativeExecutor::Sync() {
  2755. if (session == nullptr) {
  2756. MS_EXCEPTION(NotExistsError) << "No session has been created!";
  2757. }
  2758. session->SyncStream();
  2759. }
  2760. void PynativeExecutor::EnterConstruct(const py::object &cell) {
  2761. if (top_cell_ != nullptr) {
  2762. return;
  2763. }
  2764. top_cell_ = cell.ptr();
  2765. pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess();
  2766. MS_LOG(DEBUG) << "Enter construct process.";
  2767. }
  2768. void PynativeExecutor::LeaveConstruct(const py::object &cell) {
  2769. if (top_cell_ != cell.ptr()) {
  2770. return;
  2771. }
  2772. top_cell_ = nullptr;
  2773. pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess();
  2774. MS_LOG(DEBUG) << "Leave construct process.";
  2775. }
  2776. REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
  2777. (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
  2778. .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
  2779. .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
  2780. .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
  2781. .def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.")
  2782. .def("check_run", &PynativeExecutor::CheckAlreadyRun, "pynative check graph run before.")
  2783. .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
  2784. .def("clear", &PynativeExecutor::Clear, "pynative clear status.")
  2785. .def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
  2786. .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
  2787. .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
  2788. "Executor set grad flag.")
  2789. .def("enter_construct", &PynativeExecutor::EnterConstruct,
  2790. "Do something before enter construct function.")
  2791. .def("leave_construct", &PynativeExecutor::LeaveConstruct,
  2792. "Do something after leave construct function.");
  2793. }));
  2794. } // namespace mindspore::pynative