You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pynative_execute.cc 141 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/pynative/pynative_execute.h"
  17. #include <typeinfo>
  18. #include <set>
  19. #include <memory>
  20. #include <sstream>
  21. #include <algorithm>
  22. #include "utils/hash_map.h"
  23. #include "utils/hash_set.h"
  24. #include "debug/trace.h"
  25. #include "debug/anf_ir_dump.h"
  26. #include "pybind_api/api_register.h"
  27. #include "pybind_api/pybind_patch.h"
  28. #include "pybind_api/ir/tensor_py.h"
  29. #include "ir/param_info.h"
  30. #include "ir/anf.h"
  31. #include "ir/cell.h"
  32. #include "ir/tensor.h"
  33. #include "utils/any.h"
  34. #include "utils/utils.h"
  35. #include "utils/ms_context.h"
  36. #include "utils/check_convert_utils.h"
  37. #include "utils/context/context_extends.h"
  38. #include "utils/config_manager.h"
  39. #include "utils/convert_utils_py.h"
  40. #include "utils/scoped_long_running.h"
  41. #include "frontend/optimizer/ad/grad.h"
  42. #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
  43. #include "frontend/operator/ops.h"
  44. #include "frontend/operator/composite/do_signature.h"
  45. #include "frontend/parallel/context.h"
  46. #include "pipeline/jit/action.h"
  47. #include "pipeline/jit/pass.h"
  48. #include "pipeline/jit/parse/data_converter.h"
  49. #include "pipeline/jit/parse/parse_dynamic.h"
  50. #include "pipeline/jit/static_analysis/prim.h"
  51. #include "pipeline/jit/static_analysis/auto_monad.h"
  52. #include "pipeline/jit/pipeline.h"
  53. #include "pipeline/jit/resource.h"
  54. #include "pipeline/pynative/base.h"
  55. #include "backend/session/session_factory.h"
  56. #include "backend/optimizer/common/const_input_to_attr_registry.h"
  57. #include "backend/optimizer/common/helper.h"
  58. #include "runtime/hardware/device_context_manager.h"
  59. #include "vm/transform.h"
  60. #ifdef ENABLE_GE
  61. #include "pipeline/pynative/pynative_execute_ge.h"
  62. #endif
  63. using mindspore::tensor::TensorPy;
  64. namespace mindspore::pynative {
  65. PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
  66. ForwardExecutorPtr PynativeExecutor::forward_executor_ = nullptr;
  67. GradExecutorPtr PynativeExecutor::grad_executor_ = nullptr;
  68. std::mutex PynativeExecutor::instance_lock_;
  69. namespace {
  70. const size_t PTR_LEN = 15;
  71. const size_t ARG_SIZE = 2;
  72. const size_t MAX_TOP_CELL_COUNTS = 20;
  73. // primitive unable to infer value for constant input in PyNative mode
  74. const std::set<std::string> kVmOperators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
  75. "mixed_precision_cast"};
  76. const char kOpsFunctionModelName[] = "mindspore.ops.functional";
  77. const char kGrad[] = "grad";
  78. std::map<std::string, std::shared_ptr<session::SessionBasic>> kSessionBackends;
  79. std::map<std::string, std::shared_ptr<compile::MindRTBackend>> kMindRtBackends;
  80. PyObjectIdCache g_pyobj_id_cache;
  81. template <typename T, typename... Args>
  82. void PynativeExecutorTry(const std::function<void(T *ret, const Args &...)> &method, T *ret, const Args &... args) {
  83. const auto inst = PynativeExecutor::GetInstance();
  84. MS_EXCEPTION_IF_NULL(inst);
  85. MS_EXCEPTION_IF_NULL(method);
  86. try {
  87. method(ret, args...);
  88. } catch (const py::error_already_set &ex) {
  89. // print function call stack info before release
  90. std::ostringstream oss;
  91. trace::TraceGraphEval();
  92. trace::GetEvalStackInfo(oss);
  93. // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
  94. // these info from screen, no need to open log file to find these info
  95. py::print(oss.str());
  96. MS_LOG(ERROR) << oss.str();
  97. inst->ClearRes();
  98. // re-throw this exception to Python interpreter to handle it
  99. throw(py::error_already_set(ex));
  100. } catch (const py::type_error &ex) {
  101. inst->ClearRes();
  102. throw py::type_error(ex);
  103. } catch (const py::value_error &ex) {
  104. inst->ClearRes();
  105. throw py::value_error(ex);
  106. } catch (const py::index_error &ex) {
  107. inst->ClearRes();
  108. throw py::index_error(ex);
  109. } catch (const py::name_error &ex) {
  110. inst->ClearRes();
  111. throw py::name_error(ex);
  112. } catch (const std::exception &ex) {
  113. inst->ClearRes();
  114. // re-throw this exception to Python interpreter to handle it
  115. throw(std::runtime_error(ex.what()));
  116. } catch (...) {
  117. inst->ClearRes();
  118. auto exception_type = abi::__cxa_current_exception_type();
  119. MS_EXCEPTION_IF_NULL(exception_type);
  120. std::string ex_name(exception_type->name());
  121. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << ex_name;
  122. }
  123. }
  124. inline ValuePtr PyObjToValue(const py::object &obj) {
  125. ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
  126. if (!converted_ret) {
  127. MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
  128. }
  129. return converted_ret;
  130. }
  131. std::string GetPyObjId(const py::handle &obj) {
  132. py::object out = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
  133. if (py::isinstance<py::none>(out)) {
  134. MS_LOG(EXCEPTION) << "Get pyobj failed";
  135. }
  136. return out.cast<std::string>();
  137. }
  138. std::string GetId(const py::handle &obj) {
  139. if (py::isinstance<tensor::Tensor>(obj)) {
  140. auto tensor_ptr = py::cast<tensor::TensorPtr>(obj);
  141. if (tensor_ptr->is_parameter()) {
  142. const auto &param_info = tensor_ptr->param_info();
  143. MS_EXCEPTION_IF_NULL(param_info);
  144. return param_info->name();
  145. }
  146. return tensor_ptr->id();
  147. } else if (py::isinstance<mindspore::Type>(obj)) {
  148. auto type_ptr = py::cast<mindspore::TypePtr>(obj);
  149. return "type" + type_ptr->ToString();
  150. } else if (py::isinstance<py::str>(obj) || py::isinstance<py::int_>(obj) || py::isinstance<py::float_>(obj)) {
  151. return std::string(py::str(obj));
  152. } else if (py::isinstance<py::none>(obj)) {
  153. return "none";
  154. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  155. auto p_list = py::cast<py::tuple>(obj);
  156. string prefix = py::isinstance<py::tuple>(obj) ? "tuple:" : "list";
  157. if (p_list.empty()) {
  158. prefix = "empty";
  159. } else {
  160. std::string key;
  161. for (size_t i = 0; i < p_list.size(); ++i) {
  162. key += std::string(py::str(GetId(p_list[i]))) + ":";
  163. }
  164. prefix += key;
  165. }
  166. return prefix;
  167. }
  168. if (py::isinstance<Cell>(obj) || py::isinstance<py::function>(obj)) {
  169. auto it = g_pyobj_id_cache.find(obj);
  170. if (it == g_pyobj_id_cache.end()) {
  171. auto id = GetPyObjId(obj);
  172. g_pyobj_id_cache.emplace(obj, id);
  173. return id;
  174. } else {
  175. return it->second;
  176. }
  177. } else {
  178. return GetPyObjId(obj);
  179. }
  180. }
  181. void GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes,
  182. mindspore::HashMap<SignatureEnumDType, std::vector<size_t>> *type_indexes) {
  183. MS_EXCEPTION_IF_NULL(type_indexes);
  184. for (size_t i = 0; i < dtypes.size(); ++i) {
  185. auto it = type_indexes->find(dtypes[i]);
  186. if (it == type_indexes->end()) {
  187. (void)type_indexes->emplace(std::make_pair(dtypes[i], std::vector<size_t>{i}));
  188. } else {
  189. it->second.emplace_back(i);
  190. }
  191. }
  192. }
  193. TypeId JudgeMaxType(TypeId max_type, bool has_scalar_float32, bool has_scalar_int64, bool has_tensor_int8) {
  194. if (max_type == TypeId::kNumberTypeBool) {
  195. if (has_scalar_int64) {
  196. max_type = TypeId::kNumberTypeInt64;
  197. }
  198. if (has_scalar_float32) {
  199. max_type = TypeId::kNumberTypeFloat32;
  200. }
  201. }
  202. if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
  203. max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
  204. max_type = TypeId::kNumberTypeFloat32;
  205. }
  206. if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
  207. max_type = TypeId::kNumberTypeInt16;
  208. }
  209. return max_type;
  210. }
  211. std::string GetCurrentDeviceTarget(const std::string &device_target, const PrimitivePyPtr &op_prim) {
  212. MS_EXCEPTION_IF_NULL(op_prim);
  213. const auto &attr_map = op_prim->attrs();
  214. auto iter = attr_map.find("primitive_target");
  215. if (iter != attr_map.end()) {
  216. return GetValue<std::string>(iter->second);
  217. }
  218. return device_target;
  219. }
  220. session::SessionPtr GetCurrentSession(const std::string &device_target, uint32_t device_id) {
  221. auto iter = kSessionBackends.find(device_target);
  222. if (iter == kSessionBackends.end()) {
  223. auto session = session::SessionFactory::Get().Create(device_target);
  224. MS_EXCEPTION_IF_NULL(session);
  225. session->Init(device_id);
  226. kSessionBackends[device_target] = session;
  227. return session;
  228. } else {
  229. return iter->second;
  230. }
  231. }
  232. compile::MindRTBackendPtr GetMindRtBackend(const std::string &device_target, uint32_t device_id) {
  233. auto iter = kMindRtBackends.find(device_target);
  234. if (iter == kMindRtBackends.end()) {
  235. auto backend = std::make_shared<compile::MindRTBackend>("ms", device_target, device_id);
  236. MS_EXCEPTION_IF_NULL(backend);
  237. kMindRtBackends[device_target] = backend;
  238. return backend;
  239. } else {
  240. return iter->second;
  241. }
  242. }
  243. void GetDstType(const py::tuple &py_args,
  244. const mindspore::HashMap<SignatureEnumDType, std::vector<size_t>> &type_indexes,
  245. mindspore::HashMap<SignatureEnumDType, TypeId> *dst_type) {
  246. for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
  247. const auto &type = it->first;
  248. const auto &indexes = it->second;
  249. if (type == SignatureEnumDType::kDTypeEmptyDefaultValue || indexes.size() < ARG_SIZE) {
  250. continue;
  251. }
  252. size_t priority = 0;
  253. TypeId max_type = TypeId::kTypeUnknown;
  254. bool has_scalar_float32 = false;
  255. bool has_scalar_int64 = false;
  256. bool has_tensor_int8 = false;
  257. // Find the maximum priority of the same dtype
  258. for (size_t index : indexes) {
  259. if (index >= py_args.size()) {
  260. MS_LOG(EXCEPTION) << "The index " << index << " exceeds the size of py_args " << py_args.size();
  261. }
  262. const auto &obj = py_args[index];
  263. if (py::isinstance<py::float_>(obj)) {
  264. has_scalar_float32 = true;
  265. }
  266. if (!py::isinstance<py::bool_>(obj) && py::isinstance<py::int_>(obj)) {
  267. has_scalar_int64 = true;
  268. }
  269. if (py::isinstance<tensor::Tensor>(obj)) {
  270. auto arg = py::cast<tensor::TensorPtr>(obj);
  271. TypeId arg_type_id = arg->data_type();
  272. auto type_priority = prim::type_map.find(arg_type_id);
  273. if (type_priority == prim::type_map.end()) {
  274. continue;
  275. }
  276. if (arg_type_id == kNumberTypeInt8) {
  277. has_tensor_int8 = true;
  278. }
  279. if (type_priority->second > priority) {
  280. max_type = type_priority->first;
  281. priority = type_priority->second;
  282. }
  283. }
  284. }
  285. max_type = JudgeMaxType(max_type, has_scalar_float32, has_scalar_int64, has_tensor_int8);
  286. MS_EXCEPTION_IF_NULL(dst_type);
  287. (void)dst_type->emplace(std::make_pair(type, max_type));
  288. }
  289. }
  290. const std::string &TypeIdToMsTypeStr(const TypeId &type_id) {
  291. const auto &type_name = type_name_map.find(type_id);
  292. if (type_name == type_name_map.end()) {
  293. MS_LOG(EXCEPTION) << "For implicit type conversion, not support convert to the type: " << TypeIdToType(type_id);
  294. }
  295. return type_name->second;
  296. }
  297. bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
  298. MS_EXCEPTION_IF_NULL(prim);
  299. MS_EXCEPTION_IF_NULL(dtypes);
  300. const auto &signature = prim->signatures();
  301. bool has_sig_dtype = false;
  302. (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
  303. [&has_sig_dtype](const Signature &sig) {
  304. auto dtype = sig.dtype;
  305. if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
  306. has_sig_dtype = true;
  307. }
  308. return dtype;
  309. });
  310. return has_sig_dtype;
  311. }
  312. void PynativeInfer(const PrimitivePyPtr &prim, OpExecInfo *const op_exec_info,
  313. const abstract::AbstractBasePtrList &args_spec_list) {
  314. MS_EXCEPTION_IF_NULL(prim);
  315. MS_LOG(DEBUG) << "Prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list);
  316. prim->BeginRecordAddAttr();
  317. auto eval_ret = EvalOnePrim(prim, args_spec_list);
  318. MS_EXCEPTION_IF_NULL(eval_ret);
  319. AbstractBasePtr infer_res = eval_ret->abstract();
  320. MS_EXCEPTION_IF_NULL(infer_res);
  321. prim->EndRecordAddAttr();
  322. MS_EXCEPTION_IF_NULL(op_exec_info);
  323. op_exec_info->abstract = infer_res;
  324. MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
  325. MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString();
  326. }
  327. void GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::vector<tensor::TensorPtr> &input_tensors,
  328. const std::vector<int64_t> &tensors_mask, std::string *graph_info_key) {
  329. MS_EXCEPTION_IF_NULL(op_exec_info);
  330. MS_EXCEPTION_IF_NULL(graph_info_key);
  331. auto &graph_info = *graph_info_key;
  332. if (input_tensors.size() != tensors_mask.size()) {
  333. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  334. << tensors_mask.size();
  335. }
  336. std::ostringstream buf;
  337. buf << op_exec_info->op_name;
  338. bool has_const_input = false;
  339. for (size_t index = 0; index < input_tensors.size(); ++index) {
  340. MS_EXCEPTION_IF_NULL(input_tensors[index]);
  341. buf << input_tensors[index]->shape();
  342. buf << input_tensors[index]->data_type();
  343. buf << input_tensors[index]->padding_type();
  344. // In the case of the same shape, but dtype and format are inconsistent
  345. auto tensor_addr = input_tensors[index]->device_address();
  346. if (tensor_addr != nullptr) {
  347. auto p_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr);
  348. MS_EXCEPTION_IF_NULL(p_address);
  349. buf << p_address->type_id();
  350. buf << p_address->format();
  351. }
  352. // For constant input
  353. if (tensors_mask[index] == kValueNodeTensorMask) {
  354. has_const_input = true;
  355. auto dtype = input_tensors[index]->Dtype();
  356. MS_EXCEPTION_IF_NULL(dtype);
  357. if (dtype->type_id() == kNumberTypeInt64) {
  358. buf << *reinterpret_cast<int *>(input_tensors[index]->data_c());
  359. } else if (dtype->type_id() == kNumberTypeFloat32 || dtype->type_id() == kNumberTypeFloat16) {
  360. buf << *reinterpret_cast<float *>(input_tensors[index]->data_c());
  361. } else {
  362. MS_LOG(EXCEPTION) << "The dtype of the constant input is not int64 or float32!";
  363. }
  364. }
  365. buf << "_";
  366. }
  367. // The value of the attribute affects the operator selection
  368. const auto &op_prim = op_exec_info->py_primitive;
  369. MS_EXCEPTION_IF_NULL(op_prim);
  370. const auto &attr_map = op_prim->attrs();
  371. (void)std::for_each(attr_map.begin(), attr_map.end(),
  372. [&buf](const auto &element) { buf << element.second->ToString(); });
  373. // Constant input affects output, operators like DropoutGenMask whose output is related to values of input when input
  374. // shapes are the same but values are different
  375. if (has_const_input) {
  376. buf << "_";
  377. auto abstr = op_exec_info->abstract;
  378. MS_EXCEPTION_IF_NULL(abstr);
  379. auto build_shape = abstr->BuildShape();
  380. MS_EXCEPTION_IF_NULL(build_shape);
  381. buf << build_shape->ToString();
  382. auto build_type = abstr->BuildType();
  383. MS_EXCEPTION_IF_NULL(build_type);
  384. buf << build_type->type_id();
  385. }
  386. graph_info = buf.str();
  387. }
  388. py::list FilterTensorArgs(const py::args &args, bool has_sens = false) {
  389. size_t size = args.size();
  390. if (size == 0 && has_sens) {
  391. MS_LOG(EXCEPTION) << "The size of args is 0, when the flag of sens is set to True";
  392. }
  393. py::list only_tensors;
  394. size_t forward_args_size = has_sens ? size - 1 : size;
  395. for (size_t i = 0; i < forward_args_size; ++i) {
  396. if (py::isinstance<tensor::Tensor>(args[i])) {
  397. only_tensors.append(args[i]);
  398. }
  399. }
  400. if (has_sens) {
  401. only_tensors.append(args[forward_args_size]);
  402. }
  403. return only_tensors;
  404. }
  405. bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
  406. const mindspore::HashSet<size_t> &input_attrs) {
  407. MS_EXCEPTION_IF_NULL(op_prim);
  408. const auto &input_names_value = op_prim->GetAttr(kAttrInputNames);
  409. if (input_names_value == nullptr) {
  410. return false;
  411. }
  412. const auto &input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
  413. if (input_index >= input_names_vec.size()) {
  414. MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
  415. }
  416. if (input_attrs.find(input_index) != input_attrs.end()) {
  417. const auto &value = PyObjToValue(input_object);
  418. auto input_name = input_names_vec[input_index];
  419. op_prim->AddAttr(input_name, value);
  420. return true;
  421. }
  422. return false;
  423. }
  424. void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
  425. std::vector<tensor::TensorPtr> *input_tensors) {
  426. MS_EXCEPTION_IF_NULL(op_prim);
  427. MS_EXCEPTION_IF_NULL(input_tensors);
  428. for (const auto &input_object : tuple_inputs) {
  429. if (!py::isinstance<tensor::Tensor>(input_object)) {
  430. MS_LOG(EXCEPTION) << "The input object is not a tensor!";
  431. }
  432. auto tensor = py::cast<tensor::TensorPtr>(input_object);
  433. MS_EXCEPTION_IF_NULL(tensor);
  434. input_tensors->emplace_back(tensor);
  435. }
  436. op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(tuple_inputs.size())}));
  437. }
  438. void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensors) {
  439. MS_EXCEPTION_IF_NULL(input_tensors);
  440. const auto &input_value = PyObjToValue(input_object);
  441. MS_EXCEPTION_IF_NULL(input_value);
  442. if (!input_value->isa<ValueTuple>()) {
  443. MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
  444. }
  445. auto value_tuple = input_value->cast<ValueTuplePtr>();
  446. MS_EXCEPTION_IF_NULL(value_tuple);
  447. tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
  448. MS_EXCEPTION_IF_NULL(tensor_ptr);
  449. input_tensors->emplace_back(tensor_ptr);
  450. }
  451. void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  452. std::vector<tensor::TensorPtr> *input_tensors, int64_t *const tensor_mask) {
  453. MS_EXCEPTION_IF_NULL(op_prim);
  454. MS_EXCEPTION_IF_NULL(input_tensors);
  455. MS_EXCEPTION_IF_NULL(tensor_mask);
  456. if (!py::isinstance<py::tuple>(input_object)) {
  457. MS_LOG(EXCEPTION) << "The input should be a tuple!";
  458. }
  459. auto tuple_inputs = py::cast<py::tuple>(input_object);
  460. if (tuple_inputs.empty()) {
  461. MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
  462. }
  463. if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
  464. PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
  465. } else {
  466. ConvertValueTupleToTensor(input_object, input_tensors);
  467. *tensor_mask = kValueNodeTensorMask;
  468. }
  469. }
  470. void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  471. std::vector<tensor::TensorPtr> *input_tensors, int64_t *const tensor_mask) {
  472. MS_EXCEPTION_IF_NULL(op_prim);
  473. MS_EXCEPTION_IF_NULL(input_tensors);
  474. MS_EXCEPTION_IF_NULL(tensor_mask);
  475. tensor::TensorPtr tensor_ptr = nullptr;
  476. if (py::isinstance<tensor::Tensor>(input_object)) {
  477. tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
  478. } else if (py::isinstance<py::float_>(input_object)) {
  479. double input_value = py::cast<py::float_>(input_object);
  480. tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
  481. *tensor_mask = kValueNodeTensorMask;
  482. } else if (py::isinstance<py::int_>(input_object)) {
  483. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
  484. *tensor_mask = kValueNodeTensorMask;
  485. } else if (py::isinstance<py::array>(input_object)) {
  486. tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
  487. } else if (py::isinstance<py::list>(input_object)) {
  488. auto list_inputs = py::cast<py::list>(input_object);
  489. py::tuple tuple_inputs(list_inputs.size());
  490. for (size_t i = 0; i < tuple_inputs.size(); ++i) {
  491. tuple_inputs[i] = list_inputs[i];
  492. }
  493. ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
  494. return;
  495. } else if (py::isinstance<py::tuple>(input_object)) {
  496. ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
  497. return;
  498. } else if (py::isinstance<py::none>(input_object)) {
  499. return;
  500. } else {
  501. MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
  502. }
  503. MS_EXCEPTION_IF_NULL(tensor_ptr);
  504. input_tensors->emplace_back(tensor_ptr);
  505. }
  506. void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t> *tensors_mask,
  507. std::vector<tensor::TensorPtr> *input_tensors) {
  508. MS_EXCEPTION_IF_NULL(op_run_info);
  509. MS_EXCEPTION_IF_NULL(tensors_mask);
  510. MS_EXCEPTION_IF_NULL(input_tensors);
  511. PrimitivePtr op_prim = op_run_info->py_primitive;
  512. MS_EXCEPTION_IF_NULL(op_prim);
  513. // Checking whether attr conversion is needed.
  514. opt::ConstInputToAttrInfoRegister reg;
  515. bool reg_exist = false;
  516. if (op_run_info->op_name == prim::kPrimCustom->name()) {
  517. // Custom op needs to set reg dynamically
  518. mindspore::HashSet<size_t> attr_indexes;
  519. opt::GetCustomOpAttrIndex(op_prim, &attr_indexes);
  520. if (!attr_indexes.empty()) {
  521. reg_exist = true;
  522. reg.SetConstInputToAttr(attr_indexes);
  523. }
  524. } else {
  525. reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
  526. }
  527. if (op_run_info->is_dynamic_shape &&
  528. dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) {
  529. MS_LOG(DEBUG) << "current node is dynamic shape: " << op_run_info->op_name;
  530. reg_exist = false;
  531. }
  532. auto ms_context = MsContext::GetInstance();
  533. MS_EXCEPTION_IF_NULL(ms_context);
  534. const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  535. if (device_target != kCPUDevice && op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
  536. reg_exist = false;
  537. }
  538. // Gather op needs converting const input to attr on GPU device
  539. if (device_target != kGPUDevice && op_run_info->op_name == prim::kPrimGatherD->name()) {
  540. reg_exist = false;
  541. }
  542. // Get input tensors.
  543. op_prim->BeginRecordAddAttr();
  544. size_t input_num = op_run_info->op_inputs.size();
  545. if (input_num != op_run_info->inputs_mask.size()) {
  546. MS_LOG(EXCEPTION) << "The op input size " << input_num << ", but the size of input mask "
  547. << op_run_info->inputs_mask.size();
  548. }
  549. for (size_t index = 0; index < input_num; ++index) {
  550. // convert const input to attr
  551. if (reg_exist &&
  552. RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
  553. continue;
  554. }
  555. // convert const and tuple input to tensor
  556. int64_t tensor_mask = op_run_info->inputs_mask[index];
  557. ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
  558. // Mark tensors, common tensor data : 0, weight param: 1, valuenode(float_, int_): 2
  559. op_run_info->inputs_mask[index] = tensor_mask;
  560. std::vector<int64_t> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
  561. tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
  562. }
  563. op_prim->EndRecordAddAttr();
  564. }
  565. void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) {
  566. MS_EXCEPTION_IF_NULL(op_run_info);
  567. const auto &op_prim = op_run_info->py_primitive;
  568. MS_EXCEPTION_IF_NULL(op_prim);
  569. const auto &op_name = op_run_info->op_name;
  570. auto attrs = op_prim->attrs();
  571. for (auto attr : attrs) {
  572. bool converted = CheckAndConvertUtils::ConvertAttrValueToString(op_name, attr.first, &attr.second);
  573. if (converted) {
  574. op_prim->set_attr(attr.first, attr.second);
  575. }
  576. bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(op_name, attr.first, &attr.second);
  577. if (converted_ir_attr) {
  578. op_prim->set_attr(attr.first, attr.second);
  579. }
  580. }
  581. }
  582. size_t GetTupleSize(const py::tuple &args) {
  583. size_t count = 0;
  584. for (size_t i = 0; i < args.size(); i++) {
  585. if (py::isinstance<py::tuple>(args[i])) {
  586. count += GetTupleSize(args[i]);
  587. } else {
  588. count += 1;
  589. }
  590. }
  591. return count;
  592. }
  593. void ConvertTupleArg(py::tuple *res, size_t *const index, const py::tuple &arg) {
  594. MS_EXCEPTION_IF_NULL(res);
  595. MS_EXCEPTION_IF_NULL(index);
  596. auto res_size = res->size();
  597. for (size_t i = 0; i < arg.size(); i++) {
  598. if (py::isinstance<py::tuple>(arg[i])) {
  599. ConvertTupleArg(res, index, arg[i]);
  600. } else {
  601. if (*index >= res_size) {
  602. MS_LOG(EXCEPTION) << "Convert tuple error, index is greater than tuple size, index " << (*index)
  603. << ", tuple size " << res_size;
  604. }
  605. (*res)[(*index)++] = arg[i];
  606. }
  607. }
  608. }
  609. py::tuple ConvertArgs(const py::tuple &args) {
  610. size_t tuple_size = GetTupleSize(args);
  611. py::tuple res(tuple_size);
  612. size_t index = 0;
  613. for (size_t i = 0; i < args.size(); i++) {
  614. if (py::isinstance<py::tuple>(args[i])) {
  615. ConvertTupleArg(&res, &index, args[i]);
  616. } else {
  617. if (index >= tuple_size) {
  618. MS_LOG(EXCEPTION) << "Convert error, index is greater than tuple size, index " << index << ", tuple size "
  619. << tuple_size;
  620. }
  621. res[index++] = args[i];
  622. }
  623. }
  624. return res;
  625. }
  626. void ResetTopCellInfo(const TopCellInfoPtr &top_cell, const py::args &args) {
  627. MS_EXCEPTION_IF_NULL(top_cell);
  628. top_cell->set_op_num(0);
  629. top_cell->all_op_info().clear();
  630. top_cell->set_forward_already_run(true);
  631. std::string input_args_id;
  632. for (size_t i = 0; i < args.size(); ++i) {
  633. input_args_id += GetId(args[i]) + "_";
  634. }
  635. top_cell->set_input_args_id(input_args_id);
  636. }
  637. void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
  638. const FuncGraphPtr &grad_graph) {
  639. MS_EXCEPTION_IF_NULL(grad_graph);
  640. MS_EXCEPTION_IF_NULL(added_make_tuple);
  641. size_t index = 0;
  642. for (size_t i = 1; i < added_make_tuple->size(); ++i) {
  643. const auto &input_i = added_make_tuple->input(i);
  644. MS_EXCEPTION_IF_NULL(input_i);
  645. auto cnode = input_i->cast<CNodePtr>();
  646. MS_EXCEPTION_IF_NULL(cnode);
  647. MS_LOG(DEBUG) << "Replace new output tensors for cnode: " << cnode->DebugString();
  648. auto output_vnode = cnode->forward().first;
  649. MS_EXCEPTION_IF_NULL(output_vnode);
  650. grad_graph->AddValueNode(output_vnode);
  651. MS_LOG(DEBUG) << "Original output value node: " << output_vnode << " info: " << output_vnode->ToString();
  652. size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
  653. if (index + output_num > total_output_tensors.size()) {
  654. MS_LOG(EXCEPTION) << "The size of total_output_tensors: " << total_output_tensors.size()
  655. << ", but the current index: " << index << ", output num: " << output_num;
  656. }
  657. // Get new tensors.
  658. std::vector<ValuePtr> new_values;
  659. for (size_t j = index; j < index + output_num; ++j) {
  660. new_values.push_back(total_output_tensors[j]);
  661. }
  662. index = index + output_num;
  663. // Replace new tensors.
  664. if (output_num == 1) {
  665. output_vnode->set_value(new_values[0]);
  666. } else if (output_num > 1) {
  667. output_vnode->set_value(std::make_shared<ValueTuple>(new_values));
  668. } else {
  669. MS_LOG(EXCEPTION) << "The output value of forward cnode is empty, forward cnode info: " << cnode->ToString();
  670. }
  671. MS_LOG(DEBUG) << "New output value node: " << output_vnode << " info: " << output_vnode->ToString();
  672. }
  673. // Save op info with new tensors for current running ms_function func graph.
  674. if (index != total_output_tensors.size()) {
  675. MS_LOG(EXCEPTION) << "The index: " << index
  676. << " should be equal to the size of total_output_tensors: " << total_output_tensors.size();
  677. }
  678. }
  679. void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const OpExecInfoPtr &op_exec_info,
  680. const ValuePtr &added_out, const FuncGraphPtr &ms_func_graph,
  681. const FuncGraphPtr &grad_graph) {
  682. MS_EXCEPTION_IF_NULL(top_cell);
  683. MS_EXCEPTION_IF_NULL(grad_graph);
  684. MS_EXCEPTION_IF_NULL(op_exec_info);
  685. MS_EXCEPTION_IF_NULL(ms_func_graph);
  686. // Get added forward nodes.
  687. auto merge_node = ms_func_graph->output();
  688. MS_EXCEPTION_IF_NULL(merge_node);
  689. auto merge_make_tuple = merge_node->cast<CNodePtr>();
  690. MS_EXCEPTION_IF_NULL(merge_make_tuple);
  691. constexpr size_t merge_output_size = 3;
  692. if (merge_make_tuple->size() != merge_output_size) {
  693. MS_LOG(EXCEPTION) << "The input size of merge make tuple node should be 3, but it is: " << merge_make_tuple->size();
  694. }
  695. constexpr size_t added_output_index = 2;
  696. const auto &added_forward_node = merge_make_tuple->input(added_output_index);
  697. MS_EXCEPTION_IF_NULL(added_forward_node);
  698. if (added_forward_node->isa<ValueNode>()) {
  699. MS_LOG(DEBUG) << "The added forward output node is value node: " << added_forward_node->DebugString();
  700. std::vector<tensor::TensorPtr> total_output_tensors;
  701. TensorValueToTensor(added_out, &total_output_tensors);
  702. top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors);
  703. return;
  704. }
  705. // Replace new output tensors for forward nodes, it will also work in grad graph with same value node.
  706. auto added_make_tuple = added_forward_node->cast<CNodePtr>();
  707. MS_EXCEPTION_IF_NULL(added_make_tuple);
  708. MS_LOG(DEBUG) << "The added forward make tuple node info: " << added_make_tuple->DebugString();
  709. std::vector<tensor::TensorPtr> total_output_tensors;
  710. TensorValueToTensor(added_out, &total_output_tensors);
  711. RunReplace(added_make_tuple, total_output_tensors, grad_graph);
  712. top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors);
  713. }
  714. void SaveOpInfo(const TopCellInfoPtr &top_cell, const std::string &op_info,
  715. const std::vector<tensor::TensorPtr> &op_out_tensors) {
  716. MS_EXCEPTION_IF_NULL(top_cell);
  717. auto &op_info_with_tensor_id = top_cell->op_info_with_tensor_id();
  718. if (op_info_with_tensor_id.find(op_info) != op_info_with_tensor_id.end()) {
  719. MS_LOG(EXCEPTION) << "Top cell: " << top_cell.get() << " records op info with tensor id, but get op info "
  720. << op_info << " in op_info_with_tensor_id map";
  721. }
  722. // Record the relationship between the forward op and its output tensor id
  723. std::for_each(op_out_tensors.begin(), op_out_tensors.end(),
  724. [&op_info_with_tensor_id, &op_info](const tensor::TensorPtr &tensor) {
  725. op_info_with_tensor_id[op_info].emplace_back(tensor->id());
  726. });
  727. }
  728. void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<tensor::TensorPtr> &pre_tensors) {
  729. MS_EXCEPTION_IF_NULL(new_tensor);
  730. if (pre_tensors.empty()) {
  731. MS_LOG(EXCEPTION) << "The size of pre tensors is empty.";
  732. }
  733. const auto &device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  734. for (auto &pre_tensor : pre_tensors) {
  735. MS_EXCEPTION_IF_NULL(pre_tensor);
  736. MS_LOG(DEBUG) << "Replace Old tensor " << pre_tensor.get() << " id " << pre_tensor->id()
  737. << " device_address: " << pre_tensor->device_address() << " shape and type "
  738. << pre_tensor->GetShapeAndDataTypeInfo() << " with New tensor " << new_tensor.get() << " id "
  739. << new_tensor->id() << " device_address " << new_tensor->device_address() << " shape and dtype "
  740. << new_tensor->GetShapeAndDataTypeInfo();
  741. pre_tensor->set_shape(new_tensor->shape());
  742. pre_tensor->set_data_type(new_tensor->data_type());
  743. auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
  744. MS_EXCEPTION_IF_NULL(device_address);
  745. if (device_target != kCPUDevice && device_address->DeviceType() != device::DeviceAddressType::kCPU) {
  746. pre_tensor->set_device_address(new_tensor->device_address());
  747. continue;
  748. }
  749. for (auto &item : kMindRtBackends) {
  750. MS_EXCEPTION_IF_NULL(item.second);
  751. item.second->SyncLazyTasks();
  752. }
  753. // Replace data in device address when run in CPU device.
  754. if (pre_tensor->device_address() != nullptr) {
  755. auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(pre_tensor->device_address());
  756. MS_EXCEPTION_IF_NULL(old_device_address);
  757. auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
  758. MS_EXCEPTION_IF_NULL(new_device_address);
  759. auto old_ptr = old_device_address->GetMutablePtr();
  760. MS_EXCEPTION_IF_NULL(old_ptr);
  761. auto new_ptr = new_device_address->GetPtr();
  762. MS_EXCEPTION_IF_NULL(new_ptr);
  763. MS_EXCEPTION_IF_CHECK_FAIL(old_device_address->GetSize() == new_device_address->GetSize(), "Size not equal");
  764. if (old_device_address->GetSize() < SECUREC_MEM_MAX_LEN) {
  765. auto ret_code = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize());
  766. MS_EXCEPTION_IF_CHECK_FAIL(ret_code == EOK, "Memory copy failed, ret code: " + std::to_string(ret_code));
  767. } else {
  768. auto ret_code = std::memcpy(old_ptr, new_ptr, old_device_address->GetSize());
  769. MS_EXCEPTION_IF_CHECK_FAIL(ret_code == old_ptr, "Memory copy failed");
  770. }
  771. } else {
  772. pre_tensor->set_device_address(device_address);
  773. pre_tensor->data_sync();
  774. pre_tensor->set_device_address(nullptr);
  775. pre_tensor->set_sync_status(kNeedSyncHostToDevice);
  776. }
  777. }
  778. }
  779. void CheckPyNativeContext() {
  780. const auto &parallel_context = parallel::ParallelContext::GetInstance();
  781. MS_EXCEPTION_IF_NULL(parallel_context);
  782. const auto &ms_context = MsContext::GetInstance();
  783. MS_EXCEPTION_IF_NULL(ms_context);
  784. const auto &parallel_mode = parallel_context->parallel_mode();
  785. if (parallel_mode != parallel::STAND_ALONE && parallel_mode != parallel::DATA_PARALLEL &&
  786. ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
  787. MS_LOG(EXCEPTION) << "PyNative Only support STAND_ALONE and DATA_PARALLEL, but got:" << parallel_mode;
  788. }
  789. }
  790. py::object GetDstType(const TypeId &type_id) {
  791. ValuePtr value = nullptr;
  792. if (type_id == kNumberTypeFloat16) {
  793. value = std::make_shared<Float>(16);
  794. } else if (type_id == kNumberTypeFloat32) {
  795. value = std::make_shared<Float>(32);
  796. } else if (type_id == kNumberTypeFloat64) {
  797. value = std::make_shared<Float>(64);
  798. } else if (type_id == kNumberTypeBool) {
  799. value = std::make_shared<Bool>();
  800. } else if (type_id == kNumberTypeInt8) {
  801. value = std::make_shared<Int>(8);
  802. } else if (type_id == kNumberTypeUInt8) {
  803. value = std::make_shared<UInt>(8);
  804. } else if (type_id == kNumberTypeInt16) {
  805. value = std::make_shared<Int>(16);
  806. } else if (type_id == kNumberTypeInt32) {
  807. value = std::make_shared<Int>(32);
  808. } else if (type_id == kNumberTypeInt64) {
  809. value = std::make_shared<Int>(64);
  810. } else {
  811. MS_LOG(EXCEPTION) << "Not support dst type";
  812. }
  813. MS_EXCEPTION_IF_NULL(value);
  814. return py::cast(value);
  815. }
  816. bool IsPyObjTypeInvalid(const py::object &obj) {
  817. return !py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj);
  818. }
  819. } // namespace
  820. py::object RealRunOp(const py::args &args) {
  821. CheckPyNativeContext();
  822. const auto &executor = PynativeExecutor::GetInstance();
  823. MS_EXCEPTION_IF_NULL(executor);
  824. OpExecInfoPtr op_exec_info = executor->forward_executor()->GenerateOpExecInfo(args);
  825. MS_EXCEPTION_IF_NULL(op_exec_info);
  826. py::object ret = py::none();
  827. PynativeExecutorTry(executor->forward_executor()->RunOpS, &ret, op_exec_info);
  828. return ret;
  829. }
  830. GradExecutorPtr ForwardExecutor::grad() const {
  831. auto grad_executor = grad_executor_.lock();
  832. MS_EXCEPTION_IF_NULL(grad_executor);
  833. return grad_executor;
  834. }
  835. bool TopCellInfo::IsSubCell(const std::string &cell_id) const {
  836. if (sub_cell_list_.empty()) {
  837. MS_LOG(DEBUG) << "The sub cell list is empty, there is no sub cell";
  838. return false;
  839. }
  840. if (sub_cell_list_.find(cell_id) != sub_cell_list_.end()) {
  841. return true;
  842. }
  843. return false;
  844. }
  845. void TopCellInfo::ClearDeviceMemory() {
  846. MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
  847. auto ms_context = MsContext::GetInstance();
  848. MS_EXCEPTION_IF_NULL(ms_context);
  849. const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  850. if (device_target == kCPUDevice) {
  851. MS_LOG(DEBUG) << "No need to clear device address when run in CPU device.";
  852. return;
  853. }
  854. k_pynative_cell_ptr_ = nullptr;
  855. // Get all tensors obj in value node of running graph
  856. std::vector<tensor::TensorPtr> tensors_in_bprop_graph;
  857. MS_EXCEPTION_IF_NULL(resource_);
  858. const auto &bprop_graph = resource_->func_graph();
  859. MS_EXCEPTION_IF_NULL(bprop_graph);
  860. const auto &value_node_list = bprop_graph->value_nodes();
  861. for (const auto &elem : value_node_list) {
  862. auto &node = elem.first;
  863. MS_EXCEPTION_IF_NULL(node);
  864. auto value_node = node->cast<ValueNodePtr>();
  865. MS_EXCEPTION_IF_NULL(value_node);
  866. TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
  867. }
  868. for (const auto &tensor : tensors_in_bprop_graph) {
  869. MS_EXCEPTION_IF_NULL(tensor);
  870. MS_LOG(DEBUG) << "Clear device address for tensor: " << tensor->ToString();
  871. tensor->set_device_address(nullptr);
  872. }
  873. }
  874. void TopCellInfo::Clear() {
  875. MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
  876. op_num_ = 0;
  877. is_dynamic_ = false;
  878. vm_compiled_ = false;
  879. ms_function_flag_ = false;
  880. is_init_kpynative_ = false;
  881. need_compile_graph_ = false;
  882. forward_already_run_ = false;
  883. input_args_id_.clear();
  884. all_op_info_.clear();
  885. resource_ = nullptr;
  886. df_builder_ = nullptr;
  887. fg_ = nullptr;
  888. k_pynative_cell_ptr_ = nullptr;
  889. graph_info_map_.clear();
  890. sub_cell_list_.clear();
  891. op_info_with_tensor_id_.clear();
  892. tensor_id_with_tensor_object_.clear();
  893. op_info_with_ms_func_forward_tensors_.clear();
  894. }
  895. void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_info) {
  896. MS_EXCEPTION_IF_NULL(ret);
  897. MS_EXCEPTION_IF_NULL(op_exec_info);
  898. MS_LOG(DEBUG) << "RunOp name: " << op_exec_info->op_name;
  899. if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
  900. RunMixedPrecisionCastOp(op_exec_info, ret);
  901. return;
  902. }
  903. // 1.Set cast for inputs
  904. SetCastForInputs(op_exec_info);
  905. // 2.Construct graph, first step abs will update by node
  906. auto cnode = ConstructForwardGraph(op_exec_info);
  907. // 3.Get inputs abstract
  908. abstract::AbstractBasePtrList args_spec_list;
  909. GetInputsArgsSpec(op_exec_info, &args_spec_list);
  910. // 4.Get output abstract
  911. bool prim_cache_hit = false;
  912. GetOpOutputAbstract(op_exec_info, args_spec_list, &prim_cache_hit);
  913. // 5.Get output
  914. GetOpOutput(op_exec_info, args_spec_list, cnode, prim_cache_hit, ret);
  915. }
  916. OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
  917. if (args.size() != PY_ARGS_NUM) {
  918. MS_LOG(EXCEPTION) << "Three args are needed by RunOp";
  919. }
  920. const auto &op_exec_info = std::make_shared<OpExecInfo>();
  921. const auto &op_name = py::cast<std::string>(args[PY_NAME]);
  922. op_exec_info->op_name = op_name;
  923. const auto &adapter = py::cast<PrimitivePyAdapterPtr>(args[PY_PRIM]);
  924. MS_EXCEPTION_IF_NULL(adapter);
  925. auto prim = adapter->attached_primitive();
  926. if (prim == nullptr) {
  927. prim = std::make_shared<PrimitivePy>(args[PY_PRIM], adapter);
  928. adapter->set_attached_primitive(prim);
  929. }
  930. if (!prim->HasPyObj()) {
  931. MS_LOG(EXCEPTION) << "Pyobj is empty";
  932. }
  933. op_exec_info->py_primitive = prim;
  934. op_exec_info->op_inputs = args[PY_INPUTS];
  935. op_exec_info->lazy_build = lazy_build_;
  936. return op_exec_info;
  937. }
  938. void ForwardExecutor::SetCastForInputs(const OpExecInfoPtr &op_exec_info) {
  939. MS_EXCEPTION_IF_NULL(op_exec_info);
  940. // No need cast self
  941. if (op_exec_info->op_name == prim::kPrimCast->name()) {
  942. return;
  943. }
  944. // Mixed precision conversion tensors which has cast dtype
  945. SetTensorMixPrecisionCast(op_exec_info);
  946. // Implicit transform
  947. SetImplicitCast(op_exec_info);
  948. }
  949. void ForwardExecutor::RunMixedPrecisionCastOp(const OpExecInfoPtr &op_exec_info, py::object *ret) {
  950. MS_EXCEPTION_IF_NULL(ret);
  951. MS_EXCEPTION_IF_NULL(op_exec_info);
  952. py::tuple res = RunOpWithInitBackendPolicy(op_exec_info);
  953. if (res.size() == 1) {
  954. *ret = res[0];
  955. return;
  956. }
  957. *ret = std::move(res);
  958. }
  959. void ForwardExecutor::SetNonCostantValueAbs(const AbstractBasePtr &abs, size_t i, const std::string &id) {
  960. MS_EXCEPTION_IF_NULL(abs);
  961. if (abs->isa<abstract::AbstractTensor>()) {
  962. abs->set_value(kAnyValue);
  963. } else if (abs->isa<abstract::AbstractTuple>() || abs->isa<abstract::AbstractList>()) {
  964. const auto &abs_seq = abs->cast<abstract::AbstractSequeuePtr>();
  965. MS_EXCEPTION_IF_NULL(abs_seq);
  966. for (auto &item : abs_seq->elements()) {
  967. MS_EXCEPTION_IF_NULL(item);
  968. if (item->isa<abstract::AbstractTensor>()) {
  969. item->set_value(kAnyValue);
  970. }
  971. }
  972. }
  973. MS_LOG(DEBUG) << "Set " << i << "th abs " << abs->ToString();
  974. node_abs_map_[id] = abs;
  975. }
  976. void ForwardExecutor::GetInputsArgsSpec(const OpExecInfoPtr &op_exec_info,
  977. abstract::AbstractBasePtrList *args_spec_list) {
  978. MS_EXCEPTION_IF_NULL(op_exec_info);
  979. MS_EXCEPTION_IF_NULL(args_spec_list);
  980. auto prim = op_exec_info->py_primitive;
  981. MS_EXCEPTION_IF_NULL(prim);
  982. for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
  983. abstract::AbstractBasePtr abs = nullptr;
  984. const auto &obj = op_exec_info->op_inputs[i];
  985. const auto &id = GetId(obj);
  986. MS_LOG(DEBUG) << "Set input abs " << id;
  987. auto it = node_abs_map_.find(id);
  988. if (it != node_abs_map_.end()) {
  989. abs = it->second;
  990. }
  991. const auto const_input_index = prim->get_const_input_indexes();
  992. bool have_const_input = !const_input_index.empty();
  993. bool is_const_prim = prim->is_const_prim();
  994. MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
  995. << prim->is_const_prim();
  996. bool is_const_input =
  997. have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
  998. if (abs == nullptr || is_const_prim || is_const_input) {
  999. abs = PyObjToValue(obj)->ToAbstract();
  1000. if (!is_const_prim && !is_const_input) {
  1001. SetNonCostantValueAbs(abs, i, id);
  1002. }
  1003. }
  1004. args_spec_list->emplace_back(abs);
  1005. }
  1006. }
  1007. CNodePtr ForwardExecutor::ConstructForwardGraph(const OpExecInfoPtr &op_exec_info) {
  1008. MS_EXCEPTION_IF_NULL(op_exec_info);
  1009. auto prim = op_exec_info->py_primitive;
  1010. std::vector<AnfNodePtr> inputs;
  1011. std::vector<int64_t> op_masks;
  1012. inputs.emplace_back(NewValueNode(prim));
  1013. for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
  1014. const auto &obj = op_exec_info->op_inputs[i];
  1015. bool op_mask = false;
  1016. tensor::MetaTensorPtr meta_tensor = nullptr;
  1017. if (py::isinstance<tensor::MetaTensor>(obj)) {
  1018. meta_tensor = obj.cast<tensor::MetaTensorPtr>();
  1019. if (meta_tensor) {
  1020. op_mask = meta_tensor->is_parameter();
  1021. }
  1022. }
  1023. MS_LOG(DEBUG) << "Args i " << i << ", op mask " << op_mask;
  1024. op_masks.emplace_back(static_cast<int64_t>(op_mask));
  1025. // Construct grad graph
  1026. if (grad()->need_construct_graph()) {
  1027. const auto &id = GetId(obj);
  1028. AnfNodePtr input_node = nullptr;
  1029. input_node = grad()->GetInput(obj, op_mask);
  1030. // update abstract
  1031. if (input_node != nullptr) {
  1032. if (input_node->abstract() != nullptr) {
  1033. abstract::AbstractBasePtr abs = input_node->abstract();
  1034. node_abs_map_[id] = abs;
  1035. }
  1036. inputs.emplace_back(input_node);
  1037. }
  1038. }
  1039. }
  1040. op_exec_info->inputs_mask = std::move(op_masks);
  1041. CNodePtr cnode = nullptr;
  1042. if (grad()->need_construct_graph()) {
  1043. cnode = grad()->curr_g()->NewCNodeInOrder(inputs);
  1044. MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << ", new cnode is " << cnode->DebugString();
  1045. }
  1046. return cnode;
  1047. }
  1048. void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
  1049. const abstract::AbstractBasePtrList &args_spec_list, bool *prim_cache_hit) {
  1050. MS_EXCEPTION_IF_NULL(op_exec_info);
  1051. MS_EXCEPTION_IF_NULL(prim_cache_hit);
  1052. auto op_name = op_exec_info->op_name;
  1053. auto prim = op_exec_info->py_primitive;
  1054. MS_EXCEPTION_IF_NULL(prim);
  1055. AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
  1056. auto temp = prim_abs_list_.find(key);
  1057. if (temp != prim_abs_list_.end()) {
  1058. MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
  1059. auto iter = temp->second.find(args_spec_list);
  1060. if (iter != temp->second.end()) {
  1061. MS_LOG(DEBUG) << "Match prim ok " << op_name;
  1062. op_exec_info->abstract = iter->second.abs;
  1063. prim->set_evaluate_added_attrs(iter->second.attrs);
  1064. *prim_cache_hit = true;
  1065. }
  1066. }
  1067. if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) {
  1068. // Use python infer method
  1069. if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) {
  1070. PynativeInfer(prim, op_exec_info.get(), args_spec_list);
  1071. }
  1072. }
  1073. // Get output dynamic shape info
  1074. auto abstract = op_exec_info->abstract;
  1075. MS_EXCEPTION_IF_NULL(abstract);
  1076. auto shape = abstract->BuildShape();
  1077. MS_EXCEPTION_IF_NULL(shape);
  1078. if (shape->IsDynamic()) {
  1079. op_exec_info->is_dynamic_shape = true;
  1080. // Dynamic shape operator in the current top cell, disable backend cache
  1081. grad()->EnableOpGraphCache(false);
  1082. }
  1083. }
  1084. void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
  1085. const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode,
  1086. bool prim_cache_hit, py::object *ret) {
  1087. MS_EXCEPTION_IF_NULL(op_exec_info);
  1088. auto prim = op_exec_info->py_primitive;
  1089. MS_EXCEPTION_IF_NULL(prim);
  1090. // Infer output value by constant folding
  1091. MS_EXCEPTION_IF_NULL(ret);
  1092. py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
  1093. if (!output["value"].is_none()) {
  1094. *ret = output["value"];
  1095. grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret));
  1096. return;
  1097. }
  1098. if (prim->is_const_prim()) {
  1099. *ret = py::cast("");
  1100. grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret));
  1101. return;
  1102. }
  1103. // Add output abstract info into cache, the const value needs to infer evert step
  1104. if (grad()->enable_op_cache() && !prim_cache_hit && !op_exec_info->is_dynamic_shape) {
  1105. AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
  1106. auto &out = prim_abs_list_[key];
  1107. out[args_spec_list].abs = op_exec_info->abstract;
  1108. out[args_spec_list].attrs = prim->evaluate_added_attrs();
  1109. }
  1110. // run op with selected backend
  1111. auto result = RunOpWithInitBackendPolicy(op_exec_info);
  1112. py::object out_real = result;
  1113. if (result.size() == 1 && op_exec_info->abstract != nullptr &&
  1114. !op_exec_info->abstract->isa<abstract::AbstractSequeue>()) {
  1115. out_real = result[0];
  1116. }
  1117. // get output value
  1118. ValuePtr out_real_value = nullptr;
  1119. if (grad()->grad_flag()) {
  1120. out_real_value = PyObjToValue(out_real);
  1121. }
  1122. // Save cnode info and build grad graph
  1123. if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) {
  1124. MS_EXCEPTION_IF_NULL(cnode);
  1125. const auto &obj_id = GetId(out_real);
  1126. cnode->set_abstract(op_exec_info->abstract);
  1127. node_abs_map_[obj_id] = op_exec_info->abstract;
  1128. grad()->SaveOutputNodeMap(obj_id, out_real, cnode);
  1129. grad()->DoOpGrad(op_exec_info, cnode, out_real_value);
  1130. } else {
  1131. node_abs_map_.clear();
  1132. }
  1133. // Record op info for judge whether the construct of cell has been changed
  1134. grad()->RecordGradOpInfo(op_exec_info, out_real_value);
  1135. grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value);
  1136. *ret = out_real;
  1137. }
  1138. py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
  1139. size_t index) {
  1140. static py::object cast_prim = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
  1141. const auto &op_exec_info = std::make_shared<OpExecInfo>();
  1142. op_exec_info->op_name = prim::kPrimCast->name();
  1143. const auto &adapter = py::cast<PrimitivePyAdapterPtr>(cast_prim);
  1144. MS_EXCEPTION_IF_NULL(adapter);
  1145. auto prim = adapter->attached_primitive();
  1146. if (prim == nullptr) {
  1147. prim = std::make_shared<PrimitivePy>(cast_prim, adapter);
  1148. adapter->set_attached_primitive(prim);
  1149. }
  1150. op_exec_info->py_primitive = prim;
  1151. op_exec_info->is_mixed_precision_cast = true;
  1152. op_exec_info->next_op_name = op_name;
  1153. op_exec_info->next_input_index = index;
  1154. py::object dst_type = GetDstType(type_id);
  1155. py::tuple inputs(ARG_SIZE);
  1156. inputs[0] = arg;
  1157. inputs[1] = dst_type;
  1158. op_exec_info->op_inputs = inputs;
  1159. op_exec_info->lazy_build = lazy_build_;
  1160. py::object ret = py::none();
  1161. RunOpInner(&ret, op_exec_info);
  1162. return ret;
  1163. }
  1164. py::object ForwardExecutor::DoAutoCastTuple(const py::tuple &tuple, const TypeId &type_id, const std::string &op_name,
  1165. size_t index) {
  1166. auto tuple_size = tuple.size();
  1167. py::tuple result(tuple_size);
  1168. for (size_t i = 0; i < tuple_size; i++) {
  1169. if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) {
  1170. result[i] = DoAutoCastTuple(tuple[i], type_id, op_name, index);
  1171. } else {
  1172. result[i] = DoAutoCast(tuple[i], type_id, op_name, index);
  1173. }
  1174. }
  1175. return std::move(result);
  1176. }
  1177. py::object ForwardExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name,
  1178. size_t index) {
  1179. MS_EXCEPTION_IF_NULL(is_cast);
  1180. const auto &tensor = py::cast<tensor::TensorPtr>(obj);
  1181. MS_EXCEPTION_IF_NULL(tensor);
  1182. const auto &cast_type = tensor->cast_dtype();
  1183. if (cast_type != nullptr) {
  1184. auto source_element = tensor->Dtype();
  1185. if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
  1186. MS_LOG(DEBUG) << "Cast to " << cast_type->ToString();
  1187. *is_cast = true;
  1188. return DoAutoCast(obj, cast_type->type_id(), op_name, index);
  1189. }
  1190. }
  1191. return obj;
  1192. }
  1193. py::object ForwardExecutor::DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple,
  1194. const std::string &op_name, size_t index) {
  1195. MS_EXCEPTION_IF_NULL(is_cast);
  1196. auto tuple_size = tuple.size();
  1197. py::tuple result(tuple_size);
  1198. for (size_t i = 0; i < tuple_size; i++) {
  1199. if (py::isinstance<tensor::MetaTensor>(tuple[i])) {
  1200. MS_LOG(DEBUG) << "Call cast for item " << i;
  1201. result[i] = DoParamMixPrecisionCast(is_cast, tuple[i], op_name, index);
  1202. } else if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) {
  1203. result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i], op_name, index);
  1204. } else {
  1205. result[i] = tuple[i];
  1206. }
  1207. }
  1208. return std::move(result);
  1209. }
  1210. void ForwardExecutor::DoSignatureCast(const PrimitivePyPtr &prim,
  1211. const mindspore::HashMap<SignatureEnumDType, TypeId> &dst_type,
  1212. const std::vector<SignatureEnumDType> &dtypes,
  1213. const OpExecInfoPtr &op_exec_info) {
  1214. MS_EXCEPTION_IF_NULL(prim);
  1215. MS_EXCEPTION_IF_NULL(op_exec_info);
  1216. const auto &signature = prim->signatures();
  1217. auto &input_args = op_exec_info->op_inputs;
  1218. size_t input_args_size = input_args.size();
  1219. for (size_t i = 0; i < input_args_size; ++i) {
  1220. // No need to implicit cast if no dtype.
  1221. if (dtypes.empty() || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
  1222. continue;
  1223. }
  1224. auto it = dst_type.find(dtypes[i]);
  1225. if (it == dst_type.end() || it->second == kTypeUnknown) {
  1226. continue;
  1227. }
  1228. MS_LOG(DEBUG) << "Check inputs " << i;
  1229. const auto &obj = input_args[i];
  1230. auto sig = SignatureEnumRW::kRWDefault;
  1231. if (!signature.empty()) {
  1232. if (i >= signature.size()) {
  1233. MS_EXCEPTION(ValueError) << "Signature size is not equal to index, signature size " << signature.size()
  1234. << ", index " << i;
  1235. }
  1236. sig = signature[i].rw;
  1237. }
  1238. TypeId arg_type_id = kTypeUnknown;
  1239. if (py::isinstance<tensor::MetaTensor>(obj)) {
  1240. const auto &arg = py::cast<tensor::MetaTensorPtr>(obj);
  1241. arg_type_id = arg->data_type();
  1242. }
  1243. // Implicit cast
  1244. bool is_same_type = false;
  1245. if (arg_type_id != kTypeUnknown) {
  1246. is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
  1247. }
  1248. if (sig == SignatureEnumRW::kRWWrite && arg_type_id != kTypeUnknown && !is_same_type) {
  1249. prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id),
  1250. TypeIdToMsTypeStr(it->second));
  1251. }
  1252. if (is_same_type) {
  1253. continue;
  1254. }
  1255. if (IsPyObjTypeInvalid(obj)) {
  1256. MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i << "th input " << signature[i].name
  1257. << " is a not support implicit conversion. "
  1258. << "Its type is " << py::cast<std::string>(obj.attr("__class__").attr("__name__"))
  1259. << ", and the value is " << py::cast<py::str>(obj) << ". Only support Tensor or Scalar.";
  1260. }
  1261. py::object cast_output = DoAutoCast(input_args[i], it->second, op_exec_info->op_name, i);
  1262. input_args[i] = cast_output;
  1263. }
  1264. }
  1265. void ForwardExecutor::SetTensorMixPrecisionCast(const OpExecInfoPtr &op_exec_info) {
  1266. MS_EXCEPTION_IF_NULL(op_exec_info);
  1267. const auto &prim = op_exec_info->py_primitive;
  1268. MS_EXCEPTION_IF_NULL(prim);
  1269. const auto &signature = prim->signatures();
  1270. for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
  1271. const auto &obj = op_exec_info->op_inputs[i];
  1272. auto sig = SignatureEnumRW::kRWDefault;
  1273. if (!signature.empty()) {
  1274. if (i >= signature.size()) {
  1275. MS_EXCEPTION(ValueError) << "Signature size is not equal to index, signature size " << signature.size()
  1276. << ", index " << i;
  1277. }
  1278. sig = signature[i].rw;
  1279. }
  1280. MS_LOG(DEBUG) << "Check mix precision " << op_exec_info->op_name << " input " << i;
  1281. // mix precision for non param
  1282. bool is_cast = false;
  1283. py::object cast_output;
  1284. if (py::isinstance<tensor::MetaTensor>(obj)) {
  1285. auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
  1286. if (meta_tensor && meta_tensor->is_parameter()) {
  1287. // If parameter write(not kRWRead), no need cast
  1288. if (sig != SignatureEnumRW::kRWRead) {
  1289. continue;
  1290. }
  1291. }
  1292. cast_output = DoParamMixPrecisionCast(&is_cast, obj, prim->name(), i);
  1293. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  1294. // mix precision for tuple inputs
  1295. cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj, prim->name(), i);
  1296. }
  1297. if (is_cast) {
  1298. op_exec_info->op_inputs[i] = cast_output;
  1299. }
  1300. }
  1301. }
  1302. void ForwardExecutor::SetImplicitCast(const OpExecInfoPtr &op_exec_info) {
  1303. MS_EXCEPTION_IF_NULL(op_exec_info);
  1304. const auto &prim = op_exec_info->py_primitive;
  1305. MS_EXCEPTION_IF_NULL(prim);
  1306. const auto &it = implicit_cast_map_.find(prim->name());
  1307. if (it == implicit_cast_map_.end()) {
  1308. MS_LOG(DEBUG) << "Do signature for " << op_exec_info->op_name << " first";
  1309. const auto &signature = prim->signatures();
  1310. auto sig_size = signature.size();
  1311. // Ignore monad signature
  1312. for (const auto &sig : signature) {
  1313. if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
  1314. --sig_size;
  1315. }
  1316. }
  1317. auto size = op_exec_info->op_inputs.size();
  1318. if (sig_size > 0 && sig_size != size) {
  1319. MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
  1320. << "signature size " << sig_size;
  1321. }
  1322. std::vector<SignatureEnumDType> dtypes;
  1323. mindspore::HashMap<SignatureEnumDType, std::vector<size_t>> type_indexes;
  1324. bool has_dtype_sig = GetSignatureType(op_exec_info->py_primitive, &dtypes);
  1325. if (has_dtype_sig) {
  1326. mindspore::HashMap<SignatureEnumDType, TypeId> dst_type;
  1327. GetTypeIndex(dtypes, &type_indexes);
  1328. GetDstType(op_exec_info->op_inputs, type_indexes, &dst_type);
  1329. DoSignatureCast(op_exec_info->py_primitive, dst_type, dtypes, op_exec_info);
  1330. }
  1331. PrimSignature sig_value{has_dtype_sig, dtypes, type_indexes};
  1332. implicit_cast_map_[prim->name()] = sig_value;
  1333. } else {
  1334. if (!it->second.has_dtype_sig) {
  1335. MS_LOG(DEBUG) << op_exec_info->op_name << " have no dtype sig";
  1336. return;
  1337. }
  1338. MS_LOG(DEBUG) << "Do signature for " << op_exec_info->op_name << " with cache";
  1339. mindspore::HashMap<SignatureEnumDType, TypeId> dst_type;
  1340. GetDstType(op_exec_info->op_inputs, it->second.type_indexes, &dst_type);
  1341. DoSignatureCast(op_exec_info->py_primitive, dst_type, it->second.dtypes, op_exec_info);
  1342. }
  1343. }
  1344. AnfNodePtr GradExecutor::GetInput(const py::object &obj, bool op_mask) {
  1345. AnfNodePtr node = nullptr;
  1346. const auto &obj_id = GetId(obj);
  1347. if (op_mask) {
  1348. MS_LOG(DEBUG) << "Cell parameters(weights)";
  1349. // get the parameter name from parameter object
  1350. auto name_attr = parse::python_adapter::GetPyObjAttr(obj, "name");
  1351. if (py::isinstance<py::none>(name_attr)) {
  1352. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  1353. }
  1354. const auto &param_name = py::cast<std::string>(name_attr);
  1355. auto df_builder = top_cell()->df_builder();
  1356. MS_EXCEPTION_IF_NULL(df_builder);
  1357. auto graph_info = top_cell()->graph_info_map().at(df_builder);
  1358. MS_EXCEPTION_IF_NULL(graph_info);
  1359. if (graph_info->params.find(obj_id) == graph_info->params.end()) {
  1360. auto free_param = df_builder->add_parameter();
  1361. free_param->set_name(param_name);
  1362. free_param->debug_info()->set_name(param_name);
  1363. auto value = py::cast<tensor::TensorPtr>(obj);
  1364. free_param->set_default_param(value);
  1365. MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
  1366. SetParamNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
  1367. SetParamNodeMapInGraphInfoMap(curr_g(), obj_id, free_param);
  1368. SetNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
  1369. SetNodeMapInGraphInfoMap(curr_g(), obj_id, free_param);
  1370. return free_param;
  1371. }
  1372. node = graph_info->params.at(obj_id);
  1373. MS_EXCEPTION_IF_NULL(node);
  1374. MS_LOG(DEBUG) << "Get input param node " << node->ToString() << ", obj id " << obj_id;
  1375. return node;
  1376. }
  1377. auto curr_graph_info = top_cell()->graph_info_map().at(curr_g());
  1378. MS_EXCEPTION_IF_NULL(curr_graph_info);
  1379. if (curr_graph_info->node_map.find(obj_id) != curr_graph_info->node_map.end()) {
  1380. // op(x, y)
  1381. // out = op(op1(x, y))
  1382. // out = op(cell1(x, y))
  1383. // out = op(cell1(x, y)[0])
  1384. node = GetObjNode(obj, obj_id);
  1385. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  1386. // out = op((x, y))
  1387. // out = cell((x, y))
  1388. auto tuple = obj.cast<py::tuple>();
  1389. // cell((1,2)): support not mix (scalar, tensor)
  1390. if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) {
  1391. return MakeValueNode(obj, obj_id);
  1392. }
  1393. std::vector<AnfNodePtr> args;
  1394. args.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  1395. auto tuple_size = tuple.size();
  1396. for (size_t i = 0; i < tuple_size; i++) {
  1397. args.emplace_back(GetInput(tuple[i], false));
  1398. }
  1399. auto cnode = curr_g()->NewCNode(args);
  1400. SetNodeMapInGraphInfoMap(curr_g(), GetId(obj), cnode);
  1401. node = cnode;
  1402. } else {
  1403. node = MakeValueNode(obj, obj_id);
  1404. }
  1405. node == nullptr ? MS_LOG(DEBUG) << "Get node is nullptr"
  1406. : MS_LOG(DEBUG) << "Get input node " << node->ToString() << ", id " << obj_id;
  1407. return node;
  1408. }
  1409. AnfNodePtr GradExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
  1410. auto graph_info = top_cell()->graph_info_map().at(curr_g());
  1411. MS_EXCEPTION_IF_NULL(graph_info);
  1412. if (graph_info->node_map.find(obj_id) == graph_info->node_map.end()) {
  1413. // A tuple returns in this case: x = op1, y = op2, return (x, y)
  1414. // or a constant returns in this case
  1415. auto make_tuple = CreateMakeTupleNode(obj, obj_id);
  1416. if (make_tuple == nullptr) {
  1417. MS_LOG(DEBUG) << "Create value node for obj id: " << obj_id;
  1418. return MakeValueNode(obj, obj_id);
  1419. }
  1420. return make_tuple;
  1421. }
  1422. // single output CNode
  1423. const auto &out = graph_info->node_map.at(obj_id);
  1424. if (out.second.size() == 1 && out.second[0] == -1) {
  1425. return out.first;
  1426. }
  1427. // Params node
  1428. if (graph_info->params.find(obj_id) != graph_info->params.end()) {
  1429. auto para_node = out.first;
  1430. for (auto &v : out.second) {
  1431. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node, NewValueNode(v)};
  1432. para_node = curr_g()->NewCNode(tuple_get_item_inputs);
  1433. }
  1434. return para_node;
  1435. }
  1436. // Create tuple get item node for multiple output CNode
  1437. return CreateTupleGetItemNode(obj_id);
  1438. }
  1439. AnfNodePtr GradExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) {
  1440. ValuePtr converted_ret = nullptr;
  1441. if (!parse::ConvertData(obj, &converted_ret)) {
  1442. MS_LOG(EXCEPTION) << "Failed to convert obj to value node.";
  1443. }
  1444. auto node = NewValueNode(converted_ret);
  1445. SetNodeMapInGraphInfoMap(curr_g(), obj_id, node);
  1446. return node;
  1447. }
  1448. AnfNodePtr GradExecutor::CreateMakeTupleNode(const py::object &obj, const std::string &obj_id) {
  1449. if (!py::isinstance<py::tuple>(obj) && !py::isinstance<py::list>(obj)) {
  1450. MS_LOG(DEBUG) << "The input obj is not a tuple or list.";
  1451. return nullptr;
  1452. }
  1453. // get input node and value
  1454. const auto &obj_tuple = obj.cast<py::tuple>();
  1455. ValuePtrList input_args;
  1456. std::vector<size_t> value_index;
  1457. std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
  1458. for (size_t i = 0; i < obj_tuple.size(); ++i) {
  1459. const auto &v = PyObjToValue(obj_tuple[i]);
  1460. // Graph have no define for grad
  1461. if (v->isa<FuncGraph>()) {
  1462. continue;
  1463. }
  1464. value_index.emplace_back(i);
  1465. input_args.emplace_back(v);
  1466. (void)CreateMakeTupleNode(obj_tuple[i], GetId(obj_tuple[i]));
  1467. inputs.emplace_back(GetInput(obj_tuple[i], false));
  1468. }
  1469. py::tuple value_outs(value_index.size());
  1470. for (size_t i = 0; i < value_index.size(); ++i) {
  1471. value_outs[i] = obj_tuple[value_index[i]];
  1472. }
  1473. // create make tuple node and record in graph info map
  1474. auto cnode = curr_g()->NewCNode(inputs);
  1475. MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
  1476. SetTupleArgsToGraphInfoMap(curr_g(), obj, cnode);
  1477. SetNodeMapInGraphInfoMap(curr_g(), obj_id, cnode);
  1478. // run ad for make tuple node
  1479. if (grad_flag_) {
  1480. if (grad_is_running_ && !bprop_grad_stack_.empty() && !bprop_grad_stack_.top().second) {
  1481. MS_LOG(DEBUG) << "Running custom bprop, no need to do GradPynativeOp.";
  1482. } else {
  1483. (void)ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, PyObjToValue(value_outs));
  1484. }
  1485. }
  1486. return cnode;
  1487. }
  1488. AnfNodePtr GradExecutor::CreateTupleGetItemNode(const std::string &obj_id) {
  1489. // obj_id is obtained by calling the 'GetId()'
  1490. auto graph_info = top_cell()->graph_info_map().at(curr_g());
  1491. MS_EXCEPTION_IF_NULL(graph_info);
  1492. if (graph_info->node_map.find(obj_id) == graph_info->node_map.end()) {
  1493. MS_LOG(DEBUG) << "Can not find CNode for obj id: " << obj_id;
  1494. return nullptr;
  1495. }
  1496. const auto &out = graph_info->node_map.at(obj_id);
  1497. MS_LOG(DEBUG) << "Output size: " << out.second.size();
  1498. auto c_node = out.first->cast<CNodePtr>();
  1499. MS_EXCEPTION_IF_NULL(c_node);
  1500. auto abs = c_node->abstract();
  1501. // Create tuple get item node
  1502. for (const auto &idx : out.second) {
  1503. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), c_node, NewValueNode(idx)};
  1504. c_node = curr_g()->NewCNode(tuple_get_item_inputs);
  1505. if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
  1506. auto abs_tuple = dyn_cast<abstract::AbstractTuple>(abs);
  1507. MS_EXCEPTION_IF_NULL(abs_tuple);
  1508. const auto &elements = abs_tuple->elements();
  1509. if (static_cast<size_t>(idx) >= elements.size()) {
  1510. MS_LOG(EXCEPTION) << "Index exceeds the size of elements. Index " << idx << ", element size "
  1511. << elements.size();
  1512. }
  1513. auto prim_abs = elements[static_cast<size_t>(idx)];
  1514. MS_EXCEPTION_IF_NULL(prim_abs);
  1515. MS_LOG(DEBUG) << "Set tuple getitem abs " << prim_abs->ToString();
  1516. c_node->set_abstract(prim_abs);
  1517. }
  1518. }
  1519. if (c_node->abstract() != nullptr) {
  1520. forward()->node_abs_map()[obj_id] = c_node->abstract();
  1521. }
  1522. MS_LOG(DEBUG) << "Create tuple get item node: " << c_node->DebugString();
  1523. return c_node;
  1524. }
  1525. TopCellInfoPtr GradExecutor::GetTopCell(const std::string &already_run_cell_id) {
  1526. TopCellInfoPtr find_top_cell = nullptr;
  1527. for (const auto &top_cell : top_cell_list_) {
  1528. MS_EXCEPTION_IF_NULL(top_cell);
  1529. // Complete match, means run grad operation first
  1530. if (top_cell->already_run_cell_id() == already_run_cell_id) {
  1531. return top_cell;
  1532. }
  1533. // Partial match, means run forward first
  1534. if (already_run_cell_id.find(top_cell->already_run_cell_id()) != std::string::npos &&
  1535. top_cell->already_run_cell_id().back() == '_') {
  1536. find_top_cell = top_cell;
  1537. break;
  1538. }
  1539. }
  1540. // Same topcell info, but grad operation is not the same, construct backward graph again
  1541. if (find_top_cell != nullptr) {
  1542. if (!find_top_cell->grad_operation().empty() && find_top_cell->grad_operation() != grad_operation_) {
  1543. MS_LOG(DEBUG) << "Already exist grad operation " << find_top_cell->grad_operation() << " is different with new "
  1544. << grad_operation_;
  1545. EraseTopCellFromTopCellList(find_top_cell);
  1546. (void)already_run_top_cell_.erase(find_top_cell->already_run_cell_id());
  1547. return nullptr;
  1548. } else {
  1549. return find_top_cell;
  1550. }
  1551. }
  1552. return nullptr;
  1553. }
  1554. void GradExecutor::EnableOpGraphCache(bool is_enable) {
  1555. MS_LOG(DEBUG) << "Op cache is enable: " << is_enable;
  1556. enable_op_cache_ = is_enable;
  1557. const auto inst = MsContext::GetInstance();
  1558. MS_EXCEPTION_IF_NULL(inst);
  1559. inst->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, is_enable);
  1560. }
  1561. void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out) {
  1562. if (!grad_flag_) {
  1563. MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info";
  1564. return;
  1565. }
  1566. MS_EXCEPTION_IF_NULL(op_exec_info);
  1567. MS_EXCEPTION_IF_NULL(op_out);
  1568. std::string input_args_info;
  1569. // Record input args info (weight or data)
  1570. for (const auto mask : op_exec_info->inputs_mask) {
  1571. if (mask) {
  1572. input_args_info += "w";
  1573. continue;
  1574. }
  1575. input_args_info += "d";
  1576. }
  1577. // Record op name and index
  1578. op_exec_info->op_info.clear();
  1579. const auto &curr_op_num = top_cell()->op_num();
  1580. op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info;
  1581. // The out shape is added to determine those ops that change the shape
  1582. auto out_abs = op_out->ToAbstract();
  1583. if (out_abs != nullptr) {
  1584. auto out_shape = out_abs->BuildShape()->ToString();
  1585. if (out_shape.find("()") == std::string::npos && out_shape.find("NoShape") == std::string::npos) {
  1586. op_exec_info->op_info += "-" + out_shape;
  1587. }
  1588. }
  1589. top_cell()->all_op_info() += "-" + op_exec_info->op_info;
  1590. top_cell()->set_op_num(curr_op_num + 1);
  1591. }
  1592. void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const CNodePtr &cnode) {
  1593. if (cell_stack_.empty()) {
  1594. MS_LOG(DEBUG) << "No need save output";
  1595. return;
  1596. }
  1597. MS_EXCEPTION_IF_NULL(cnode);
  1598. MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString() << " id " << obj_id;
  1599. if (py::isinstance<py::tuple>(out_real)) {
  1600. auto value = py::cast<py::tuple>(out_real);
  1601. auto size = static_cast<int64_t>(value.size());
  1602. if (size > 1) {
  1603. for (int64_t i = 0; i < size; ++i) {
  1604. auto value_id = GetId(value[static_cast<size_t>(i)]);
  1605. SetNodeMapInGraphInfoMap(curr_g(), value_id, cnode, i);
  1606. }
  1607. }
  1608. }
  1609. SetNodeMapInGraphInfoMap(curr_g(), obj_id, cnode);
  1610. }
  1611. // Run ad grad for curr op and connect grad graph with previous op
  1612. void GradExecutor::DoOpGrad(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const ValuePtr &op_out) {
  1613. MS_EXCEPTION_IF_NULL(op_out);
  1614. if (grad_is_running_ && !bprop_grad_stack_.top().second) {
  1615. MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
  1616. return;
  1617. }
  1618. ValuePtrList input_args;
  1619. for (size_t i = 0; i < op_exec_info->op_inputs.size(); ++i) {
  1620. const auto &arg = PyObjToValue(op_exec_info->op_inputs[i]);
  1621. input_args.emplace_back(arg);
  1622. }
  1623. if (!ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, op_out)) {
  1624. MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_exec_info->op_name;
  1625. }
  1626. }
  1627. void GradExecutor::UpdateMsFunctionForwardTensors(const OpExecInfoPtr &op_exec_info,
  1628. const ValuePtr &new_forward_value) {
  1629. MS_LOG(DEBUG) << "Ms func graph has already ran before. The graph phase is: " << graph_phase();
  1630. MS_EXCEPTION_IF_NULL(new_forward_value);
  1631. MS_LOG(DEBUG) << "The output values of added forward nodes are: " << new_forward_value->ToString();
  1632. std::vector<tensor::TensorPtr> new_tensors;
  1633. TensorValueToTensor(new_forward_value, &new_tensors);
  1634. if (new_tensors.empty()) {
  1635. MS_LOG(DEBUG) << "The size of added forward tensors is zero, no need to update.";
  1636. return;
  1637. }
  1638. MS_EXCEPTION_IF_NULL(op_exec_info);
  1639. const auto &old_tensors = top_cell()->op_info_with_ms_func_forward_tensors().at(op_exec_info->op_info);
  1640. if (old_tensors.size() != new_tensors.size()) {
  1641. MS_LOG(EXCEPTION) << "The size of old tensors is: " << old_tensors.size()
  1642. << ", but the size of new tensors is: " << new_tensors.size()
  1643. << ", the current op info is: " << op_exec_info->op_info;
  1644. }
  1645. for (size_t i = 0; i < new_tensors.size(); ++i) {
  1646. UpdateTensorInfo(new_tensors[i], {old_tensors[i]});
  1647. old_tensors[i]->set_sync_status(kNeedSyncDeviceToHost);
  1648. }
  1649. }
  1650. void GradExecutor::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
  1651. ValuePtrList *input_values, CNodePtr *ms_function_cnode) {
  1652. // Get input node info of ms_function
  1653. MS_EXCEPTION_IF_NULL(ms_func_graph);
  1654. std::vector<AnfNodePtr> input_nodes{NewValueNode(ms_func_graph)};
  1655. MS_EXCEPTION_IF_NULL(input_values);
  1656. for (size_t i = 0; i < args.size(); ++i) {
  1657. auto input_i_node = GetInput(args[i], false);
  1658. MS_EXCEPTION_IF_NULL(input_i_node);
  1659. MS_LOG(DEBUG) << "The input " << i << " node of ms_function graph is: " << input_i_node->DebugString();
  1660. input_nodes.emplace_back(input_i_node);
  1661. const auto &inp_i_value = PyObjToValue(args[i]);
  1662. MS_LOG(DEBUG) << "The input " << i << " value of ms_function graph is: " << inp_i_value->ToString();
  1663. (*input_values).emplace_back(inp_i_value);
  1664. }
  1665. // Get dfbuilder and graph info map
  1666. auto df_builder = top_cell()->df_builder();
  1667. MS_EXCEPTION_IF_NULL(df_builder);
  1668. const auto &graph_info = top_cell()->graph_info_map().at(df_builder);
  1669. MS_EXCEPTION_IF_NULL(graph_info);
  1670. // Get weights info of ms_function
  1671. std::vector<AnfNodePtr> new_params;
  1672. auto manage = Manage(ms_func_graph, false);
  1673. for (const auto &anf_node : ms_func_graph->parameters()) {
  1674. MS_EXCEPTION_IF_NULL(anf_node);
  1675. auto param = anf_node->cast<ParameterPtr>();
  1676. MS_EXCEPTION_IF_NULL(param);
  1677. if (!param->has_default()) {
  1678. new_params.push_back(param);
  1679. continue;
  1680. }
  1681. auto param_info = param->param_info();
  1682. MS_EXCEPTION_IF_NULL(param_info);
  1683. auto param_name = param_info->name();
  1684. if (graph_info->params.count(param_name)) {
  1685. // Share same weight parameter in different ms_function call.
  1686. auto same_param = graph_info->params.at(param_name);
  1687. manage->Replace(anf_node, same_param);
  1688. param = same_param;
  1689. } else {
  1690. df_builder->add_parameter(param);
  1691. param->debug_info()->set_name(param_name);
  1692. }
  1693. new_params.push_back(param);
  1694. input_nodes.emplace_back(param);
  1695. (*input_values).emplace_back(param->default_param());
  1696. SetParamNodeMapInGraphInfoMap(df_builder, param_name, param);
  1697. MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
  1698. << param->default_param()->ToString() << ". Its name is: " << param_name;
  1699. }
  1700. ms_func_graph->set_parameters(new_params);
  1701. manage->Clear();
  1702. // Make a CNode which includes ms_function fprop graph and inputs node
  1703. MS_EXCEPTION_IF_NULL(ms_function_cnode);
  1704. *ms_function_cnode = curr_g()->NewCNode(input_nodes);
  1705. MS_LOG(DEBUG) << "Make ms function forward cnode: " << (*ms_function_cnode)->DebugString();
  1706. }
  1707. // Make adjoint for ms_function fprop graph and connect it with previous op
  1708. void GradExecutor::MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
  1709. const py::object &actual_out, const py::args &args,
  1710. const ValuePtr &actual_out_v) {
  1711. ValuePtrList input_values;
  1712. CNodePtr ms_function_cnode = nullptr;
  1713. MakeCNodeForMsFunction(ms_func_graph, args, &input_values, &ms_function_cnode);
  1714. MS_EXCEPTION_IF_NULL(ms_function_cnode);
  1715. SetTupleArgsToGraphInfoMap(curr_g(), actual_out, ms_function_cnode);
  1716. SetNodeMapInGraphInfoMap(curr_g(), GetId(actual_out), ms_function_cnode);
  1717. // Connect grad graph of ms_function to context.
  1718. auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
  1719. MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
  1720. MS_EXCEPTION_IF_NULL(grad_graph);
  1721. if (!k_pynative_cell_ptr->KPynativeWithFProp(ms_function_cnode, input_values, actual_out_v, grad_graph)) {
  1722. MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
  1723. << ms_function_cnode->DebugString();
  1724. }
  1725. top_cell()->set_ms_function_flag(true);
  1726. }
  1727. void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out) {
  1728. if (!grad_flag_) {
  1729. MS_LOG(DEBUG) << "The grad flag is false, no need to update forward op info in bprop graph";
  1730. return;
  1731. }
  1732. MS_EXCEPTION_IF_NULL(op_exec_info);
  1733. MS_EXCEPTION_IF_NULL(op_out);
  1734. const auto &op_info = op_exec_info->op_info;
  1735. MS_LOG(DEBUG) << "Current op info: " << op_info;
  1736. std::vector<tensor::TensorPtr> all_op_tensors;
  1737. // Get output tensors
  1738. TensorValueToTensor(op_out, &all_op_tensors);
  1739. // Save all tensors info of current op
  1740. if (need_construct_graph()) {
  1741. SaveOpInfo(top_cell_, op_info, all_op_tensors);
  1742. }
  1743. // First run top cell
  1744. if (already_run_top_cell_.find(top_cell_->already_run_cell_id()) == already_run_top_cell_.end()) {
  1745. MS_LOG(DEBUG) << "Top cell " << top_cell_->cell_id() << " run firstly";
  1746. if (!need_construct_graph()) {
  1747. MS_LOG(EXCEPTION) << "The cell stack is empty when running a new top cell " << top_cell_->cell_id();
  1748. }
  1749. return;
  1750. }
  1751. // Non-first run
  1752. const auto &pre_top_cell = already_run_top_cell_.at(top_cell_->already_run_cell_id());
  1753. MS_EXCEPTION_IF_NULL(pre_top_cell);
  1754. if (pre_top_cell->op_info_with_tensor_id().find(op_info) == pre_top_cell->op_info_with_tensor_id().end()) {
  1755. MS_LOG(DEBUG) << "Can not find op info " << op_info << " in op info with tensor id map. Top cell "
  1756. << top_cell_->cell_id();
  1757. return;
  1758. }
  1759. // Update new output tensor info in bprop graph
  1760. const auto &pre_op_tensor_id = pre_top_cell->op_info_with_tensor_id().at(op_info);
  1761. if (pre_op_tensor_id.size() != all_op_tensors.size()) {
  1762. MS_LOG(EXCEPTION) << "The size of pre op tensor id: " << pre_op_tensor_id.size()
  1763. << " is not equal to the size of all tensors of current op " << all_op_tensors.size();
  1764. }
  1765. const auto &pre_tensor_id_with_tensor_object = pre_top_cell->tensor_id_with_tensor_object();
  1766. for (size_t i = 0; i < pre_op_tensor_id.size(); ++i) {
  1767. auto pre_id = pre_op_tensor_id[i];
  1768. if (pre_tensor_id_with_tensor_object.find(pre_id) == pre_tensor_id_with_tensor_object.end()) {
  1769. continue;
  1770. }
  1771. const auto &new_tensor = all_op_tensors[i];
  1772. const auto &pre_tensor_object = pre_tensor_id_with_tensor_object.at(pre_id);
  1773. UpdateTensorInfo(new_tensor, pre_tensor_object);
  1774. }
  1775. }
  1776. void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const {
  1777. MS_EXCEPTION_IF_NULL(resource);
  1778. // Get all tensors id of forward op
  1779. mindspore::HashSet<std::string> forward_op_tensor_id;
  1780. const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id();
  1781. for (const auto &record : op_info_with_tensor_id) {
  1782. std::for_each(record.second.begin(), record.second.end(),
  1783. [&forward_op_tensor_id](const std::string &tensor_id) { forward_op_tensor_id.emplace(tensor_id); });
  1784. }
  1785. // Get all tensors obj in value node of bprop graph
  1786. const auto &bprop_graph = resource->func_graph();
  1787. MS_EXCEPTION_IF_NULL(bprop_graph);
  1788. const auto &value_node_list = bprop_graph->value_nodes();
  1789. std::vector<tensor::TensorPtr> tensors_in_bprop_graph;
  1790. for (const auto &elem : value_node_list) {
  1791. auto value_node = elem.first->cast<ValueNodePtr>();
  1792. MS_EXCEPTION_IF_NULL(value_node);
  1793. TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
  1794. }
  1795. auto &tensor_id_with_tensor_object = top_cell()->tensor_id_with_tensor_object();
  1796. if (!tensor_id_with_tensor_object.empty()) {
  1797. MS_LOG(EXCEPTION) << "When compile a top graph, the tensor_id_with_tensor_object map should be empty. Top cell: "
  1798. << top_cell()->cell_id();
  1799. }
  1800. // Save tensor in value node of bprop graph
  1801. for (const auto &tensor : tensors_in_bprop_graph) {
  1802. MS_EXCEPTION_IF_NULL(tensor);
  1803. if (forward_op_tensor_id.find(tensor->id()) == forward_op_tensor_id.end() || tensor->device_address() == nullptr) {
  1804. continue;
  1805. }
  1806. tensor_id_with_tensor_object[tensor->id()].emplace_back(tensor);
  1807. MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
  1808. << " device address: " << tensor->device_address() << " shape and dtype "
  1809. << tensor->GetShapeAndDataTypeInfo();
  1810. }
  1811. }
  1812. py::tuple ForwardExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) {
  1813. MS_EXCEPTION_IF_NULL(op_exec_info);
  1814. auto backend_policy = InitEnv(op_exec_info);
  1815. PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
  1816. // returns a null py::tuple on error
  1817. py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
  1818. if (status != PYNATIVE_SUCCESS) {
  1819. MS_LOG(EXCEPTION) << "Failed to run " << op_exec_info->op_name;
  1820. }
  1821. MS_LOG(DEBUG) << "RunOp end";
  1822. return result;
  1823. }
  1824. MsBackendPolicy ForwardExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) {
  1825. MS_EXCEPTION_IF_NULL(op_exec_info);
  1826. MS_LOG(DEBUG) << "RunOp start, op name is: " << op_exec_info->op_name;
  1827. parse::python_adapter::set_python_env_flag(true);
  1828. MsBackendPolicy backend_policy;
  1829. #if (!defined ENABLE_GE)
  1830. auto ms_context = MsContext::GetInstance();
  1831. MS_EXCEPTION_IF_NULL(ms_context);
  1832. if (!context::IsTsdOpened(ms_context)) {
  1833. if (!context::OpenTsd(ms_context)) {
  1834. MS_LOG(EXCEPTION) << "Open tsd failed";
  1835. }
  1836. }
  1837. if (ms_context->backend_policy() == "ms") {
  1838. backend_policy = kMsBackendMsPrior;
  1839. } else {
  1840. backend_policy = kMsBackendVmOnly;
  1841. }
  1842. #else
  1843. auto ms_context = MsContext::GetInstance();
  1844. MS_EXCEPTION_IF_NULL(ms_context);
  1845. context::PynativeInitGe(ms_context);
  1846. backend_policy = kMsBackendGeOnly;
  1847. #endif
  1848. if (kVmOperators.find(op_exec_info->op_name) != kVmOperators.end()) {
  1849. backend_policy = kMsBackendVmOnly;
  1850. }
  1851. return backend_policy;
  1852. }
  1853. py::object ForwardExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
  1854. PynativeStatusCode *status) {
  1855. MS_EXCEPTION_IF_NULL(status);
  1856. py::object result;
  1857. switch (backend_policy) {
  1858. case kMsBackendVmOnly: {
  1859. // use vm only
  1860. MS_LOG(DEBUG) << "RunOp use VM only backend";
  1861. result = RunOpInVM(op_exec_info, status);
  1862. break;
  1863. }
  1864. case kMsBackendGePrior: {
  1865. #ifdef ENABLE_GE
  1866. // use GE first, use vm when GE fails
  1867. MS_LOG(DEBUG) << "RunOp use GE first backend";
  1868. result = RunOpInGE(op_exec_info, status);
  1869. if (*status != PYNATIVE_SUCCESS) {
  1870. result = RunOpInVM(op_exec_info, status);
  1871. }
  1872. #endif
  1873. break;
  1874. }
  1875. case kMsBackendMsPrior: {
  1876. // use Ms first,use others when ms failed
  1877. MS_LOG(DEBUG) << "RunOp use Ms first backend";
  1878. result = RunOpInMs(op_exec_info, status);
  1879. if (*status != PYNATIVE_SUCCESS) {
  1880. MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
  1881. }
  1882. break;
  1883. }
  1884. default:
  1885. MS_LOG(ERROR) << "No backend configured for run op";
  1886. }
  1887. return result;
  1888. }
  1889. py::object ForwardExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  1890. MS_LOG(DEBUG) << "RunOpInVM start";
  1891. MS_EXCEPTION_IF_NULL(status);
  1892. MS_EXCEPTION_IF_NULL(op_exec_info);
  1893. MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
  1894. auto &op_inputs = op_exec_info->op_inputs;
  1895. if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf" ||
  1896. op_exec_info->op_name == "stop_gradient") {
  1897. py::tuple result(op_inputs.size());
  1898. for (size_t i = 0; i < op_inputs.size(); i++) {
  1899. py::object input = op_inputs[i];
  1900. auto tensor = py::cast<tensor::TensorPtr>(input);
  1901. MS_EXCEPTION_IF_NULL(tensor);
  1902. if (op_exec_info->op_name == "HookBackward") {
  1903. // the input object is not a output of forward cnode, eg: parameter
  1904. result[i] = tensor;
  1905. } else {
  1906. // the input object is a output of forward cnode
  1907. auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
  1908. new_tensor->set_device_address(tensor->device_address());
  1909. new_tensor->set_sync_status(tensor->sync_status());
  1910. result[i] = new_tensor;
  1911. }
  1912. }
  1913. *status = PYNATIVE_SUCCESS;
  1914. MS_LOG(DEBUG) << "RunOpInVM end";
  1915. return std::move(result);
  1916. }
  1917. auto primitive = op_exec_info->py_primitive;
  1918. MS_EXCEPTION_IF_NULL(primitive);
  1919. auto result = primitive->RunPyComputeFunction(op_inputs);
  1920. MS_LOG(DEBUG) << "RunOpInVM end";
  1921. if (py::isinstance<py::none>(result)) {
  1922. MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
  1923. *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
  1924. py::tuple err_ret(0);
  1925. return std::move(err_ret);
  1926. }
  1927. *status = PYNATIVE_SUCCESS;
  1928. if (py::isinstance<py::tuple>(result)) {
  1929. return result;
  1930. }
  1931. py::tuple tuple_result = py::make_tuple(result);
  1932. return std::move(tuple_result);
  1933. }
  1934. void ForwardExecutor::CheckIfNeedSyncForHeterogeneous(const std::string &cur_target) {
  1935. if (last_target_ != "Unknown" && last_target_ != cur_target) {
  1936. auto executor = PynativeExecutor::GetInstance();
  1937. executor->Sync();
  1938. }
  1939. last_target_ = cur_target;
  1940. }
  1941. py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  1942. MS_EXCEPTION_IF_NULL(op_exec_info);
  1943. MS_EXCEPTION_IF_NULL(status);
  1944. compile::SetMindRTEnable();
  1945. MS_LOG(DEBUG) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms";
  1946. auto ms_context = MsContext::GetInstance();
  1947. MS_EXCEPTION_IF_NULL(ms_context);
  1948. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
  1949. const std::string &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  1950. uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  1951. auto enable_mind_rt = ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
  1952. std::string cur_target = GetCurrentDeviceTarget(device_target, op_exec_info->py_primitive);
  1953. CheckIfNeedSyncForHeterogeneous(cur_target);
  1954. std::vector<tensor::TensorPtr> input_tensors;
  1955. std::vector<int64_t> tensors_mask;
  1956. std::string graph_info;
  1957. ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
  1958. ConvertAttrToUnifyMindIR(op_exec_info);
  1959. // get graph info for checking it whether existing in the cache
  1960. GetSingleOpGraphInfo(op_exec_info, input_tensors, tensors_mask, &graph_info);
  1961. #if defined(__APPLE__)
  1962. session::OpRunInfo op_run_info = {op_exec_info->op_name,
  1963. op_exec_info->py_primitive.get(),
  1964. op_exec_info->abstract,
  1965. op_exec_info->is_dynamic_shape,
  1966. op_exec_info->is_mixed_precision_cast,
  1967. op_exec_info->lazy_build,
  1968. op_exec_info->next_op_name,
  1969. static_cast<int>(op_exec_info->next_input_index),
  1970. graph_info,
  1971. tensors_mask,
  1972. input_tensors};
  1973. #else
  1974. session::OpRunInfo op_run_info = {op_exec_info->op_name,
  1975. op_exec_info->py_primitive.get(),
  1976. op_exec_info->abstract,
  1977. op_exec_info->is_dynamic_shape,
  1978. op_exec_info->is_mixed_precision_cast,
  1979. op_exec_info->lazy_build,
  1980. op_exec_info->next_op_name,
  1981. op_exec_info->next_input_index,
  1982. graph_info,
  1983. tensors_mask,
  1984. input_tensors};
  1985. #endif
  1986. VectorRef outputs;
  1987. if (!enable_mind_rt || cur_target == "Ascend") {
  1988. auto cur_session = GetCurrentSession(cur_target, device_id);
  1989. MS_EXCEPTION_IF_NULL(cur_session);
  1990. cur_session->RunOp(&op_run_info, &outputs);
  1991. } else {
  1992. auto cur_mind_rt_backend = GetMindRtBackend(cur_target, device_id);
  1993. MS_EXCEPTION_IF_NULL(cur_mind_rt_backend);
  1994. mindspore::ScopedLongRunning long_running;
  1995. cur_mind_rt_backend->RunOp(&op_run_info, &outputs);
  1996. }
  1997. if (op_exec_info->is_dynamic_shape) {
  1998. op_exec_info->abstract = op_run_info.abstract;
  1999. }
  2000. auto result = BaseRefToPyData(outputs);
  2001. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
  2002. *status = PYNATIVE_SUCCESS;
  2003. MS_LOG(DEBUG) << "End run op [" << op_exec_info->op_name << "] with backend policy ms";
  2004. return result;
  2005. }
  2006. void ForwardExecutor::ClearRes() {
  2007. MS_LOG(DEBUG) << "Clear forward res";
  2008. lazy_build_ = false;
  2009. implicit_cast_map_.clear();
  2010. prim_abs_list_.clear();
  2011. node_abs_map_.clear();
  2012. }
  2013. ForwardExecutorPtr GradExecutor::forward() const {
  2014. auto forward_executor = forward_executor_.lock();
  2015. MS_EXCEPTION_IF_NULL(forward_executor);
  2016. return forward_executor;
  2017. }
  2018. TopCellInfoPtr GradExecutor::top_cell() const {
  2019. MS_EXCEPTION_IF_NULL(top_cell_);
  2020. return top_cell_;
  2021. }
  2022. FuncGraphPtr GradExecutor::curr_g() const {
  2023. auto fg = top_cell()->fg();
  2024. MS_EXCEPTION_IF_NULL(fg);
  2025. return fg;
  2026. }
  2027. void GradExecutor::PushCellStack(const std::string &cell_id) { cell_stack_.push(cell_id); }
  2028. void GradExecutor::PopCellStack() {
  2029. if (cell_stack_.empty()) {
  2030. MS_LOG(EXCEPTION) << "Stack cell_stack_ is empty";
  2031. }
  2032. cell_stack_.pop();
  2033. }
  2034. void GradExecutor::PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) { high_order_stack_.push(top_cell); }
  2035. TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
  2036. if (high_order_stack_.empty()) {
  2037. MS_LOG(EXCEPTION) << "Stack high_order_stack_ is empty";
  2038. }
  2039. high_order_stack_.pop();
  2040. TopCellInfoPtr top_cell = nullptr;
  2041. if (!high_order_stack_.empty()) {
  2042. top_cell = high_order_stack_.top();
  2043. }
  2044. return top_cell;
  2045. }
  2046. std::string GradExecutor::GetCellId(const py::object &cell, const py::args &args) {
  2047. auto cell_id = GetId(cell);
  2048. for (size_t i = 0; i < args.size(); i++) {
  2049. const auto &arg_id = GetId(args[i]);
  2050. auto it = forward()->node_abs_map().find(arg_id);
  2051. if (it != forward()->node_abs_map().end()) {
  2052. auto &abs = it->second;
  2053. MS_EXCEPTION_IF_NULL(abs);
  2054. auto shape = abs->BuildShape();
  2055. MS_EXCEPTION_IF_NULL(shape);
  2056. auto type = abs->BuildType();
  2057. MS_EXCEPTION_IF_NULL(type);
  2058. cell_id += "_" + shape->ToString();
  2059. cell_id += type->ToString();
  2060. } else {
  2061. auto value = PyObjToValue(args[i]);
  2062. MS_EXCEPTION_IF_NULL(value);
  2063. auto abs = value->ToAbstract();
  2064. MS_EXCEPTION_IF_NULL(abs);
  2065. if (abs->isa<abstract::AbstractTensor>()) {
  2066. abs->set_value(kAnyValue);
  2067. }
  2068. forward()->node_abs_map()[arg_id] = abs;
  2069. auto shape = abs->BuildShape();
  2070. MS_EXCEPTION_IF_NULL(shape);
  2071. auto type = abs->BuildType();
  2072. MS_EXCEPTION_IF_NULL(type);
  2073. cell_id += "_" + shape->ToString();
  2074. cell_id += type->ToString();
  2075. }
  2076. }
  2077. return cell_id;
  2078. }
  2079. void GradExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) {
  2080. #ifdef ENABLE_DUMP_IR
  2081. auto ms_context = MsContext::GetInstance();
  2082. MS_EXCEPTION_IF_NULL(ms_context);
  2083. if (ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  2084. DumpIR(filename, graph);
  2085. }
  2086. #endif
  2087. }
  2088. inline bool GradExecutor::IsNestedGrad() const {
  2089. MS_LOG(DEBUG) << "Grad nested order is " << grad_order_;
  2090. return grad_order_ > 1;
  2091. }
  2092. bool GradExecutor::IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) const {
  2093. // just compare obj_id, ignore args id
  2094. return l_cell_id.compare(0, PTR_LEN, r_cell_id, 0, PTR_LEN) == 0;
  2095. }
  2096. bool GradExecutor::IsBpropGraph(const std::string &cell_id) {
  2097. if (top_cell_ == nullptr) {
  2098. return false;
  2099. }
  2100. return std::any_of(bprop_cell_list_.begin(), bprop_cell_list_.end(),
  2101. [&cell_id](const std::string &value) { return cell_id.find(value) != std::string::npos; });
  2102. }
  2103. void GradExecutor::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled) {
  2104. top_cell()->set_vm_compiled(vm_compiled);
  2105. top_cell()->set_need_compile_graph(need_compile_graph);
  2106. top_cell()->set_forward_already_run(forward_already_run);
  2107. }
  2108. void GradExecutor::ClearCellRes(const std::string &cell_id) {
  2109. static bool clear_all_cell_res = false;
  2110. // Grad clean
  2111. if (cell_id.empty()) {
  2112. MS_LOG(DEBUG) << "Clear all cell resources";
  2113. clear_all_cell_res = true;
  2114. for (const auto &iter : top_cell_list_) {
  2115. MS_EXCEPTION_IF_NULL(iter);
  2116. iter->Clear();
  2117. }
  2118. top_cell_list_.clear();
  2119. already_run_top_cell_.clear();
  2120. clear_all_cell_res = false;
  2121. return;
  2122. }
  2123. if (clear_all_cell_res) {
  2124. MS_LOG(DEBUG) << "In process of clearing all cell resources, so no need to clear single cell resource again";
  2125. return;
  2126. }
  2127. // clear when cell destruction
  2128. for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) {
  2129. MS_EXCEPTION_IF_NULL(*it);
  2130. const auto &top_cell_id = (*it)->cell_id();
  2131. const auto &already_run_cell_id = (*it)->already_run_cell_id();
  2132. if (IsCellObjIdEq(cell_id, top_cell_id)) {
  2133. MS_LOG(DEBUG) << "Clear top cell resource. Top cell id " << top_cell_id;
  2134. (*it)->Clear();
  2135. it = top_cell_list_.erase(it);
  2136. (void)already_run_top_cell_.erase(already_run_cell_id);
  2137. continue;
  2138. }
  2139. ++it;
  2140. }
  2141. }
  2142. void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop_top) {
  2143. if (is_bprop_top) {
  2144. // Convert input args to parameters for top cell graph in bprop.
  2145. for (size_t i = 0; i < args.size(); ++i) {
  2146. auto param = args[i];
  2147. auto new_param = curr_g()->add_parameter();
  2148. const auto &param_id = GetId(param);
  2149. SetTupleArgsToGraphInfoMap(curr_g(), param, new_param, true);
  2150. SetNodeMapInGraphInfoMap(curr_g(), param_id, new_param);
  2151. SetParamNodeMapInGraphInfoMap(curr_g(), param_id, new_param);
  2152. }
  2153. return;
  2154. }
  2155. // Convert input args to parameters for top cell graph in construct.
  2156. std::vector<ValuePtr> input_param_values;
  2157. const auto &only_tensors = FilterTensorArgs(args);
  2158. for (size_t i = 0; i < only_tensors.size(); ++i) {
  2159. auto new_param = curr_g()->add_parameter();
  2160. auto param_i = only_tensors[i];
  2161. const auto &param_i_value = PyObjToValue(param_i);
  2162. input_param_values.emplace_back(param_i_value);
  2163. auto param_i_abs = param_i_value->ToAbstract();
  2164. MS_EXCEPTION_IF_NULL(param_i_abs);
  2165. new_param->set_abstract(param_i_abs->Broaden());
  2166. const auto &param_i_id = GetId(param_i);
  2167. SetTupleArgsToGraphInfoMap(curr_g(), param_i, new_param, true);
  2168. SetNodeMapInGraphInfoMap(curr_g(), param_i_id, new_param);
  2169. SetParamNodeMapInGraphInfoMap(curr_g(), param_i_id, new_param);
  2170. SetParamNodeMapInGraphInfoMap(top_cell_->df_builder(), param_i_id, new_param);
  2171. }
  2172. top_cell()->set_k_pynative_cell_ptr(ad::GradPynativeCellBegin(curr_g()->parameters(), input_param_values));
  2173. }
  2174. void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args) {
  2175. if (cell_stack_.empty() || IsNestedGrad()) {
  2176. if (cell_stack_.empty() && !grad_is_running_) {
  2177. MS_LOG(DEBUG) << "Make new topest graph";
  2178. MakeNewTopGraph(cell_id, args, true);
  2179. } else if (grad_is_running_ && IsBpropGraph(cell_id)) {
  2180. MS_LOG(DEBUG) << "Run bprop cell";
  2181. auto fg = std::make_shared<FuncGraph>();
  2182. top_cell()->set_fg(fg);
  2183. auto graph_info_cg = std::make_shared<GraphInfo>(cell_id);
  2184. top_cell()->graph_info_map()[fg] = graph_info_cg;
  2185. HandleInputArgsForTopCell(args, true);
  2186. bprop_grad_stack_.push(std::make_pair(cell_id, false));
  2187. } else if (grad_is_running_ && top_cell()->grad_order() != grad_order_) {
  2188. MS_LOG(DEBUG) << "Nested grad graph existed in bprop";
  2189. MakeNewTopGraph(cell_id, args, false);
  2190. bprop_grad_stack_.push(std::make_pair(cell_id, true));
  2191. } else if (!cell_stack_.empty() && IsNestedGrad() && top_cell()->grad_order() != grad_order_) {
  2192. MS_LOG(DEBUG) << "Nested grad graph existed in construct";
  2193. auto cur_top_is_dynamic = top_cell()->is_dynamic();
  2194. MakeNewTopGraph(cell_id, args, false);
  2195. top_cell()->set_is_dynamic(cur_top_is_dynamic);
  2196. }
  2197. }
  2198. PushCellStack(cell_id);
  2199. // Init kPynativeCellPtr with input parameters of top cell
  2200. if (!top_cell()->is_init_kpynative()) {
  2201. auto graph_info_cg = std::make_shared<GraphInfo>(cell_id);
  2202. top_cell()->graph_info_map()[curr_g()] = graph_info_cg;
  2203. auto graph_info_df = std::make_shared<GraphInfo>(cell_id);
  2204. top_cell()->graph_info_map()[top_cell_->df_builder()] = graph_info_df;
  2205. HandleInputArgsForTopCell(args, false);
  2206. top_cell()->set_need_compile_graph(true);
  2207. top_cell()->set_init_kpynative(true);
  2208. } else {
  2209. // Non-top cell
  2210. top_cell()->sub_cell_list().emplace(cell_id);
  2211. }
  2212. }
  2213. void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) {
  2214. MS_EXCEPTION_IF_NULL(ret);
  2215. const auto &cell_id = GetCellId(cell, args);
  2216. MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
  2217. if (top_cell_ != nullptr && cell_stack_.empty()) {
  2218. // Already run top cell need distinguish high order; high order add "0" otherwise "1"
  2219. const auto &already_run_cell_id = GetAlreadyRunCellId(cell_id);
  2220. auto top_it = already_run_top_cell_.find(already_run_cell_id);
  2221. if (top_it != already_run_top_cell_.end()) {
  2222. // Top cell forward run.
  2223. const auto &pre_top_cell = top_it->second;
  2224. MS_EXCEPTION_IF_NULL(pre_top_cell);
  2225. if (!pre_top_cell->is_dynamic()) {
  2226. MS_LOG(DEBUG) << "Top cell " << cell_id << " is not dynamic, no need to run NewGraphInner again";
  2227. ResetTopCellInfo(pre_top_cell, args);
  2228. PushHighOrderGraphStack(pre_top_cell);
  2229. set_top_cell(pre_top_cell);
  2230. grad_order_ = pre_top_cell->grad_order();
  2231. return;
  2232. }
  2233. } else if ((top_cell()->IsSubCell(cell_id) || GetHighOrderStackSize() >= 1) &&
  2234. !IsCellObjIdEq(cell_id, check_graph_cell_id_)) {
  2235. // Sub cell ( or may be a temporary cell, but must be non top) forward run in cache process.
  2236. MS_LOG(DEBUG) << "Sub cell no need to run NewGraphInner again";
  2237. return;
  2238. }
  2239. }
  2240. // When the cell has custom bprop, in_custom_bprop_cell is lager than 0
  2241. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  2242. custom_bprop_cell_count_ += 1;
  2243. }
  2244. // Make top graph and init resource for resource and df_builder
  2245. InitResourceAndDfBuilder(cell_id, args);
  2246. // Check whether cell has dynamic construct
  2247. if (!top_cell()->is_dynamic()) {
  2248. bool is_dynamic = parse::DynamicParser::IsDynamicCell(cell);
  2249. MS_LOG(DEBUG) << "Current cell dynamic " << is_dynamic;
  2250. if (is_dynamic) {
  2251. top_cell()->set_is_dynamic(is_dynamic);
  2252. }
  2253. }
  2254. }
  2255. void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest) {
  2256. pipeline::CheckArgsValid(args);
  2257. // Record input args info
  2258. std::string input_args_id;
  2259. for (size_t i = 0; i < args.size(); ++i) {
  2260. input_args_id += GetId(args[i]) + "_";
  2261. }
  2262. // Run forward first need plus 1
  2263. if (grad_order_ == 0) {
  2264. ++grad_order_;
  2265. }
  2266. // The number of top cell exceeds MAX_TOP_CELL_COUNTS, delete the last one to keep the maximum length of the list,
  2267. // disable backend cache
  2268. if (top_cell_list_.size() >= MAX_TOP_CELL_COUNTS) {
  2269. EnableOpGraphCache(false);
  2270. const auto last_top_cell = top_cell_list_.back();
  2271. top_cell_list_.pop_back();
  2272. MS_EXCEPTION_IF_NULL(last_top_cell);
  2273. last_top_cell->Clear();
  2274. (void)already_run_top_cell_.erase(last_top_cell->already_run_cell_id());
  2275. }
  2276. // Create top cell
  2277. auto fg = std::make_shared<FuncGraph>();
  2278. auto df_builder = std::make_shared<FuncGraph>();
  2279. auto resource = std::make_shared<pipeline::Resource>();
  2280. const auto &already_run_cell_id = GetAlreadyRunCellId(cell_id);
  2281. auto top_cell =
  2282. std::make_shared<TopCellInfo>(is_topest, grad_order_, resource, fg, df_builder, cell_id, already_run_cell_id);
  2283. top_cell->set_forward_already_run(true);
  2284. top_cell->set_input_args_id(input_args_id);
  2285. top_cell_list_.emplace_back(top_cell);
  2286. PushHighOrderGraphStack(top_cell);
  2287. set_top_cell(top_cell);
  2288. MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get();
  2289. }
  2290. void GradExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
  2291. bool is_param) {
  2292. if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
  2293. return;
  2294. }
  2295. auto tuple = args.cast<py::tuple>();
  2296. auto tuple_size = static_cast<int64_t>(tuple.size());
  2297. for (int64_t i = 0; i < tuple_size; ++i) {
  2298. // tuple slice used size_t
  2299. auto id = GetId(tuple[static_cast<size_t>(i)]);
  2300. if (is_param && node->isa<Parameter>()) {
  2301. auto param = node->cast<ParameterPtr>();
  2302. MS_EXCEPTION_IF_NULL(param);
  2303. SetParamNodeMapInGraphInfoMap(g, id, param);
  2304. }
  2305. SetNodeMapInGraphInfoMap(g, id, node, i);
  2306. SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, std::vector<int64_t>{i}, is_param);
  2307. }
  2308. }
  2309. void GradExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
  2310. const std::vector<int64_t> &index_sequence, bool is_param) {
  2311. if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
  2312. return;
  2313. }
  2314. MS_EXCEPTION_IF_NULL(node);
  2315. auto tuple = args.cast<py::tuple>();
  2316. auto tuple_size = static_cast<int64_t>(tuple.size());
  2317. for (int64_t i = 0; i < tuple_size; ++i) {
  2318. std::vector<int64_t> tmp = index_sequence;
  2319. tmp.emplace_back(i);
  2320. // tuple slice used size_t
  2321. auto id = GetId(tuple[static_cast<size_t>(i)]);
  2322. if (is_param && node->isa<Parameter>()) {
  2323. auto param = node->cast<ParameterPtr>();
  2324. MS_EXCEPTION_IF_NULL(param);
  2325. SetParamNodeMapInGraphInfoMap(g, id, param);
  2326. }
  2327. SetNodeMapInGraphInfoMap(g, id, node, tmp);
  2328. SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, tmp, is_param);
  2329. }
  2330. }
  2331. void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args) {
  2332. MS_EXCEPTION_IF_NULL(ret);
  2333. const auto &cell_id = GetCellId(cell, args);
  2334. MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
  2335. if (cell_stack_.empty()) {
  2336. if (cell_id == top_cell()->cell_id()) {
  2337. if (top_cell()->is_topest()) {
  2338. set_grad_flag(false);
  2339. }
  2340. if (GetHighOrderStackSize() < ARG_SIZE) {
  2341. auto outer_top_cell = PopHighOrderGraphStack();
  2342. if (outer_top_cell != nullptr) {
  2343. set_top_cell(outer_top_cell);
  2344. }
  2345. }
  2346. }
  2347. MS_LOG(DEBUG) << "Current cell " << cell_id << " no need to run EndGraphInner again";
  2348. return;
  2349. }
  2350. DoGradForCustomBprop(cell, out, args);
  2351. PopCellStack();
  2352. if (grad_is_running_ && !bprop_grad_stack_.empty()) {
  2353. if (!bprop_grad_stack_.top().second) {
  2354. curr_g()->set_output(GetObjNode(out, GetId(out)));
  2355. bprop_grad_stack_.pop();
  2356. return;
  2357. } else if (bprop_grad_stack_.top().first == cell_id) {
  2358. bprop_grad_stack_.pop();
  2359. }
  2360. }
  2361. // Just only dump the last forward graph
  2362. bool is_top_cell_end = cell_id == top_cell()->cell_id();
  2363. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && is_top_cell_end) {
  2364. curr_g()->set_output(GetObjNode(out, GetId(out)));
  2365. #ifdef ENABLE_DUMP_IR
  2366. DumpIR("fg.ir", curr_g());
  2367. #endif
  2368. }
  2369. // Reset grad flag and update output node of the outermost cell
  2370. if (cell_stack_.empty() && is_top_cell_end) {
  2371. MS_LOG(DEBUG) << "Cur top last cell " << cell_id;
  2372. PopHighOrderGraphStack();
  2373. auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
  2374. MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
  2375. k_pynative_cell_ptr->UpdateOutputNodeOfTopCell(GetObjNode(out, GetId(out)));
  2376. set_grad_flag(false);
  2377. }
  2378. // Checkout whether need to compile graph when each top cell has ran finished
  2379. if (is_top_cell_end) {
  2380. // In high grad cases, the output of the internal graph may be a tuple, and node needs to be created in the getobj
  2381. if (!cell_stack_.empty()) {
  2382. (void)GetObjNode(out, GetId(out));
  2383. }
  2384. CheckNeedCompileGraph();
  2385. }
  2386. }
  2387. void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args) {
  2388. if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  2389. return;
  2390. }
  2391. custom_bprop_cell_count_ -= 1;
  2392. if (custom_bprop_cell_count_ != 0) {
  2393. return;
  2394. }
  2395. MS_LOG(DEBUG) << "Do grad for custom bprop";
  2396. size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
  2397. if (par_number > 0) {
  2398. MS_LOG(EXCEPTION) << "When user defines the net bprop, the 'Parameter' data type is not supported in the net.";
  2399. }
  2400. py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
  2401. auto bprop_func_cellid = GetId(bprop_func);
  2402. bprop_cell_list_.emplace_back(bprop_func_cellid);
  2403. auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name());
  2404. if (py::isinstance<Cell>(cell)) {
  2405. auto cell_ptr = py::cast<CellPtr>(cell);
  2406. fake_prim->set_bprop_cls_name(cell_ptr->name());
  2407. }
  2408. fake_prim->set_hook(bprop_func);
  2409. const auto &cell_id = GetCellId(cell, args);
  2410. (void)fake_prim->AddAttr("cell_id", MakeValue(cell_id));
  2411. (void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
  2412. py::object code_obj = py::getattr(bprop_func, "__code__");
  2413. py::object co_name = py::getattr(code_obj, "co_name");
  2414. if (std::string(py::str(co_name)) == "staging_specialize") {
  2415. MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported.";
  2416. }
  2417. // Three parameters self, out and dout need to be excluded
  2418. const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
  2419. if (inputs_num > args.size()) {
  2420. MS_EXCEPTION(TypeError) << "Size of bprop func inputs[" << inputs_num << "] is larger than size of cell inputs["
  2421. << args.size() << "]";
  2422. }
  2423. py::list cell_inputs;
  2424. for (size_t i = 0; i < inputs_num; i += 1) {
  2425. cell_inputs.append(args[i]);
  2426. }
  2427. OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
  2428. op_exec_info->op_name = fake_prim->name();
  2429. op_exec_info->py_primitive = fake_prim;
  2430. op_exec_info->op_inputs = cell_inputs;
  2431. auto cnode = forward()->ConstructForwardGraph(op_exec_info);
  2432. const auto &v_out = PyObjToValue(out);
  2433. DoOpGrad(op_exec_info, cnode, v_out);
  2434. const auto &out_obj_id = GetId(out);
  2435. SaveOutputNodeMap(out_obj_id, out, cnode);
  2436. }
  2437. std::string GradExecutor::GetAlreadyRunCellId(const std::string &cell_id) {
  2438. std::string already_run_cell_id(cell_id);
  2439. already_run_cell_id += std::to_string(grad_order_ == 0 ? 1 : grad_order_);
  2440. already_run_cell_id += "_" + grad_operation_;
  2441. MS_LOG(DEBUG) << "Get already run top cell id " << already_run_cell_id;
  2442. return already_run_cell_id;
  2443. }
  2444. std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, const py::args &args) {
  2445. size_t forward_args_size = args.size();
  2446. py::args tmp = args;
  2447. if (has_sens) {
  2448. forward_args_size--;
  2449. py::tuple f_args(forward_args_size);
  2450. for (size_t i = 0; i < forward_args_size; ++i) {
  2451. f_args[i] = args[i];
  2452. }
  2453. tmp = f_args;
  2454. }
  2455. const auto &cell_id = GetCellId(cell, tmp);
  2456. return cell_id;
  2457. }
  2458. void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
  2459. const py::object &weights, const py::object &grad_position, const py::args &args) {
  2460. MS_EXCEPTION_IF_NULL(ret);
  2461. MS_EXCEPTION_IF_NULL(grad);
  2462. auto size = args.size();
  2463. const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args);
  2464. MS_LOG(DEBUG) << "GradNet start " << size << " " << cell_id;
  2465. if (!top_cell()->need_compile_graph()) {
  2466. MS_LOG(DEBUG) << "No need compile graph";
  2467. if (!cell_stack_.empty()) {
  2468. UpdateTopCellInfo(false, false, true);
  2469. } else {
  2470. UpdateTopCellInfo(false, false, false);
  2471. }
  2472. return;
  2473. }
  2474. top_cell()->set_grad_operation(grad_operation_);
  2475. auto resource = top_cell()->resource();
  2476. MS_EXCEPTION_IF_NULL(resource);
  2477. auto df_builder = top_cell()->df_builder();
  2478. MS_EXCEPTION_IF_NULL(df_builder);
  2479. MS_LOG(DEBUG) << "fg ptr " << curr_g().get() << " resource ptr " << resource.get();
  2480. // Get params(weights) require derivative
  2481. auto w_args = GetWeightsArgs(weights, df_builder);
  2482. auto p_args = GetGradPositionArgs(grad_position);
  2483. if (w_args.empty() && !df_builder->parameters().empty()) {
  2484. MS_LOG(DEBUG) << "Add weights params to w_args";
  2485. w_args.insert(w_args.end(), df_builder->parameters().begin(), df_builder->parameters().end());
  2486. }
  2487. // Get bprop graph of top cell
  2488. auto bprop_graph = GetBpropGraph(grad, cell, w_args, p_args, size, args);
  2489. resource->set_func_graph(bprop_graph);
  2490. auto manager = resource->manager();
  2491. MS_EXCEPTION_IF_NULL(manager);
  2492. manager->AddFuncGraph(bprop_graph, true);
  2493. DumpGraphIR("launch_bprop_graph.ir", bprop_graph);
  2494. // Launch bprop graph to backend
  2495. SaveForwardTensorInfoInBpropGraph(resource);
  2496. compile::SetMindRTEnable();
  2497. resource->results()[pipeline::kBackend] = compile::CreateBackend();
  2498. MS_LOG(DEBUG) << "Start task emit action";
  2499. TaskEmitAction(resource);
  2500. MS_LOG(DEBUG) << "Start execute action";
  2501. ExecuteAction(resource);
  2502. MS_LOG(DEBUG) << "Start update top cell info when run finish";
  2503. UpdateTopCellInfo(false, false, true);
  2504. resource->Clean();
  2505. abstract::AnalysisContext::ClearContext();
  2506. }
  2507. std::vector<AnfNodePtr> GradExecutor::GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder) {
  2508. MS_EXCEPTION_IF_NULL(df_builder);
  2509. if (!py::hasattr(weights, "__parameter_tuple__")) {
  2510. MS_LOG(DEBUG) << "No parameter tuple get";
  2511. return {};
  2512. }
  2513. const auto &tuple = weights.cast<py::tuple>();
  2514. MS_LOG(DEBUG) << "Get weights tuple size " << tuple.size();
  2515. std::vector<AnfNodePtr> w_args;
  2516. for (size_t it = 0; it < tuple.size(); ++it) {
  2517. auto param = tuple[it];
  2518. auto param_id = GetId(param);
  2519. auto &graph_info_map = top_cell()->graph_info_map();
  2520. if (graph_info_map.find(df_builder) == graph_info_map.end()) {
  2521. MS_LOG(EXCEPTION) << "Can not find df_builder " << df_builder.get() << " Top cell " << top_cell().get()
  2522. << " cell id " << top_cell()->cell_id();
  2523. }
  2524. auto graph_info = graph_info_map.at(df_builder);
  2525. MS_EXCEPTION_IF_NULL(graph_info);
  2526. AnfNodePtr para_node = nullptr;
  2527. if (graph_info->params.find(param_id) != graph_info->params.end()) {
  2528. para_node = graph_info->params.at(param_id);
  2529. w_args.emplace_back(para_node);
  2530. continue;
  2531. }
  2532. const auto &name_attr = parse::python_adapter::GetPyObjAttr(param, "name");
  2533. if (py::isinstance<py::none>(name_attr)) {
  2534. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  2535. }
  2536. const auto &param_name = py::cast<std::string>(name_attr);
  2537. MS_LOG(DEBUG) << "The input " << it << " parameter weight name " << param_name;
  2538. if (graph_info->params.find(param_name) != graph_info->params.end()) {
  2539. para_node = graph_info->params.at(param_name);
  2540. } else {
  2541. MS_LOG(DEBUG) << "Can not find input param in graph info map, make a new parameter";
  2542. auto free_param = df_builder->add_parameter();
  2543. free_param->set_name(param_name);
  2544. auto value = py::cast<tensor::TensorPtr>(param);
  2545. free_param->set_default_param(value);
  2546. free_param->debug_info()->set_name(param_name);
  2547. para_node = free_param;
  2548. }
  2549. w_args.emplace_back(para_node);
  2550. }
  2551. return w_args;
  2552. }
  2553. std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_position) {
  2554. std::vector<size_t> pos_args;
  2555. if (py::isinstance<py::tuple>(grad_position)) {
  2556. const auto &tuple = grad_position.cast<py::tuple>();
  2557. for (size_t it = 0; it < tuple.size(); ++it) {
  2558. auto param = tuple[it];
  2559. auto param_id = GetId(param);
  2560. pos_args.push_back(std::stoi(param_id));
  2561. }
  2562. return pos_args;
  2563. }
  2564. MS_LOG(EXCEPTION) << "Grad position only support tuple.";
  2565. }
  2566. void GradExecutor::UpdateParamAbsByArgs(const py::list &args, const FuncGraphPtr &bprop_graph) {
  2567. MS_EXCEPTION_IF_NULL(bprop_graph);
  2568. const auto &bprop_params = bprop_graph->parameters();
  2569. // bprop_params include inputs, parameters, more than size(inputs)
  2570. if (bprop_params.size() < args.size()) {
  2571. MS_LOG(EXCEPTION) << "Df parameters size " << bprop_params.size() << " less than " << args.size();
  2572. }
  2573. size_t index = 0;
  2574. for (const auto &param : bprop_params) {
  2575. auto param_node = param->cast<ParameterPtr>();
  2576. MS_EXCEPTION_IF_NULL(param_node);
  2577. if (param_node->has_default()) {
  2578. // update abstract info for weights
  2579. ValuePtr value = param_node->default_param();
  2580. MS_EXCEPTION_IF_NULL(value);
  2581. auto ptr = value->ToAbstract();
  2582. MS_EXCEPTION_IF_NULL(ptr);
  2583. param_node->set_abstract(ptr->Broaden());
  2584. } else {
  2585. // update abstract info for input params
  2586. auto input_abs = abstract::FromValue(PyObjToValue(args[index]), true);
  2587. if (param_node->abstract() != nullptr) {
  2588. auto input_shape = input_abs->BuildShape()->ToString();
  2589. auto param_tensor_abs = param_node->abstract();
  2590. if (param_tensor_abs->isa<abstract::AbstractRef>()) {
  2591. param_tensor_abs = param_tensor_abs->cast<abstract::AbstractRefPtr>()->CloneAsTensor();
  2592. }
  2593. auto ir_shape = param_tensor_abs->BuildShape()->ToString();
  2594. // Exclude const input
  2595. if (input_shape != "()" && ir_shape != "()") {
  2596. if (input_shape != ir_shape) {
  2597. MS_EXCEPTION(ValueError) << "The shape should be " << ir_shape << ", but got " << input_shape << ", "
  2598. << param->DebugString();
  2599. }
  2600. auto ir_dtype = param_tensor_abs->BuildType()->ToString();
  2601. auto input_dtype = input_abs->BuildType()->ToString();
  2602. if (input_dtype != ir_dtype) {
  2603. MS_EXCEPTION(TypeError) << "The dtype should be " << ir_dtype << ", but got " << input_dtype << ", "
  2604. << param->DebugString();
  2605. }
  2606. }
  2607. if (param_node->debug_info()->name() == "sens" && ir_shape != input_shape) {
  2608. need_renormalize_ = true;
  2609. }
  2610. }
  2611. param_node->set_abstract(input_abs->Broaden());
  2612. index++;
  2613. }
  2614. }
  2615. }
  2616. FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
  2617. const std::vector<AnfNodePtr> &weights,
  2618. const std::vector<size_t> &grad_position, size_t arg_size,
  2619. const py::args &args) {
  2620. bool build_formal_param = false;
  2621. if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !cell_stack_.empty() && IsNestedGrad()) {
  2622. build_formal_param = true;
  2623. need_renormalize_ = true;
  2624. }
  2625. if (top_cell()->ms_function_flag()) {
  2626. need_renormalize_ = true;
  2627. }
  2628. auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
  2629. MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
  2630. MS_EXCEPTION_IF_NULL(grad);
  2631. FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(k_pynative_cell_ptr, weights, grad_position, grad->get_all_,
  2632. grad->get_by_list_, grad->sens_param_, build_formal_param);
  2633. MS_EXCEPTION_IF_NULL(bprop_graph);
  2634. MS_LOG(DEBUG) << "Top graph input params size " << arg_size;
  2635. std::ostringstream ss;
  2636. ss << "grad{" << arg_size << "}";
  2637. bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  2638. bprop_graph->debug_info()->set_name(ss.str());
  2639. // Get the parameters items and add the value to args_spec
  2640. UpdateParamAbsByArgs(FilterTensorArgs(args, grad->sens_param_), bprop_graph);
  2641. // Do opt for final bprop graph
  2642. pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
  2643. resource->set_func_graph(bprop_graph);
  2644. auto manager = resource->manager();
  2645. MS_EXCEPTION_IF_NULL(manager);
  2646. manager->AddFuncGraph(bprop_graph);
  2647. auto optimized_bg = ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().BpropGraphFinalOpt(resource);
  2648. if (cell_stack_.empty()) {
  2649. need_renormalize_ = false;
  2650. }
  2651. DumpGraphIR("after_final_opt.ir", optimized_bg);
  2652. return optimized_bg;
  2653. }
  2654. py::object GradExecutor::CheckGraph(const py::object &cell, const py::args &args) {
  2655. BaseRef ret = false;
  2656. check_graph_cell_id_ = GetCellId(cell, args);
  2657. if (!(top_cell_ != nullptr && check_graph_cell_id_.find(top_cell_->cell_id()) != std::string::npos &&
  2658. grad_order_ >= 1)) {
  2659. ++grad_order_;
  2660. }
  2661. if (!grad_is_running_) {
  2662. MS_LOG(DEBUG) << "Grad not running yet";
  2663. return BaseRefToPyData(ret);
  2664. }
  2665. MS_LOG(DEBUG) << "Key is " << check_graph_cell_id_;
  2666. if (top_cell_ != nullptr) {
  2667. for (auto it = top_cell_->sub_cell_list().begin(); it != top_cell_->sub_cell_list().end(); ++it) {
  2668. MS_LOG(DEBUG) << "Cur cell id " << *it;
  2669. if (!IsCellObjIdEq(*it, check_graph_cell_id_)) {
  2670. continue;
  2671. }
  2672. MS_LOG(DEBUG) << "Delete cellid from cell graph list, top cell is " << top_cell_;
  2673. top_cell_->sub_cell_list().erase(it);
  2674. ret = true;
  2675. break;
  2676. }
  2677. }
  2678. return BaseRefToPyData(ret);
  2679. }
  2680. py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell,
  2681. const py::args &args) {
  2682. bool forward_run = false;
  2683. // Get cell id and input args info
  2684. const auto &cell_id = GetCellId(cell, args);
  2685. grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_);
  2686. std::string input_args_id;
  2687. for (size_t i = 0; i < args.size(); ++i) {
  2688. input_args_id += GetId(args[i]) + "_";
  2689. }
  2690. // Under the condition that the stack is empty (forward process completed or no forward process),
  2691. // check whether need to run forward process
  2692. if (cell_stack_.empty() && top_cell_ != nullptr) {
  2693. const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
  2694. auto find_top_cell = GetTopCell(check_already_run_cell_id);
  2695. if (find_top_cell != nullptr) {
  2696. MS_LOG(DEBUG) << "Find already run top cell";
  2697. forward_run = find_top_cell->forward_already_run();
  2698. auto curr_top_cell = top_cell();
  2699. set_top_cell(find_top_cell);
  2700. bool input_args_changed =
  2701. !find_top_cell->input_args_id().empty() && find_top_cell->input_args_id() != input_args_id;
  2702. if (forward_run && input_args_changed && find_top_cell->is_dynamic()) {
  2703. MS_LOG(WARNING) << "The construct of running cell is dynamic and the input info of this cell has changed, "
  2704. "forward process will run again";
  2705. forward_run = false;
  2706. }
  2707. if (forward_run && GetHighOrderStackSize() >= 1) {
  2708. PushHighOrderGraphStack(curr_top_cell);
  2709. }
  2710. }
  2711. }
  2712. MS_LOG(DEBUG) << "Graph have already ran " << forward_run << " top cell id " << cell_id;
  2713. return BaseRefToPyData(forward_run);
  2714. }
  2715. void GradExecutor::CheckNeedCompileGraph() {
  2716. auto new_top_cell = top_cell();
  2717. const auto &already_top_cell_id = new_top_cell->already_run_cell_id();
  2718. // Update top cell by current cell op info
  2719. if (already_run_top_cell_.find(already_top_cell_id) == already_run_top_cell_.end()) {
  2720. MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has never been ran, need compile graph";
  2721. already_run_top_cell_[already_top_cell_id] = new_top_cell;
  2722. return;
  2723. }
  2724. MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has been ran";
  2725. auto pre_top_cell = already_run_top_cell_.at(already_top_cell_id);
  2726. MS_EXCEPTION_IF_NULL(pre_top_cell);
  2727. const auto &pre_all_op_info = pre_top_cell->all_op_info();
  2728. const auto &new_all_op_info = new_top_cell->all_op_info();
  2729. MS_LOG(DEBUG) << "Pre all op info : " << pre_all_op_info;
  2730. MS_LOG(DEBUG) << "New all op info : " << new_all_op_info;
  2731. if (pre_all_op_info != new_all_op_info) {
  2732. MS_LOG(DEBUG) << "The op info has been changed, need to compile graph again";
  2733. // The top cell switches exceeds MAX_TOP_CELL_COUNTS under the control flow, disable backend cache
  2734. if (top_cell_switch_counts_ >= MAX_TOP_CELL_COUNTS) {
  2735. EnableOpGraphCache(false);
  2736. } else {
  2737. // Increase top cell switches counts
  2738. ++top_cell_switch_counts_;
  2739. }
  2740. EraseTopCellFromTopCellList(pre_top_cell);
  2741. pre_top_cell->Clear();
  2742. already_run_top_cell_[already_top_cell_id] = new_top_cell;
  2743. g_pyobj_id_cache.clear();
  2744. } else {
  2745. MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
  2746. pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
  2747. // In high order situations, the internal top cell remains unchanged, but the external top cell has changed. Then
  2748. // the graph info of the internal top cell needs to be updated so that the external top cell can perceive it.
  2749. if (!cell_stack_.empty()) {
  2750. pre_top_cell->graph_info_map()[pre_top_cell->df_builder()] =
  2751. new_top_cell->graph_info_map()[new_top_cell->df_builder()];
  2752. }
  2753. EraseTopCellFromTopCellList(new_top_cell);
  2754. new_top_cell->Clear();
  2755. pre_top_cell->set_forward_already_run(true);
  2756. set_top_cell(pre_top_cell);
  2757. }
  2758. }
  2759. void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args) {
  2760. MS_EXCEPTION_IF_NULL(ret);
  2761. const auto &cell_id = GetCellId(cell, args);
  2762. MS_LOG(DEBUG) << "Run start cell id " << cell_id;
  2763. auto has_sens = std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) {
  2764. return cell_id.find(value->cell_id()) != std::string::npos && cell_id != value->cell_id();
  2765. });
  2766. MS_LOG(DEBUG) << "Run has sens " << has_sens << " cell id " << cell_id;
  2767. auto resource = top_cell()->resource();
  2768. MS_EXCEPTION_IF_NULL(resource);
  2769. MS_LOG(DEBUG) << "Run resource ptr " << resource.get();
  2770. VectorRef arg_list;
  2771. py::tuple converted_args = ConvertArgs(FilterTensorArgs(args, has_sens));
  2772. pipeline::ProcessVmArgInner(converted_args, resource, &arg_list);
  2773. MS_LOG(DEBUG) << "Convert args size " << converted_args.size() << ", graph param size " << arg_list.size();
  2774. if (resource->results().find(pipeline::kOutput) == resource->results().end()) {
  2775. MS_LOG(EXCEPTION) << "Can't find run graph output";
  2776. }
  2777. if (!resource->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
  2778. MS_LOG(EXCEPTION) << "Run graph is not VmEvalFuncPtr";
  2779. }
  2780. compile::VmEvalFuncPtr run = resource->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
  2781. MS_EXCEPTION_IF_NULL(run);
  2782. const auto &backend = MsContext::GetInstance()->backend_policy();
  2783. MS_LOG(DEBUG) << "Eval run " << backend;
  2784. grad_is_running_ = true;
  2785. BaseRef value = (*run)(arg_list);
  2786. grad_is_running_ = false;
  2787. MS_LOG(DEBUG) << "Eval run end " << value.ToString();
  2788. *ret = BaseRefToPyData(value);
  2789. // Clear device memory resource of top cell when it has been ran.
  2790. auto has_higher_order = std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
  2791. [](const TopCellInfoPtr &value) { return !value->is_topest(); });
  2792. if (top_cell()->is_topest() && !has_higher_order) {
  2793. top_cell()->ClearDeviceMemory();
  2794. }
  2795. // High order
  2796. if (top_cell()->vm_compiled()) {
  2797. MakeNestedCnode(cell, converted_args, resource, *ret);
  2798. } else if (GetHighOrderStackSize() >= ARG_SIZE) {
  2799. SwitchTopcell();
  2800. }
  2801. }
  2802. void GradExecutor::SwitchTopcell() {
  2803. const auto &inner_top_cell_all_op_info = top_cell()->all_op_info();
  2804. bool inner_top_cell_is_dynamic = top_cell()->is_dynamic();
  2805. // Get outer top cell
  2806. auto outer_top_cell = PopHighOrderGraphStack();
  2807. MS_EXCEPTION_IF_NULL(outer_top_cell);
  2808. outer_top_cell->all_op_info() += inner_top_cell_all_op_info;
  2809. // If inner is dynamic, outer set dynamic too
  2810. if (inner_top_cell_is_dynamic) {
  2811. outer_top_cell->set_is_dynamic(inner_top_cell_is_dynamic);
  2812. }
  2813. set_top_cell(outer_top_cell);
  2814. }
  2815. void GradExecutor::DoParameterReplace(const FuncGraphPtr &first_grad_fg, const py::tuple &forward_args,
  2816. std::vector<AnfNodePtr> *inputs, ValuePtrList *weights_args) {
  2817. MS_EXCEPTION_IF_NULL(inputs);
  2818. MS_EXCEPTION_IF_NULL(weights_args);
  2819. auto first_df_builder = top_cell()->df_builder();
  2820. MS_EXCEPTION_IF_NULL(first_df_builder);
  2821. auto first_graph_info = top_cell()->graph_info_map().at(first_df_builder);
  2822. MS_EXCEPTION_IF_NULL(first_graph_info);
  2823. SwitchTopcell();
  2824. auto second_df_builder = top_cell()->df_builder();
  2825. MS_EXCEPTION_IF_NULL(second_df_builder);
  2826. auto second_graph_info = top_cell()->graph_info_map().at(second_df_builder);
  2827. MS_EXCEPTION_IF_NULL(second_graph_info);
  2828. mindspore::HashSet<std::string> params_weights_set;
  2829. mindspore::HashSet<std::string> params_inputs_set;
  2830. for (const auto &sec : second_graph_info->params) {
  2831. if (sec.second->has_default()) {
  2832. params_weights_set.emplace(sec.first);
  2833. } else {
  2834. params_inputs_set.insert(sec.first);
  2835. }
  2836. }
  2837. auto manager = Manage({first_grad_fg}, false);
  2838. // Replace inputs param
  2839. for (size_t i = 0; i < forward_args.size(); ++i) {
  2840. const auto &id = GetId(forward_args[i]);
  2841. if (params_inputs_set.count(id)) {
  2842. // Can find in second graph
  2843. const auto &input_param_second = second_graph_info->params.at(id);
  2844. manager->Replace(first_graph_info->params.at(id), input_param_second);
  2845. inputs->emplace_back(input_param_second);
  2846. } else {
  2847. inputs->emplace_back(GetInput(forward_args[i], false));
  2848. }
  2849. }
  2850. // Replace weights param
  2851. for (const auto &fir : first_graph_info->params) {
  2852. if (!fir.second->has_default()) {
  2853. continue;
  2854. }
  2855. // Second graph no this weight param, need add to second graph
  2856. if (!params_weights_set.count(fir.first)) {
  2857. MS_LOG(DEBUG) << "Can't find " << fir.first << " in outer graph, add it";
  2858. second_df_builder->add_parameter(fir.second);
  2859. SetParamNodeMapInGraphInfoMap(second_df_builder, fir.first, fir.second);
  2860. inputs->emplace_back(fir.second);
  2861. weights_args->emplace_back(fir.second->default_param());
  2862. } else {
  2863. // Need replace
  2864. MS_LOG(DEBUG) << "Param name " << fir.first << " ptr " << fir.second.get();
  2865. auto it = std::find_if(second_graph_info->params.begin(), second_graph_info->params.end(),
  2866. [&fir](const std::pair<std::string, ParameterPtr> &sec) {
  2867. return sec.second->has_default() && fir.second->name() == sec.second->name();
  2868. });
  2869. if (it != second_graph_info->params.end()) {
  2870. manager->Replace(fir.second, it->second);
  2871. inputs->emplace_back(it->second);
  2872. weights_args->emplace_back(it->second->default_param());
  2873. }
  2874. }
  2875. }
  2876. }
  2877. void GradExecutor::MakeNestedCnode(const py::object &cell, const py::tuple &forward_args,
  2878. const pipeline::ResourcePtr &resource, const py::object &out) {
  2879. if (cell_stack_.empty()) {
  2880. MS_LOG(DEBUG) << "No nested grad find";
  2881. return;
  2882. }
  2883. FuncGraphPtr first_grad_fg = nullptr;
  2884. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  2885. first_grad_fg = curr_g();
  2886. MS_LOG(DEBUG) << "Bprop nested";
  2887. } else {
  2888. first_grad_fg = resource->func_graph();
  2889. }
  2890. MS_EXCEPTION_IF_NULL(first_grad_fg);
  2891. DumpGraphIR("first_grad_fg.ir", first_grad_fg);
  2892. std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
  2893. ValuePtrList weights_args;
  2894. DoParameterReplace(first_grad_fg, forward_args, &inputs, &weights_args);
  2895. pipeline::ResourcePtr r = std::make_shared<pipeline::Resource>();
  2896. r->manager()->AddFuncGraph(first_grad_fg);
  2897. set_eliminate_forward(false);
  2898. first_grad_fg->transforms().erase(kGrad);
  2899. FuncGraphPtr second_grad_fg = ad::Grad(first_grad_fg, r);
  2900. set_eliminate_forward(true);
  2901. DumpGraphIR("second_grad_fg.ir", second_grad_fg);
  2902. r->Clean();
  2903. MS_LOG(DEBUG) << "Get pre graph ptr " << curr_g().get();
  2904. auto cnode = curr_g()->NewCNode(inputs);
  2905. auto out_id = GetId(out);
  2906. SetTupleArgsToGraphInfoMap(curr_g(), out, cnode);
  2907. SetNodeMapInGraphInfoMap(curr_g(), out_id, cnode);
  2908. MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString();
  2909. // Get input values
  2910. ValuePtrList input_args;
  2911. for (size_t i = 0; i < forward_args.size(); ++i) {
  2912. const auto &arg = PyObjToValue(forward_args[i]);
  2913. input_args.emplace_back(arg);
  2914. }
  2915. input_args.insert(input_args.end(), weights_args.begin(), weights_args.end());
  2916. // Get output values
  2917. py::object new_out;
  2918. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !py::isinstance<py::tuple>(out)) {
  2919. new_out = py::make_tuple(out);
  2920. } else {
  2921. new_out = out;
  2922. }
  2923. const auto &out_value = PyObjToValue(new_out);
  2924. if (!top_cell()->k_pynative_cell_ptr()->KPynativeWithFProp(cnode, input_args, out_value, second_grad_fg)) {
  2925. MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph " << cnode->ToString();
  2926. }
  2927. need_renormalize_ = true;
  2928. }
  2929. void GradExecutor::EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell) {
  2930. MS_EXCEPTION_IF_NULL(top_cell);
  2931. auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
  2932. [&](const TopCellInfoPtr &elem) { return elem.get() == top_cell.get(); });
  2933. if (iter == top_cell_list_.end()) {
  2934. MS_LOG(WARNING) << "Can not find top cell " << top_cell.get() << " cell id " << top_cell->cell_id()
  2935. << " from top cell list";
  2936. } else {
  2937. (void)top_cell_list_.erase(iter);
  2938. }
  2939. }
  2940. void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
  2941. const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) {
  2942. // Get actual output value and added output value.
  2943. if (!py::isinstance<py::tuple>(out)) {
  2944. MS_LOG(EXCEPTION) << "The output value of ms_function func graph should be a tuple.";
  2945. }
  2946. auto tuple_out = py::cast<py::tuple>(out);
  2947. constexpr size_t tuple_out_size = 2;
  2948. if (tuple_out.size() != tuple_out_size) {
  2949. MS_LOG(EXCEPTION) << "The tuple size of output value of ms_function func graph should be 2.";
  2950. }
  2951. py::object actual_out = tuple_out[0];
  2952. auto actual_out_v = PyObjToValue(actual_out);
  2953. auto added_out = PyObjToValue(tuple_out[1]);
  2954. MS_LOG(DEBUG) << "Added output value is: " << added_out->ToString();
  2955. // Identity op info for current running ms_func graph.
  2956. OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
  2957. op_exec_info->op_name = phase;
  2958. RecordGradOpInfo(op_exec_info, actual_out_v);
  2959. MS_LOG(DEBUG) << "ms_function cnode op info: " << op_exec_info->op_info;
  2960. // Step 1: Update actual output tensors used in grad graph.
  2961. MS_LOG(DEBUG) << "ms_function actual output value: " << actual_out_v->ToString();
  2962. UpdateForwardTensorInfoInBpropGraph(op_exec_info, actual_out_v);
  2963. // Step 2: Update output tensors of added forward nodes, which are added to return node of ms_function func graph.
  2964. if (top_cell()->op_info_with_ms_func_forward_tensors().count(op_exec_info->op_info)) {
  2965. UpdateMsFunctionForwardTensors(op_exec_info, added_out);
  2966. return;
  2967. }
  2968. MS_LOG(DEBUG) << "Ms func graph run firstly. The graph phase is: " << graph_phase();
  2969. if (!need_construct_graph()) {
  2970. MS_LOG(EXCEPTION) << "The flag of need construct graph is False.";
  2971. }
  2972. ReplaceNewTensorsInGradGraph(top_cell(), op_exec_info, added_out, ms_func_graph, grad_graph);
  2973. // Clone new ms_function func graph and grad graph.
  2974. auto new_ms_func_graph = BasicClone(ms_func_graph);
  2975. auto new_grad_graph = BasicClone(grad_graph, true);
  2976. auto new_make_tuple = new_ms_func_graph->output()->cast<CNodePtr>();
  2977. MS_EXCEPTION_IF_NULL(new_make_tuple);
  2978. new_ms_func_graph->set_output(new_make_tuple->input(1));
  2979. // Make Adjoint for grad graph
  2980. MakeAdjointForMsFunction(new_ms_func_graph, new_grad_graph, actual_out, args, actual_out_v);
  2981. }
  2982. py::object GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
  2983. // Get actual forward output object.
  2984. if (graph_phase().empty()) {
  2985. MS_LOG(EXCEPTION) << "The graph phase is empty, can not obtain ms_function func graph.";
  2986. }
  2987. const auto &phase = graph_phase();
  2988. MS_LOG(DEBUG) << "ms_function func graph phase: " << phase;
  2989. auto executor = pipeline::GraphExecutorPy::GetInstance();
  2990. MS_EXCEPTION_IF_NULL(executor);
  2991. FuncGraphPtr ms_func_graph = executor->GetFuncGraph(phase);
  2992. MS_EXCEPTION_IF_NULL(ms_func_graph);
  2993. py::object ret = out;
  2994. if (ms_func_graph->modify_output()) {
  2995. auto tuple_out = py::cast<py::tuple>(out);
  2996. ret = tuple_out[0];
  2997. }
  2998. // Make Adjoint for grad graph of ms_function.
  2999. if (!grad_flag_) {
  3000. MS_LOG(DEBUG) << "Only run forward infer computation, no need to construct grad graph.";
  3001. set_graph_phase("");
  3002. return ret;
  3003. }
  3004. FuncGraphPtr grad_graph = executor->GetGradGraph(phase);
  3005. MS_EXCEPTION_IF_NULL(grad_graph);
  3006. GradMsFunctionInner(phase, out, args, ms_func_graph, grad_graph);
  3007. set_graph_phase("");
  3008. return ret;
  3009. }
  3010. void GradExecutor::ClearGrad(const py::object &cell, const py::args &args) {
  3011. MS_LOG(DEBUG) << "Clear top cell grad resource " << GetCellId(cell, args);
  3012. if (grad_order_ > 0) {
  3013. --grad_order_;
  3014. }
  3015. check_graph_cell_id_.clear();
  3016. grad_operation_.clear();
  3017. forward()->node_abs_map().clear();
  3018. ad::CleanRes();
  3019. pipeline::ReclaimOptimizer();
  3020. }
  3021. void GradExecutor::ClearRes() {
  3022. MS_LOG(DEBUG) << "Clear grad res";
  3023. grad_flag_ = false;
  3024. enable_op_cache_ = true;
  3025. grad_is_running_ = false;
  3026. need_renormalize_ = false;
  3027. eliminate_forward_ = true;
  3028. custom_bprop_cell_count_ = 0;
  3029. grad_order_ = 0;
  3030. top_cell_switch_counts_ = 0;
  3031. check_graph_cell_id_.clear();
  3032. grad_operation_.clear();
  3033. top_cell_ = nullptr;
  3034. bprop_cell_list_.clear();
  3035. already_run_top_cell_.clear();
  3036. ClearCellRes();
  3037. std::stack<std::pair<std::string, bool>>().swap(bprop_grad_stack_);
  3038. std::stack<std::string>().swap(cell_stack_);
  3039. std::stack<TopCellInfoPtr>().swap(high_order_stack_);
  3040. }
  3041. GradExecutorPtr PynativeExecutor::grad_executor() const {
  3042. MS_EXCEPTION_IF_NULL(grad_executor_);
  3043. return grad_executor_;
  3044. }
  3045. ForwardExecutorPtr PynativeExecutor::forward_executor() const {
  3046. MS_EXCEPTION_IF_NULL(forward_executor_);
  3047. return forward_executor_;
  3048. }
  3049. bool PynativeExecutor::grad_flag() const { return grad_executor()->grad_flag(); }
  3050. void PynativeExecutor::set_grad_flag(bool flag) { grad_executor()->set_grad_flag(flag); }
  3051. void PynativeExecutor::set_graph_phase(const std::string &graph_phase) {
  3052. grad_executor()->set_graph_phase(graph_phase);
  3053. }
  3054. void PynativeExecutor::set_py_exe_path(const py::object &py_exe_path) {
  3055. if (!py::isinstance<py::str>(py_exe_path)) {
  3056. MS_LOG(EXCEPTION) << "Failed, py_exe_path input is not a str";
  3057. }
  3058. auto py_exe_path_s = py::cast<std::string>(py_exe_path);
  3059. auto ms_context = MsContext::GetInstance();
  3060. ms_context->set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, py_exe_path_s);
  3061. }
  3062. void PynativeExecutor::set_kernel_build_server_dir(const py::object &kernel_build_server_dir) {
  3063. if (!py::isinstance<py::str>(kernel_build_server_dir)) {
  3064. MS_LOG(EXCEPTION) << "Failed, kernel_build_server_dir input is not a str";
  3065. }
  3066. auto kernel_build_server_dir_s = py::cast<std::string>(kernel_build_server_dir);
  3067. auto ms_context = MsContext::GetInstance();
  3068. ms_context->set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, kernel_build_server_dir_s);
  3069. }
  3070. py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) {
  3071. return grad_executor()->CheckGraph(cell, args);
  3072. }
  3073. py::object PynativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell,
  3074. const py::args &args) {
  3075. return grad_executor()->CheckAlreadyRun(grad, cell, args);
  3076. }
  3077. py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args) {
  3078. py::object ret;
  3079. PynativeExecutorTry(grad_executor()->RunGraph, &ret, cell, args);
  3080. return ret;
  3081. }
  3082. void PynativeExecutor::ClearCell(const std::string &cell_id) {
  3083. MS_LOG(DEBUG) << "Clear cell res, cell id " << cell_id;
  3084. grad_executor()->ClearCellRes(cell_id);
  3085. }
  3086. void PynativeExecutor::ClearGrad(const py::object &cell, const py::args &args) {
  3087. MS_LOG(DEBUG) << "Clear grad";
  3088. return grad_executor()->ClearGrad(cell, args);
  3089. }
  3090. void PynativeExecutor::ClearRes() {
  3091. MS_LOG(DEBUG) << "Clear all res";
  3092. session::PynativeTaskManager::GetInstance().Reset();
  3093. for (auto &item : kMindRtBackends) {
  3094. MS_EXCEPTION_IF_NULL(item.second);
  3095. item.second->ClearOpBuilderResource();
  3096. }
  3097. SetLazyBuild(false);
  3098. cell_depth_ = 0;
  3099. // Maybe exit in runop step
  3100. auto ms_context = MsContext::GetInstance();
  3101. if (ms_context != nullptr) {
  3102. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
  3103. }
  3104. ConfigManager::GetInstance().ResetIterNum();
  3105. if (forward_executor_ != nullptr) {
  3106. forward_executor_->ClearRes();
  3107. }
  3108. if (grad_executor_ != nullptr) {
  3109. grad_executor_->ClearRes();
  3110. }
  3111. ad::CleanRes();
  3112. pipeline::ReclaimOptimizer();
  3113. kSessionBackends.clear();
  3114. kMindRtBackends.clear();
  3115. g_pyobj_id_cache.clear();
  3116. }
  3117. void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
  3118. // Make a flag for new cell
  3119. if (!grad_executor()->grad_flag()) {
  3120. MS_LOG(DEBUG) << "Grad flag is false";
  3121. return;
  3122. }
  3123. py::object ret;
  3124. PynativeExecutorTry(grad_executor()->InitGraph, &ret, cell, args);
  3125. }
  3126. void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
  3127. if (!grad_executor()->grad_flag()) {
  3128. MS_LOG(DEBUG) << "Grad flag is false";
  3129. return;
  3130. }
  3131. MS_LOG(DEBUG) << "Enter end graph process.";
  3132. py::object ret;
  3133. PynativeExecutorTry(grad_executor()->LinkGraph, &ret, cell, out, args);
  3134. MS_LOG(DEBUG) << "Leave end graph process.";
  3135. }
  3136. py::object PynativeExecutor::GradMsFunction(const py::object &out, const py::args &args) {
  3137. return grad_executor()->GradMsFunction(out, args);
  3138. }
  3139. void PynativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  3140. const py::object &grad_position, const py::args &args) {
  3141. py::object ret;
  3142. PynativeExecutorTry(grad_executor()->GradGraph, &ret, grad, cell, weights, grad_position, args);
  3143. }
  3144. void PynativeExecutor::Sync() {
  3145. auto ms_context = MsContext::GetInstance();
  3146. MS_EXCEPTION_IF_NULL(ms_context);
  3147. ExecuteAllTask();
  3148. if (!ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
  3149. for (auto &item : kSessionBackends) {
  3150. MS_EXCEPTION_IF_NULL(item.second);
  3151. item.second->SyncStream();
  3152. }
  3153. } else {
  3154. for (auto &item : kMindRtBackends) {
  3155. MS_EXCEPTION_IF_NULL(item.second);
  3156. item.second->SyncStream();
  3157. }
  3158. for (auto &item : kSessionBackends) {
  3159. MS_EXCEPTION_IF_NULL(item.second);
  3160. item.second->SyncStream();
  3161. }
  3162. }
  3163. }
  3164. void PynativeExecutor::SetLazyBuild(bool enable) { forward_executor()->set_lazy_build(enable); }
  3165. void PynativeExecutor::EnterCell() {
  3166. if (cell_depth_ < UINT32_MAX) {
  3167. ++cell_depth_;
  3168. } else {
  3169. MS_LOG(ERROR) << "Cell call stack too deep";
  3170. }
  3171. }
  3172. void PynativeExecutor::ExitCell() {
  3173. if (cell_depth_ > 0) {
  3174. --cell_depth_;
  3175. }
  3176. }
  3177. bool PynativeExecutor::IsTopCell() const { return cell_depth_ == 0; }
  3178. void PynativeExecutor::ExecuteAllTask() {
  3179. session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks();
  3180. for (auto &item : kMindRtBackends) {
  3181. MS_EXCEPTION_IF_NULL(item.second);
  3182. item.second->SyncLazyTasks();
  3183. }
  3184. }
  3185. REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
  3186. (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
  3187. .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
  3188. .def("enter_cell", &PynativeExecutor::EnterCell, "enter cell.")
  3189. .def("exit_cell", &PynativeExecutor::ExitCell, "exit cell.")
  3190. .def("is_top_cell", &PynativeExecutor::IsTopCell, "check top cell.")
  3191. .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
  3192. .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
  3193. .def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.")
  3194. .def("check_run", &PynativeExecutor::CheckAlreadyRun, "pynative check graph run before.")
  3195. .def("grad_ms_function", &PynativeExecutor::GradMsFunction, "pynative grad for ms_function.")
  3196. .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
  3197. .def("clear_cell", &PynativeExecutor::ClearCell, "pynative clear status.")
  3198. .def("clear_res", &PynativeExecutor::ClearRes, "pynative clear exception res.")
  3199. .def("clear_grad", &PynativeExecutor::ClearGrad, "pynative clear grad status.")
  3200. .def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
  3201. .def("set_lazy_build", &PynativeExecutor::SetLazyBuild, "pynative build kernel async")
  3202. .def("execute_all_task", &PynativeExecutor::ExecuteAllTask, "clear all task")
  3203. .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
  3204. .def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase")
  3205. .def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag")
  3206. .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
  3207. "Executor set grad flag.")
  3208. .def("set_py_exe_path", &PynativeExecutor::set_py_exe_path,
  3209. py::arg("py_exe_path") = py::str(""), "set python executable path.")
  3210. .def("set_kernel_build_server_dir", &PynativeExecutor::set_kernel_build_server_dir,
  3211. py::arg("kernel_build_server_dir") = py::str(""),
  3212. "set kernel build server directory path.");
  3213. }));
  3214. } // namespace mindspore::pynative