You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pynative_execute.cc 124 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/pynative/pynative_execute.h"
  17. #include <typeinfo>
  18. #include <map>
  19. #include <set>
  20. #include <memory>
  21. #include <sstream>
  22. #include <unordered_set>
  23. #include <algorithm>
  24. #include "debug/trace.h"
  25. #include "pybind_api/ir/tensor_py.h"
  26. #include "ir/param_info.h"
  27. #include "ir/anf.h"
  28. #include "ir/cell.h"
  29. #include "ir/tensor.h"
  30. #include "utils/any.h"
  31. #include "utils/utils.h"
  32. #include "utils/ms_context.h"
  33. #include "utils/check_convert_utils.h"
  34. #include "utils/context/context_extends.h"
  35. #include "utils/config_manager.h"
  36. #include "utils/convert_utils_py.h"
  37. #include "frontend/operator/ops.h"
  38. #include "frontend/operator/composite/do_signature.h"
  39. #include "pipeline/jit/parse/data_converter.h"
  40. #include "pipeline/jit/parse/resolve.h"
  41. #include "pipeline/jit/static_analysis/prim.h"
  42. #include "pipeline/jit/static_analysis/auto_monad.h"
  43. #include "backend/session/session_factory.h"
  44. #include "backend/optimizer/pass/const_input_to_attr_registry.h"
  45. #include "backend/optimizer/common/helper.h"
  46. #include "pipeline/jit/action.h"
  47. #include "pipeline/pynative/base.h"
  48. #include "pybind_api/api_register.h"
  49. #include "vm/transform.h"
  50. #include "frontend/optimizer/ad/grad.h"
  51. #include "pipeline/jit/resource.h"
  52. #include "pipeline/jit/pipeline.h"
  53. #include "pipeline/jit/pass.h"
  54. #include "frontend/parallel/context.h"
  55. #ifdef ENABLE_GE
  56. #include "pipeline/pynative/pynative_execute_ge.h"
  57. #endif
  58. #include "debug/anf_ir_dump.h"
  59. using mindspore::tensor::TensorPy;
  60. const size_t PTR_LEN = 15;
  61. // primitive unable to infer value for constant input in PyNative mode
  62. static const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
  63. "mixed_precision_cast"};
  64. static const char kOpsFunctionModelName[] = "mindspore.ops.functional";
  65. static const char kMSDtypeModelName[] = "mindspore.common.dtype";
  66. namespace mindspore::pynative {
  67. static std::shared_ptr<session::SessionBasic> session = nullptr;
  68. PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
  69. std::mutex PynativeExecutor::instance_lock_;
  70. int64_t PynativeExecutor::graph_id_ = 0;
  71. template <typename... Args>
  72. void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) {
  73. MS_EXCEPTION_IF_NULL(executor);
  74. try {
  75. (executor->*method)(args...);
  76. } catch (const py::error_already_set &ex) {
  77. // print function call stack info before release
  78. std::ostringstream oss;
  79. trace::TraceGraphEval();
  80. trace::GetEvalStackInfo(oss);
  81. // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
  82. // these info from screen, no need to open log file to find these info
  83. py::print(oss.str());
  84. MS_LOG(ERROR) << oss.str();
  85. PynativeExecutor::GetInstance()->ClearRes();
  86. // re-throw this exception to Python interpreter to handle it
  87. throw(py::error_already_set(ex));
  88. } catch (const py::type_error &ex) {
  89. PynativeExecutor::GetInstance()->ClearRes();
  90. throw py::type_error(ex);
  91. } catch (const py::value_error &ex) {
  92. PynativeExecutor::GetInstance()->ClearRes();
  93. throw py::value_error(ex);
  94. } catch (const py::index_error &ex) {
  95. PynativeExecutor::GetInstance()->ClearRes();
  96. throw py::index_error(ex);
  97. } catch (const std::exception &ex) {
  98. PynativeExecutor::GetInstance()->ClearRes();
  99. // re-throw this exception to Python interpreter to handle it
  100. throw(std::runtime_error(ex.what()));
  101. } catch (...) {
  102. PynativeExecutor::GetInstance()->ClearRes();
  103. std::string exName(abi::__cxa_current_exception_type()->name());
  104. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
  105. }
  106. }
  107. inline ValuePtr PyAttrValue(const py::object &obj) {
  108. ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
  109. if (!converted_ret) {
  110. MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
  111. }
  112. return converted_ret;
  113. }
  114. static std::string GetId(const py::object &obj) {
  115. if (py::isinstance<tensor::Tensor>(obj)) {
  116. auto tensor_ptr = py::cast<tensor::TensorPtr>(obj);
  117. return tensor_ptr->id();
  118. } else if (py::isinstance<mindspore::Type>(obj)) {
  119. auto type_ptr = py::cast<mindspore::TypePtr>(obj);
  120. return "type" + type_ptr->ToString();
  121. } else if (py::isinstance<py::str>(obj) || py::isinstance<py::int_>(obj) || py::isinstance<py::float_>(obj)) {
  122. return std::string(py::str(obj));
  123. } else if (py::isinstance<py::none>(obj)) {
  124. return "none";
  125. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  126. auto p_list = py::cast<py::tuple>(obj);
  127. string prefix = py::isinstance<py::tuple>(obj) ? "tuple:" : "list";
  128. if (p_list.empty()) {
  129. prefix = "empty";
  130. } else {
  131. std::string key;
  132. for (size_t i = 0; i < p_list.size(); ++i) {
  133. key += std::string(py::str(GetId(p_list[i]))) + ":";
  134. }
  135. prefix += key;
  136. }
  137. return prefix;
  138. }
  139. py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
  140. return py::cast<std::string>(ret);
  141. }
  142. std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
  143. std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
  144. for (size_t i = 0; i < dtypes.size(); ++i) {
  145. auto it = type_indexes.find(dtypes[i]);
  146. if (it == type_indexes.end()) {
  147. (void)type_indexes.emplace(std::make_pair(dtypes[i], std::vector<size_t>{i}));
  148. } else {
  149. it->second.emplace_back(i);
  150. }
  151. }
  152. return type_indexes;
  153. }
  154. TypeId JudgeMaxType(TypeId max_type, bool has_scalar_float32, bool has_scalar_int64, bool has_tensor_int8) {
  155. if (max_type == TypeId::kNumberTypeBool) {
  156. if (has_scalar_int64) {
  157. max_type = TypeId::kNumberTypeInt64;
  158. }
  159. if (has_scalar_float32) {
  160. max_type = TypeId::kNumberTypeFloat32;
  161. }
  162. }
  163. if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
  164. max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
  165. max_type = TypeId::kNumberTypeFloat32;
  166. }
  167. if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
  168. max_type = TypeId::kNumberTypeInt16;
  169. }
  170. return max_type;
  171. }
  172. std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
  173. const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
  174. std::map<SignatureEnumDType, TypeId> dst_type;
  175. for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
  176. auto type = it->first;
  177. auto indexes = it->second;
  178. if (type == SignatureEnumDType::kDTypeEmptyDefaultValue || indexes.size() < 2) {
  179. continue;
  180. }
  181. size_t priority = 0;
  182. TypeId max_type = TypeId::kTypeUnknown;
  183. bool has_scalar_float32 = false;
  184. bool has_scalar_int64 = false;
  185. bool has_tensor_int8 = false;
  186. for (size_t index : indexes) {
  187. auto obj = py_args[index];
  188. if (py::isinstance<py::float_>(obj)) {
  189. has_scalar_float32 = true;
  190. }
  191. if (!py::isinstance<py::bool_>(obj) && py::isinstance<py::int_>(obj)) {
  192. has_scalar_int64 = true;
  193. }
  194. if (py::isinstance<tensor::Tensor>(obj)) {
  195. auto arg = py::cast<tensor::TensorPtr>(obj);
  196. TypeId arg_type_id = arg->data_type();
  197. auto type_priority = prim::type_map.find(arg_type_id);
  198. if (type_priority == prim::type_map.end()) {
  199. continue;
  200. }
  201. if (arg_type_id == kNumberTypeInt8) {
  202. has_tensor_int8 = true;
  203. }
  204. if (type_priority->second > priority) {
  205. max_type = type_priority->first;
  206. priority = type_priority->second;
  207. }
  208. }
  209. }
  210. max_type = JudgeMaxType(max_type, has_scalar_float32, has_scalar_int64, has_tensor_int8);
  211. (void)dst_type.emplace(std::make_pair(type, max_type));
  212. }
  213. return dst_type;
  214. }
  215. std::string TypeIdToMsTypeStr(const TypeId &type_id) {
  216. auto type_name = type_name_map.find(type_id);
  217. if (type_name == type_name_map.end()) {
  218. MS_LOG(EXCEPTION) << "For implicit type conversion, not support convert to the type: " << TypeIdToType(type_id);
  219. }
  220. return type_name->second;
  221. }
  222. bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
  223. MS_EXCEPTION_IF_NULL(dtypes);
  224. auto signature = prim->signatures();
  225. bool has_sig_dtype = false;
  226. (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
  227. [&has_sig_dtype](const Signature &sig) {
  228. auto dtype = sig.dtype;
  229. if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
  230. has_sig_dtype = true;
  231. }
  232. return dtype;
  233. });
  234. return has_sig_dtype;
  235. }
  236. void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info,
  237. const abstract::AbstractBasePtrList &args_spec_list) {
  238. MS_LOG(DEBUG) << "Prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list);
  239. prim->BeginRecordAddAttr();
  240. AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
  241. prim->EndRecordAddAttr();
  242. op_exec_info->abstract = infer_res;
  243. MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString();
  244. }
  245. std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
  246. const std::vector<tensor::TensorPtr> &input_tensors) {
  247. MS_EXCEPTION_IF_NULL(op_exec_info);
  248. std::string graph_info;
  249. // get input tensor info
  250. for (const auto &tensor : input_tensors) {
  251. MS_EXCEPTION_IF_NULL(tensor);
  252. auto tensor_shape = tensor->shape();
  253. (void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
  254. [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
  255. (void)graph_info.append(std::to_string(tensor->data_type()) + "_");
  256. if (tensor->device_address() != nullptr) {
  257. (void)graph_info.append(
  258. std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id()) + "_");
  259. (void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format() +
  260. "_");
  261. }
  262. }
  263. // get prim and abstract info
  264. (void)graph_info.append(op_exec_info->op_name + "_");
  265. // get attr info
  266. const auto &op_prim = op_exec_info->py_primitive;
  267. MS_EXCEPTION_IF_NULL(op_prim);
  268. const auto &attr_map = op_prim->attrs();
  269. (void)std::for_each(attr_map.begin(), attr_map.end(),
  270. [&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
  271. // Add output information(shape, type id) of the operator to graph_info to solve the problem of cache missing
  272. // caused by operators like DropoutGenMask whose output is related to values of input when input shapes are
  273. // the same but values are different
  274. auto abstr = op_exec_info->abstract;
  275. MS_EXCEPTION_IF_NULL(abstr);
  276. auto build_shape = abstr->BuildShape();
  277. MS_EXCEPTION_IF_NULL(build_shape);
  278. (void)graph_info.append(build_shape->ToString() + "_");
  279. auto build_type = abstr->BuildType();
  280. MS_EXCEPTION_IF_NULL(build_type);
  281. (void)graph_info.append(std::to_string(build_type->type_id()) + "_");
  282. return graph_info;
  283. }
  284. bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
  285. const std::unordered_set<size_t> &input_attrs) {
  286. MS_EXCEPTION_IF_NULL(op_prim);
  287. auto input_names_value = op_prim->GetAttr(kAttrInputNames);
  288. if (input_names_value == nullptr) {
  289. return false;
  290. }
  291. auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
  292. if (input_index >= input_names_vec.size()) {
  293. MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
  294. }
  295. if (input_attrs.find(input_index) != input_attrs.end()) {
  296. ValuePtr value = parse::data_converter::PyDataToValue(input_object);
  297. MS_EXCEPTION_IF_NULL(value);
  298. auto input_name = input_names_vec[input_index];
  299. op_prim->AddAttr(input_name, value);
  300. return true;
  301. }
  302. return false;
  303. }
  304. void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
  305. std::vector<tensor::TensorPtr> *input_tensors) {
  306. MS_EXCEPTION_IF_NULL(op_prim);
  307. MS_EXCEPTION_IF_NULL(input_tensors);
  308. for (const auto &input_object : tuple_inputs) {
  309. if (!py::isinstance<tensor::Tensor>(input_object)) {
  310. MS_LOG(EXCEPTION) << "The input object is not a tensor!";
  311. }
  312. auto tensor = py::cast<tensor::TensorPtr>(input_object);
  313. MS_EXCEPTION_IF_NULL(tensor);
  314. input_tensors->emplace_back(tensor);
  315. }
  316. op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(tuple_inputs.size())}));
  317. }
  318. void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensors) {
  319. MS_EXCEPTION_IF_NULL(input_tensors);
  320. ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
  321. MS_EXCEPTION_IF_NULL(input_value);
  322. if (!input_value->isa<ValueTuple>()) {
  323. MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
  324. }
  325. auto value_tuple = input_value->cast<ValueTuplePtr>();
  326. MS_EXCEPTION_IF_NULL(value_tuple);
  327. tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
  328. MS_EXCEPTION_IF_NULL(tensor_ptr);
  329. input_tensors->emplace_back(tensor_ptr);
  330. }
  331. void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  332. std::vector<tensor::TensorPtr> *input_tensors, int64_t *tensor_mask) {
  333. MS_EXCEPTION_IF_NULL(op_prim);
  334. MS_EXCEPTION_IF_NULL(input_tensors);
  335. MS_EXCEPTION_IF_NULL(tensor_mask);
  336. if (!py::isinstance<py::tuple>(input_object)) {
  337. MS_LOG(EXCEPTION) << "The input should be a tuple!";
  338. }
  339. auto tuple_inputs = py::cast<py::tuple>(input_object);
  340. if (tuple_inputs.empty()) {
  341. MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
  342. }
  343. auto inputs = py::cast<py::tuple>(input_object);
  344. if (py::isinstance<tensor::Tensor>(inputs[0])) {
  345. PlantTensorTupleToVector(inputs, op_prim, input_tensors);
  346. } else {
  347. ConvertValueTupleToTensor(input_object, input_tensors);
  348. *tensor_mask = kValueNodeTensorMask;
  349. }
  350. }
  351. void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  352. std::vector<tensor::TensorPtr> *input_tensors, int64_t *tensor_mask) {
  353. MS_EXCEPTION_IF_NULL(op_prim);
  354. MS_EXCEPTION_IF_NULL(input_tensors);
  355. MS_EXCEPTION_IF_NULL(tensor_mask);
  356. tensor::TensorPtr tensor_ptr = nullptr;
  357. if (py::isinstance<tensor::Tensor>(input_object)) {
  358. tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
  359. } else if (py::isinstance<py::float_>(input_object)) {
  360. double input_value = py::cast<py::float_>(input_object);
  361. tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
  362. *tensor_mask = kValueNodeTensorMask;
  363. } else if (py::isinstance<py::int_>(input_object)) {
  364. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
  365. *tensor_mask = kValueNodeTensorMask;
  366. } else if (py::isinstance<py::array>(input_object)) {
  367. tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
  368. } else if (py::isinstance<py::list>(input_object)) {
  369. auto list_inputs = py::cast<py::list>(input_object);
  370. py::tuple tuple_inputs(list_inputs.size());
  371. for (size_t i = 0; i < tuple_inputs.size(); ++i) {
  372. tuple_inputs[i] = list_inputs[i];
  373. }
  374. ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
  375. return;
  376. } else if (py::isinstance<py::tuple>(input_object)) {
  377. ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
  378. return;
  379. } else if (py::isinstance<py::none>(input_object)) {
  380. return;
  381. } else {
  382. MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
  383. }
  384. MS_EXCEPTION_IF_NULL(tensor_ptr);
  385. input_tensors->emplace_back(tensor_ptr);
  386. }
  387. void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t> *tensors_mask,
  388. std::vector<tensor::TensorPtr> *input_tensors) {
  389. MS_EXCEPTION_IF_NULL(op_run_info);
  390. MS_EXCEPTION_IF_NULL(tensors_mask);
  391. MS_EXCEPTION_IF_NULL(input_tensors);
  392. PrimitivePtr op_prim = op_run_info->py_primitive;
  393. MS_EXCEPTION_IF_NULL(op_prim);
  394. opt::ConstInputToAttrInfoRegister reg;
  395. bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
  396. if (op_run_info->is_dynamic_shape &&
  397. dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) {
  398. MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name;
  399. reg_exist = false;
  400. }
  401. auto ms_context = MsContext::GetInstance();
  402. MS_EXCEPTION_IF_NULL(ms_context);
  403. if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
  404. if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) {
  405. reg_exist = false;
  406. }
  407. }
  408. if (op_run_info->op_name == prim::kPrimGatherD->name()) {
  409. // Gather op needs converting const input to attr on GPU device
  410. if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
  411. reg_exist = false;
  412. }
  413. }
  414. op_prim->BeginRecordAddAttr();
  415. size_t input_num = op_run_info->op_inputs.size();
  416. for (size_t index = 0; index < input_num; ++index) {
  417. // convert const input to attr
  418. if (reg_exist &&
  419. RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
  420. continue;
  421. }
  422. // convert const and tuple input to tensor
  423. int64_t tensor_mask = static_cast<int64_t>(op_run_info->inputs_mask[index]);
  424. ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
  425. // mark tensors, data : 0, weight : 1, valuenode: 2
  426. std::vector<int64_t> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
  427. tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
  428. }
  429. op_prim->EndRecordAddAttr();
  430. }
  431. void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) {
  432. MS_EXCEPTION_IF_NULL(op_run_info);
  433. PrimitivePtr op_prim = op_run_info->py_primitive;
  434. MS_EXCEPTION_IF_NULL(op_prim);
  435. std::string op_name = op_run_info->op_name;
  436. auto attrs = op_prim->attrs();
  437. for (auto attr : attrs) {
  438. bool converted = CheckAndConvertUtils::ConvertAttrValueToString(op_name, attr.first, &attr.second);
  439. if (converted) {
  440. op_prim->set_attr(attr.first, attr.second);
  441. }
  442. bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(op_name, attr.first, &attr.second);
  443. if (converted_ir_attr) {
  444. op_prim->set_attr(attr.first, attr.second);
  445. }
  446. }
  447. }
  448. BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
  449. if (utils::isa<VectorRef>(base_ref)) {
  450. auto ref_list = utils::cast<VectorRef>(base_ref);
  451. py::tuple output_tensors(ref_list.size());
  452. for (size_t i = 0; i < ref_list.size(); ++i) {
  453. auto output = TransformBaseRefListToTuple(ref_list[i]);
  454. if (utils::isa<tensor::TensorPtr>(output)) {
  455. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  456. MS_EXCEPTION_IF_NULL(tensor_ptr);
  457. output_tensors[i] = tensor_ptr;
  458. } else if (utils::isa<PyObjectRef>(output)) {
  459. py::object obj = utils::cast<PyObjectRef>(output).object_;
  460. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  461. output_tensors[i] = tensor_tuple;
  462. } else {
  463. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  464. }
  465. }
  466. return std::make_shared<PyObjectRef>(output_tensors);
  467. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  468. return base_ref;
  469. } else {
  470. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  471. }
  472. }
  473. size_t GetTupleSize(const py::tuple &args) {
  474. size_t count = 0;
  475. for (size_t i = 0; i < args.size(); i++) {
  476. if (py::isinstance<py::tuple>(args[i])) {
  477. count += GetTupleSize(args[i]);
  478. } else {
  479. count += 1;
  480. }
  481. }
  482. return count;
  483. }
  484. void ConvertTupleArg(py::tuple *res, size_t *index, const py::tuple &arg) {
  485. for (size_t i = 0; i < arg.size(); i++) {
  486. if (py::isinstance<py::tuple>(arg[i])) {
  487. ConvertTupleArg(res, index, arg[i]);
  488. } else {
  489. (*res)[(*index)++] = arg[i];
  490. }
  491. }
  492. }
  493. py::tuple ConvertArgs(const py::tuple &args) {
  494. size_t tuple_size = GetTupleSize(args);
  495. py::tuple res(tuple_size);
  496. size_t index = 0;
  497. for (size_t i = 0; i < args.size(); i++) {
  498. if (py::isinstance<py::tuple>(args[i])) {
  499. ConvertTupleArg(&res, &index, args[i]);
  500. } else {
  501. res[index++] = args[i];
  502. }
  503. }
  504. return res;
  505. }
  506. void ClearPyNativeSession() { session = nullptr; }
  507. PynativeExecutor::~PynativeExecutor() {
  508. MS_LOG(DEBUG) << "PynativeExecutor destructor";
  509. ClearRes();
  510. }
  511. void CheckPyNativeContext() {
  512. auto parallel_context = parallel::ParallelContext::GetInstance();
  513. MS_EXCEPTION_IF_NULL(parallel_context);
  514. auto ms_context = MsContext::GetInstance();
  515. MS_EXCEPTION_IF_NULL(ms_context);
  516. auto parallel_mode = parallel_context->parallel_mode();
  517. if (parallel_mode != parallel::STAND_ALONE && parallel_mode != parallel::DATA_PARALLEL &&
  518. ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
  519. MS_LOG(EXCEPTION) << "PyNative Only support STAND_ALONE and DATA_PARALLEL, but got:" << parallel_mode;
  520. }
  521. }
  522. py::object RunOp(const py::args &args) {
  523. CheckPyNativeContext();
  524. auto executor = PynativeExecutor::GetInstance();
  525. MS_EXCEPTION_IF_NULL(executor);
  526. OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args);
  527. MS_EXCEPTION_IF_NULL(op_exec_info);
  528. MS_LOG(DEBUG) << "RunOp name: " << op_exec_info->op_name << " start, args: " << args.size();
  529. try {
  530. return executor->RunOpInner(op_exec_info);
  531. } catch (const py::error_already_set &ex) {
  532. executor->ClearRes();
  533. // re-throw this exception to Python interpreter to handle it
  534. throw(py::error_already_set(ex));
  535. } catch (const py::type_error &ex) {
  536. executor->ClearRes();
  537. throw py::type_error(ex);
  538. } catch (const py::value_error &ex) {
  539. executor->ClearRes();
  540. throw py::value_error(ex);
  541. } catch (const py::index_error &ex) {
  542. executor->ClearRes();
  543. throw py::index_error(ex);
  544. } catch (const std::exception &ex) {
  545. executor->ClearRes();
  546. // re-throw this exception to Python interpreter to handle it
  547. throw(std::runtime_error(ex.what()));
  548. } catch (...) {
  549. executor->ClearRes();
  550. std::string exName(abi::__cxa_current_exception_type()->name());
  551. MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
  552. }
  553. }
  554. py::object PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
  555. MS_EXCEPTION_IF_NULL(op_exec_info);
  556. if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
  557. py::tuple ret = RunOpWithInitBackendPolicy(op_exec_info);
  558. if (ret.size() == 1) {
  559. return ret[0];
  560. }
  561. return std::move(ret);
  562. }
  563. // make cnode for building grad graph if grad flag is set.
  564. abstract::AbstractBasePtrList args_spec_list;
  565. std::vector<bool> op_masks;
  566. auto cnode = MakeCNode(op_exec_info, &op_masks, &args_spec_list);
  567. op_exec_info->inputs_mask = op_masks;
  568. // get output abstract info
  569. bool is_find = false;
  570. GetOpOutputAbstract(op_exec_info, args_spec_list, &is_find);
  571. MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString();
  572. // infer output value for const prim
  573. auto prim = op_exec_info->py_primitive;
  574. MS_EXCEPTION_IF_NULL(prim);
  575. py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
  576. if (!output["value"].is_none()) {
  577. return output["value"];
  578. }
  579. if (prim->is_const_prim()) {
  580. return py::cast("");
  581. }
  582. // add output abstract info into cache
  583. if (!is_find && !op_exec_info->is_dynamic_shape) {
  584. // const_value need infer every step
  585. auto &out = prim_abs_list_[prim->id()];
  586. out[args_spec_list].abs = op_exec_info->abstract;
  587. out[args_spec_list].attrs = prim->evaluate_added_attrs();
  588. MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
  589. }
  590. // run op with selected backend
  591. auto result = RunOpWithInitBackendPolicy(op_exec_info);
  592. py::object out_real;
  593. if (result.size() == 1 && op_exec_info->abstract != nullptr &&
  594. !op_exec_info->abstract->isa<abstract::AbstractSequeue>()) {
  595. out_real = result[0];
  596. } else {
  597. out_real = result;
  598. }
  599. // update output abstract for cnode
  600. if (cnode != nullptr) {
  601. cnode->set_abstract(op_exec_info->abstract);
  602. }
  603. std::string obj_id = GetId(out_real);
  604. node_abs_map_[obj_id] = op_exec_info->abstract;
  605. // save info for building grad graph
  606. SaveOutputNodeMap(obj_id, out_real, cnode);
  607. SaveAllResult(op_exec_info, cnode, out_real);
  608. // Update the abstract and device address of value node with tensor in grad graph
  609. UpdateAbstractAndDeviceAddress(op_exec_info, out_real);
  610. return out_real;
  611. }
  612. OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
  613. if (args.size() != PY_ARGS_NUM) {
  614. MS_LOG(ERROR) << "Three args are needed by RunOp";
  615. return nullptr;
  616. }
  617. auto op_exec_info = std::make_shared<OpExecInfo>();
  618. auto op_name = py::cast<std::string>(args[PY_NAME]);
  619. op_exec_info->op_name = op_name;
  620. if (grad_flag()) {
  621. op_exec_info->op_index = op_name + "_" + std::to_string(op_index_map_[op_name]);
  622. if (!cell_op_info_stack_.empty()) {
  623. std::string &cell_op_info = cell_op_info_stack_.top();
  624. cell_op_info += op_exec_info->op_index;
  625. }
  626. op_index_map_[op_name]++;
  627. }
  628. auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
  629. MS_EXCEPTION_IF_NULL(prim);
  630. if (!prim->HasPyObj()) {
  631. MS_LOG(EXCEPTION) << "Pyobj is empty";
  632. }
  633. op_exec_info->py_primitive = prim;
  634. op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
  635. op_exec_info->op_inputs = args[PY_INPUTS];
  636. return op_exec_info;
  637. }
  638. AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
  639. abstract::AbstractBasePtrList *args_spec_list) {
  640. MS_EXCEPTION_IF_NULL(op_masks);
  641. MS_EXCEPTION_IF_NULL(args_spec_list);
  642. MS_EXCEPTION_IF_NULL(op_exec_info);
  643. auto prim = op_exec_info->py_primitive;
  644. std::vector<AnfNodePtr> inputs;
  645. inputs.emplace_back(NewValueNode(prim));
  646. const auto &signature = prim->signatures();
  647. auto sig_size = signature.size();
  648. auto size = op_exec_info->op_inputs.size();
  649. // ignore monad signature
  650. for (auto sig : signature) {
  651. if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
  652. --sig_size;
  653. }
  654. }
  655. if (sig_size > 0 && sig_size != size) {
  656. MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
  657. << "inputs size " << sig_size;
  658. }
  659. if (op_exec_info->op_name != prim::kPrimCast->name()) {
  660. RunParameterAutoMixPrecisionCast(op_exec_info);
  661. }
  662. MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag();
  663. for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
  664. abstract::AbstractBasePtr abs = nullptr;
  665. const auto &obj = op_exec_info->op_inputs[i];
  666. auto id = GetId(obj);
  667. auto it = node_abs_map_.find(id);
  668. if (it != node_abs_map_.end()) {
  669. abs = it->second;
  670. }
  671. bool op_mask = false;
  672. if (py::isinstance<tensor::MetaTensor>(obj)) {
  673. auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
  674. if (meta_tensor) {
  675. op_mask = meta_tensor->is_parameter();
  676. }
  677. }
  678. MS_LOG(DEBUG) << "Gen args i " << i << " op_mask " << op_mask;
  679. (*op_masks).emplace_back(op_mask);
  680. if (need_construct_graph()) {
  681. AnfNodePtr input_node = nullptr;
  682. if (!graph_info_map_.empty() && !top_cell_list_.empty()) {
  683. input_node = GetInput(obj, op_mask);
  684. }
  685. // update abstract
  686. if (input_node != nullptr) {
  687. if (input_node->abstract() != nullptr) {
  688. abs = input_node->abstract();
  689. }
  690. inputs.emplace_back(input_node);
  691. }
  692. }
  693. auto const_input_index = prim->get_const_input_indexes();
  694. bool have_const_input = !const_input_index.empty();
  695. bool is_const_prim = prim->is_const_prim();
  696. MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
  697. << prim->is_const_prim();
  698. bool is_const_input =
  699. have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
  700. if (abs == nullptr || is_const_prim || is_const_input) {
  701. MS_LOG(DEBUG) << "MakeCnode get node no in map " << id;
  702. ValuePtr input_value = PyAttrValue(obj);
  703. abs = input_value->ToAbstract();
  704. if (!is_const_prim && !is_const_input) {
  705. auto config = abstract::AbstractBase::kBroadenTensorOnly;
  706. abs = abs->Broaden(config);
  707. MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
  708. }
  709. node_abs_map_[id] = abs;
  710. }
  711. (*args_spec_list).emplace_back(abs);
  712. }
  713. CNodePtr cnode = nullptr;
  714. if (need_construct_graph()) {
  715. MS_EXCEPTION_IF_NULL(curr_g_);
  716. cnode = curr_g_->NewCNodeInOrder(inputs);
  717. MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << " new cnode is " << cnode->DebugString(4);
  718. }
  719. return cnode;
  720. }
  721. void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
  722. const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) {
  723. MS_EXCEPTION_IF_NULL(is_find);
  724. MS_EXCEPTION_IF_NULL(op_exec_info);
  725. *is_find = false;
  726. auto op_name = op_exec_info->op_name;
  727. auto prim = op_exec_info->py_primitive;
  728. MS_EXCEPTION_IF_NULL(prim);
  729. if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
  730. auto abs_list = prim_abs_list_[prim->id()];
  731. MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
  732. if (abs_list.find(args_spec_list) != abs_list.end()) {
  733. MS_LOG(DEBUG) << "Match prim ok " << op_name;
  734. op_exec_info->abstract = abs_list[args_spec_list].abs;
  735. prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
  736. *is_find = true;
  737. }
  738. }
  739. if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) {
  740. // use python infer method
  741. if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) {
  742. PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);
  743. }
  744. }
  745. // get output dynamic shape info
  746. auto py_abstract = op_exec_info->abstract;
  747. MS_EXCEPTION_IF_NULL(py_abstract);
  748. auto py_shape = py_abstract->BuildShape();
  749. MS_EXCEPTION_IF_NULL(py_shape);
  750. auto py_shape_info = py_shape->ToString();
  751. if (py_shape_info.find("-1") != string::npos) {
  752. auto c_abstract = abstract::CppInferShape(prim, args_spec_list);
  753. MS_EXCEPTION_IF_NULL(c_abstract);
  754. auto c_shape = c_abstract->BuildShape();
  755. MS_EXCEPTION_IF_NULL(c_shape);
  756. auto c_shape_info = c_shape->ToString();
  757. MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info;
  758. if (c_shape_info.find("-1") != string::npos) {
  759. op_exec_info->is_dynamic_shape = true;
  760. }
  761. }
  762. }
  763. py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
  764. size_t index) {
  765. py::tuple cast_args(3);
  766. cast_args[PY_PRIM] = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
  767. cast_args[PY_NAME] = prim::kPrimCast->name();
  768. std::string dst_type_str = TypeIdToMsTypeStr(type_id);
  769. py::object dst_type = parse::python_adapter::GetPyFn(kMSDtypeModelName, dst_type_str);
  770. py::tuple inputs(2);
  771. inputs[0] = arg;
  772. inputs[1] = dst_type;
  773. cast_args[PY_INPUTS] = inputs;
  774. auto op_exec = GenerateOpExecInfo(cast_args);
  775. op_exec->is_mixed_precision_cast = true;
  776. op_exec->next_op_name = op_name;
  777. op_exec->next_input_index = index;
  778. return RunOpInner(op_exec);
  779. }
  780. py::object PynativeExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name,
  781. size_t index) {
  782. MS_EXCEPTION_IF_NULL(is_cast);
  783. auto tensor = py::cast<tensor::TensorPtr>(obj);
  784. auto cast_type = tensor->cast_dtype();
  785. py::object cast_output = obj;
  786. if (cast_type != nullptr) {
  787. auto source_element = tensor->Dtype();
  788. if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
  789. MS_LOG(DEBUG) << "Cast to " << cast_type->ToString();
  790. *is_cast = true;
  791. return DoAutoCast(obj, cast_type->type_id(), op_name, index);
  792. }
  793. }
  794. return cast_output;
  795. }
  796. py::object PynativeExecutor::DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple,
  797. const std::string &op_name, size_t index) {
  798. MS_EXCEPTION_IF_NULL(is_cast);
  799. auto tuple_size = static_cast<int64_t>(tuple.size());
  800. py::tuple result(tuple_size);
  801. for (int64_t i = 0; i < tuple_size; i++) {
  802. if (py::isinstance<tensor::MetaTensor>(tuple[i])) {
  803. MS_LOG(DEBUG) << "Call cast for item " << i;
  804. result[i] = DoParamMixPrecisionCast(is_cast, tuple[i], op_name, index);
  805. } else if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) {
  806. result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i], op_name, index);
  807. } else {
  808. result[i] = tuple[i];
  809. }
  810. }
  811. return std::move(result);
  812. }
  813. void PynativeExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
  814. const std::vector<SignatureEnumDType> &dtypes,
  815. const OpExecInfoPtr &op_exec_info) {
  816. const auto &signature = prim->signatures();
  817. auto &out_args = op_exec_info->op_inputs;
  818. for (size_t i = 0; i < out_args.size(); ++i) {
  819. // No need to implicit cast if no dtype.
  820. if (dtypes.empty() || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
  821. continue;
  822. }
  823. auto it = dst_type.find(dtypes[i]);
  824. if (it == dst_type.end() || it->second == kTypeUnknown) {
  825. continue;
  826. }
  827. MS_LOG(DEBUG) << "Check inputs " << i;
  828. auto obj = out_args[i];
  829. auto sig = SignatureEnumRW::kRWDefault;
  830. if (!signature.empty()) {
  831. sig = signature[i].rw;
  832. }
  833. bool is_parameter = false;
  834. TypeId arg_type_id = kTypeUnknown;
  835. if (py::isinstance<tensor::MetaTensor>(obj)) {
  836. auto arg = py::cast<tensor::MetaTensorPtr>(obj);
  837. if (arg->is_parameter()) {
  838. is_parameter = true;
  839. MS_LOG(DEBUG) << "Parameter is read " << i;
  840. }
  841. arg_type_id = arg->data_type();
  842. }
  843. // implicit cast
  844. bool is_same_type = false;
  845. if (arg_type_id != kTypeUnknown) {
  846. is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
  847. }
  848. if (sig == SignatureEnumRW::kRWWrite) {
  849. if (!is_parameter) {
  850. prim::RaiseExceptionForCheckParameter(prim->name(), i, "not");
  851. }
  852. if (arg_type_id != kTypeUnknown) {
  853. if (!is_same_type) {
  854. prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id),
  855. TypeIdToMsTypeStr(it->second));
  856. }
  857. }
  858. }
  859. if (is_same_type) {
  860. continue;
  861. }
  862. if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) {
  863. MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i
  864. << "th input is a not support implicit conversion type: "
  865. << py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is "
  866. << py::cast<py::str>(obj) << ".";
  867. }
  868. py::object cast_output = DoAutoCast(out_args[i], it->second, op_exec_info->op_name, i);
  869. out_args[i] = cast_output;
  870. }
  871. }
  872. void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info) {
  873. size_t size = op_exec_info->op_inputs.size();
  874. auto prim = op_exec_info->py_primitive;
  875. MS_EXCEPTION_IF_NULL(prim);
  876. const auto &signature = prim->signatures();
  877. for (size_t i = 0; i < size; i++) {
  878. auto obj = op_exec_info->op_inputs[i];
  879. auto sig = SignatureEnumRW::kRWDefault;
  880. if (!signature.empty()) {
  881. sig = signature[i].rw;
  882. }
  883. MS_LOG(DEBUG) << "Check mix precision " << op_exec_info->op_name << " input " << i << " "
  884. << std::string(py::repr(obj));
  885. // mix precision for non param
  886. bool is_cast = false;
  887. py::object cast_output;
  888. if (py::isinstance<tensor::MetaTensor>(obj)) {
  889. auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
  890. if (meta_tensor && meta_tensor->is_parameter()) {
  891. if (sig != SignatureEnumRW::kRWRead) {
  892. continue;
  893. }
  894. }
  895. // redundant cast call if the tensor is a const Tensor.
  896. cast_output = DoParamMixPrecisionCast(&is_cast, obj, prim->name(), i);
  897. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  898. // mix precision for tuple inputs
  899. cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj, prim->name(), i);
  900. }
  901. if (is_cast) {
  902. op_exec_info->op_inputs[i] = cast_output;
  903. }
  904. }
  905. std::vector<SignatureEnumDType> dtypes;
  906. bool has_dtype_sig = GetSignatureType(prim, &dtypes);
  907. std::map<SignatureEnumDType, TypeId> dst_types;
  908. if (has_dtype_sig) {
  909. // fetch info for implicit cast
  910. auto type_indexes = GetTypeIndex(dtypes);
  911. dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
  912. }
  913. MS_LOG(DEBUG) << "Do signature for " << op_exec_info->op_name;
  914. DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
  915. }
  916. AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
  917. AnfNodePtr node = nullptr;
  918. std::string obj_id = GetId(obj);
  919. if (op_mask) {
  920. MS_LOG(DEBUG) << "Cell parameters(weights)";
  921. // get the parameter name from parameter object
  922. auto name_attr = parse::python_adapter::GetPyObjAttr(obj, "name");
  923. if (py::isinstance<py::none>(name_attr)) {
  924. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  925. }
  926. auto param_name = py::cast<std::string>(name_attr);
  927. auto df_builder = GetDfbuilder(top_cell_id_);
  928. MS_EXCEPTION_IF_NULL(df_builder);
  929. auto graph_info = graph_info_map_.at(df_builder);
  930. MS_EXCEPTION_IF_NULL(graph_info);
  931. if (graph_info->params.find(obj_id) == graph_info->params.end()) {
  932. auto free_param = df_builder->add_parameter();
  933. free_param->set_name(param_name);
  934. free_param->debug_info()->set_name(param_name);
  935. auto value = py::cast<tensor::TensorPtr>(obj);
  936. free_param->set_default_param(value);
  937. MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
  938. SetParamNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
  939. SetParamNodeMapInGraphInfoMap(curr_g_, obj_id, free_param);
  940. SetNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
  941. SetNodeMapInGraphInfoMap(curr_g_, obj_id, free_param);
  942. return free_param;
  943. }
  944. node = graph_info->node_map.at(obj_id).first;
  945. MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id;
  946. return node;
  947. }
  948. auto graph_info = graph_info_map_.at(curr_g_);
  949. MS_EXCEPTION_IF_NULL(graph_info);
  950. if (graph_info->node_map.find(obj_id) != graph_info->node_map.end()) {
  951. // op(x, y)
  952. // out = op(op1(x, y))
  953. // out = op(cell1(x, y))
  954. // out = op(cell1(x, y)[0])
  955. node = GetObjNode(obj, obj_id);
  956. } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
  957. // out = op((x, y))
  958. // out = cell((x, y))
  959. auto tuple = obj.cast<py::tuple>();
  960. // cell((1,2)): support not mix (scalar, tensor)
  961. if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) {
  962. return MakeValueNode(obj, obj_id);
  963. }
  964. std::vector<AnfNodePtr> args;
  965. args.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  966. auto tuple_size = tuple.size();
  967. for (size_t i = 0; i < tuple_size; i++) {
  968. args.emplace_back(GetInput(tuple[i], false));
  969. }
  970. auto cnode = curr_g_->NewCNode(args);
  971. SetNodeMapInGraphInfoMap(curr_g_, GetId(obj), cnode);
  972. node = cnode;
  973. } else {
  974. node = MakeValueNode(obj, obj_id);
  975. }
  976. node == nullptr ? MS_LOG(DEBUG) << "Get node is nullptr"
  977. : MS_LOG(DEBUG) << "Get input node " << node->ToString() << " " << obj_id;
  978. return node;
  979. }
  980. void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real) {
  981. MS_EXCEPTION_IF_NULL(op_exec_info);
  982. if (!grad_flag()) {
  983. return;
  984. }
  985. auto op_index = op_exec_info->op_index;
  986. auto output_value = PyAttrValue(out_real);
  987. MS_EXCEPTION_IF_NULL(output_value);
  988. std::vector<tensor::TensorPtr> output_tensors;
  989. TensorValueToTensor(output_value, &output_tensors);
  990. if (cell_op_index_with_tensor_id_[top_cell_id_].find(op_index) == cell_op_index_with_tensor_id_[top_cell_id_].end()) {
  991. // first step
  992. std::for_each(output_tensors.begin(), output_tensors.end(), [&](const tensor::TensorPtr &tensor) {
  993. cell_op_index_with_tensor_id_[top_cell_id_][op_index].emplace_back(tensor->id());
  994. });
  995. return;
  996. }
  997. auto ms_context = MsContext::GetInstance();
  998. auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  999. const auto &tensor_id_list = cell_op_index_with_tensor_id_[top_cell_id_][op_index];
  1000. for (size_t i = 0; i < tensor_id_list.size(); ++i) {
  1001. auto tensor_id = tensor_id_list[i];
  1002. if (cell_tensor_id_with_tensor_[top_cell_id_].find(tensor_id) != cell_tensor_id_with_tensor_[top_cell_id_].end()) {
  1003. auto &new_tensor = output_tensors[i];
  1004. auto &tensors_in_value_node = cell_tensor_id_with_tensor_[top_cell_id_][tensor_id];
  1005. std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) {
  1006. MS_LOG(DEBUG) << "Debug address: Replace forward old tensor obj " << tensor.get() << ", tensor id "
  1007. << tensor->id() << ", device address " << tensor->device_address().get()
  1008. << " with New tensor obj " << new_tensor.get() << ", tensor id " << new_tensor->id()
  1009. << ", device address " << new_tensor->device_address().get();
  1010. tensor->set_shape(new_tensor->shape());
  1011. tensor->set_data_type(new_tensor->data_type());
  1012. if (target != kCPUDevice) {
  1013. tensor->set_device_address(new_tensor->device_address());
  1014. } else {
  1015. auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
  1016. auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
  1017. auto old_ptr = old_device_address->GetMutablePtr();
  1018. auto new_ptr = new_device_address->GetPtr();
  1019. MS_EXCEPTION_IF_NULL(old_ptr);
  1020. MS_EXCEPTION_IF_NULL(new_ptr);
  1021. auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize());
  1022. if (ret != EOK) {
  1023. MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret;
  1024. }
  1025. }
  1026. });
  1027. }
  1028. }
  1029. }
  1030. void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) {
  1031. MS_EXCEPTION_IF_NULL(resource);
  1032. std::set<std::string> forward_op_tensor_id;
  1033. for (const auto &elem : cell_op_index_with_tensor_id_[top_cell_id_]) {
  1034. const auto &tensor_id_list = elem.second;
  1035. for (const auto &tensor_id : tensor_id_list) {
  1036. forward_op_tensor_id.emplace(tensor_id);
  1037. }
  1038. }
  1039. cell_tensor_id_with_tensor_[top_cell_id_].clear();
  1040. const auto &func_graph = resource->func_graph();
  1041. const auto &value_node_list = func_graph->value_nodes();
  1042. for (const auto &elem : value_node_list) {
  1043. auto value_node = elem.first->cast<ValueNodePtr>();
  1044. MS_EXCEPTION_IF_NULL(value_node);
  1045. std::vector<tensor::TensorPtr> tensors;
  1046. TensorValueToTensor(value_node->value(), &tensors);
  1047. for (const auto &tensor : tensors) {
  1048. if (tensor->device_address() != nullptr &&
  1049. forward_op_tensor_id.find(tensor->id()) != forward_op_tensor_id.end()) {
  1050. cell_tensor_id_with_tensor_[top_cell_id_][tensor->id()].emplace_back(tensor);
  1051. MS_LOG(DEBUG) << "Debug address: Save forward tensor obj " << tensor.get() << ", tensor id " << tensor->id()
  1052. << ", device address " << tensor->device_address().get();
  1053. }
  1054. }
  1055. }
  1056. }
  1057. void PynativeExecutor::CleanPreMemoryInValueNode() {
  1058. auto ms_context = MsContext::GetInstance();
  1059. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  1060. if (device_target == "CPU") {
  1061. return;
  1062. }
  1063. if (has_dynamic_cell_) {
  1064. std::set<std::string> forward_op_tensor_id;
  1065. for (const auto &elem : cell_op_index_with_tensor_id_[top_cell_id_]) {
  1066. const auto &tensor_id_list = elem.second;
  1067. for (const auto &tensor_id : tensor_id_list) {
  1068. forward_op_tensor_id.emplace(tensor_id);
  1069. }
  1070. }
  1071. for (auto &tensor : all_value_node_tensors_) {
  1072. if (tensor->device_address() != nullptr &&
  1073. forward_op_tensor_id.find(tensor->id()) != forward_op_tensor_id.end()) {
  1074. tensor->device_address()->ClearDeviceMemory();
  1075. tensor->set_device_address(nullptr);
  1076. }
  1077. }
  1078. all_value_node_tensors_.clear();
  1079. }
  1080. const auto &tensor_id_with_tensor = cell_tensor_id_with_tensor_[top_cell_id_];
  1081. for (const auto &elem : tensor_id_with_tensor) {
  1082. const auto &tensors_in_value_node = elem.second;
  1083. for (const auto &tensor : tensors_in_value_node) {
  1084. MS_EXCEPTION_IF_NULL(tensor);
  1085. tensor->set_device_address(nullptr);
  1086. }
  1087. }
  1088. }
  1089. AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
  1090. auto graph_info = graph_info_map_.at(curr_g_);
  1091. MS_EXCEPTION_IF_NULL(graph_info);
  1092. auto &out = graph_info->node_map.at(obj_id);
  1093. if (out.second.size() == 1 && out.second[0] == -1) {
  1094. return out.first;
  1095. }
  1096. MS_LOG(DEBUG) << "Output size " << out.second.size();
  1097. // Params node
  1098. if (graph_info->params.find(obj_id) != graph_info->params.end()) {
  1099. auto para_node = out.first;
  1100. for (auto &idx : out.second) {
  1101. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node,
  1102. NewValueNode(idx)};
  1103. para_node = curr_g_->NewCNode(tuple_get_item_inputs);
  1104. }
  1105. return para_node;
  1106. }
  1107. // Normal node
  1108. auto node = out.first->cast<CNodePtr>();
  1109. auto abs = node->abstract();
  1110. ValuePtr out_obj = nullptr;
  1111. if (node->forward().first != nullptr) {
  1112. out_obj = node->forward().first;
  1113. } else {
  1114. out_obj = PyAttrValue(obj);
  1115. }
  1116. for (auto &idx : out.second) {
  1117. std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
  1118. node = curr_g_->NewCNode(tuple_get_item_inputs);
  1119. if (out_obj->isa<ValueTuple>()) {
  1120. node->add_input_value(out_obj, "");
  1121. node->add_input_value(MakeValue(idx), "");
  1122. out_obj = (*out_obj->cast<ValueTuplePtr>())[idx];
  1123. node->set_forward(out_obj, "");
  1124. }
  1125. if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
  1126. auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[idx];
  1127. MS_LOG(DEBUG) << "Set tuple getitem abs " << prim_abs->ToString();
  1128. node->set_abstract(prim_abs);
  1129. }
  1130. }
  1131. if (node->abstract() != nullptr) {
  1132. node_abs_map_[obj_id] = node->abstract();
  1133. }
  1134. MS_LOG(DEBUG) << "GetObjNode output " << node->DebugString(6);
  1135. return node;
  1136. }
  1137. AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) {
  1138. ValuePtr converted_ret = nullptr;
  1139. parse::ConvertData(obj, &converted_ret);
  1140. auto node = NewValueNode(converted_ret);
  1141. SetNodeMapInGraphInfoMap(curr_g_, obj_id, node);
  1142. return node;
  1143. }
  1144. void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real,
  1145. const AnfNodePtr &cnode) {
  1146. if (!need_construct_graph()) {
  1147. MS_LOG(DEBUG) << "No need save output";
  1148. return;
  1149. }
  1150. MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString(4) << " id " << obj_id;
  1151. if (py::isinstance<py::tuple>(out_real)) {
  1152. auto value = py::cast<py::tuple>(out_real);
  1153. auto size = static_cast<int64_t>(value.size());
  1154. if (size > 1) {
  1155. for (int64_t i = 0; i < size; ++i) {
  1156. auto value_id = GetId(value[i]);
  1157. SetNodeMapInGraphInfoMap(curr_g_, value_id, cnode, i);
  1158. }
  1159. }
  1160. }
  1161. SetNodeMapInGraphInfoMap(curr_g_, obj_id, cnode);
  1162. SetPyObjInGraphInfoMap(curr_g_, obj_id);
  1163. }
  1164. void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node,
  1165. const py::object &out_real) {
  1166. if (!grad_flag() || node == nullptr) {
  1167. return;
  1168. }
  1169. MS_EXCEPTION_IF_NULL(op_exec_info);
  1170. auto cnode = node->cast<CNodePtr>();
  1171. MS_EXCEPTION_IF_NULL(cnode);
  1172. // save input object
  1173. size_t size = op_exec_info->op_inputs.size();
  1174. for (size_t i = 0; i < size; i++) {
  1175. auto obj = op_exec_info->op_inputs[i];
  1176. auto obj_id = GetId(obj);
  1177. if (obj_to_forward_id_.find(obj_id) != obj_to_forward_id_.end()) {
  1178. cnode->add_input_value(PyAttrValue(obj), obj_to_forward_id_[obj_id]);
  1179. } else {
  1180. cnode->add_input_value(nullptr, "");
  1181. }
  1182. }
  1183. // save output object
  1184. auto output_value = PyAttrValue(out_real);
  1185. MS_EXCEPTION_IF_NULL(output_value);
  1186. cnode->set_forward(output_value, op_exec_info->op_index);
  1187. auto out_id = GetId(out_real);
  1188. if (py::isinstance<py::tuple>(out_real)) {
  1189. auto tuple_item = py::cast<py::tuple>(out_real);
  1190. for (size_t i = 0; i < tuple_item.size(); i++) {
  1191. auto tuple_item_id = GetId(tuple_item[i]);
  1192. obj_to_forward_id_[tuple_item_id] = op_exec_info->op_index;
  1193. }
  1194. }
  1195. obj_to_forward_id_[out_id] = op_exec_info->op_index;
  1196. }
  1197. void PynativeExecutor::GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map) {
  1198. if (t_map == nullptr) {
  1199. return;
  1200. }
  1201. for (size_t i = 0; i < tuple->size(); i++) {
  1202. ValuePtr tuple_i = (*tuple)[i];
  1203. if (tuple_i->isa<tensor::Tensor>()) {
  1204. auto t = tuple_i->cast<tensor::TensorPtr>();
  1205. (*t_map)[t->id()] = t;
  1206. } else if (tuple_i->isa<ValueTuple>()) {
  1207. GenTupleMap(tuple_i->cast<ValueTuplePtr>(), t_map);
  1208. }
  1209. }
  1210. MS_LOG(DEBUG) << "End GenTupleMap " << tuple->ToString();
  1211. }
  1212. ValuePtr PynativeExecutor::CleanTupleAddr(const ValueTuplePtr &tuple) {
  1213. std::vector<ValuePtr> value_list;
  1214. for (size_t i = 0; i < tuple->size(); i++) {
  1215. ValuePtr tuple_i = (*tuple)[i];
  1216. if (tuple_i->isa<tensor::Tensor>()) {
  1217. auto t = tuple_i->cast<tensor::TensorPtr>();
  1218. auto new_tensor = std::make_shared<tensor::Tensor>(*t);
  1219. new_tensor->set_device_address(nullptr);
  1220. value_list.emplace_back(new_tensor);
  1221. } else if (tuple_i->isa<ValueTuple>()) {
  1222. value_list.emplace_back(CleanTupleAddr(tuple_i->cast<ValueTuplePtr>()));
  1223. } else {
  1224. MS_LOG(DEBUG) << "Tuple[i] value " << tuple_i->ToString();
  1225. value_list.emplace_back(tuple_i);
  1226. }
  1227. }
  1228. MS_LOG(DEBUG) << "End CleanTupleAddr";
  1229. return std::make_shared<ValueTuple>(value_list);
  1230. }
  1231. py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) {
  1232. auto backend_policy = InitEnv(op_exec_info);
  1233. PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
  1234. // returns a null py::tuple on error
  1235. py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
  1236. if (status != PYNATIVE_SUCCESS) {
  1237. MS_LOG(EXCEPTION) << "Failed to run " << op_exec_info->op_name;
  1238. }
  1239. MS_LOG(DEBUG) << "RunOp end";
  1240. return result;
  1241. }
  1242. MsBackendPolicy PynativeExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) {
  1243. MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
  1244. parse::python_adapter::set_python_env_flag(true);
  1245. MsBackendPolicy backend_policy;
  1246. #if (!defined ENABLE_GE)
  1247. auto ms_context = MsContext::GetInstance();
  1248. MS_EXCEPTION_IF_NULL(ms_context);
  1249. if (!context::IsTsdOpened(ms_context)) {
  1250. if (!context::OpenTsd(ms_context)) {
  1251. MS_LOG(EXCEPTION) << "Open tsd failed";
  1252. }
  1253. }
  1254. if (ms_context->backend_policy() == "ms") {
  1255. backend_policy = kMsBackendMsPrior;
  1256. } else {
  1257. backend_policy = kMsBackendVmOnly;
  1258. }
  1259. #else
  1260. auto ms_context = MsContext::GetInstance();
  1261. MS_EXCEPTION_IF_NULL(ms_context);
  1262. context::PynativeInitGe(ms_context);
  1263. backend_policy = kMsBackendGeOnly;
  1264. #endif
  1265. if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
  1266. backend_policy = kMsBackendVmOnly;
  1267. }
  1268. return backend_policy;
  1269. }
  1270. py::object PynativeExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
  1271. PynativeStatusCode *const status) {
  1272. MS_EXCEPTION_IF_NULL(status);
  1273. py::object result;
  1274. switch (backend_policy) {
  1275. case kMsBackendVmOnly: {
  1276. // use vm only
  1277. MS_LOG(INFO) << "RunOp use VM only backend";
  1278. result = RunOpInVM(op_exec_info, status);
  1279. break;
  1280. }
  1281. case kMsBackendGePrior: {
  1282. #ifdef ENABLE_GE
  1283. // use GE first, use vm when GE fails
  1284. MS_LOG(INFO) << "RunOp use GE first backend";
  1285. result = RunOpInGE(op_exec_info, status);
  1286. if (*status != PYNATIVE_SUCCESS) {
  1287. result = RunOpInVM(op_exec_info, status);
  1288. }
  1289. #endif
  1290. break;
  1291. }
  1292. case kMsBackendMsPrior: {
  1293. // use Ms first,use others when ms failed
  1294. MS_LOG(INFO) << "RunOp use Ms first backend";
  1295. result = RunOpInMs(op_exec_info, status);
  1296. if (*status != PYNATIVE_SUCCESS) {
  1297. MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
  1298. }
  1299. break;
  1300. }
  1301. default:
  1302. MS_LOG(ERROR) << "No backend configured for run op";
  1303. }
  1304. return result;
  1305. }
  1306. py::object PynativeExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  1307. MS_LOG(INFO) << "RunOpInVM start";
  1308. MS_EXCEPTION_IF_NULL(status);
  1309. MS_EXCEPTION_IF_NULL(op_exec_info);
  1310. MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
  1311. auto &op_inputs = op_exec_info->op_inputs;
  1312. if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf" ||
  1313. op_exec_info->op_name == "stop_gradient") {
  1314. py::tuple result(op_inputs.size());
  1315. for (size_t i = 0; i < op_inputs.size(); i++) {
  1316. py::object input = op_inputs[i];
  1317. auto input_obj_id = GetId(input);
  1318. auto tensor = py::cast<tensor::TensorPtr>(input);
  1319. MS_EXCEPTION_IF_NULL(tensor);
  1320. if (obj_to_forward_id_.find(input_obj_id) == obj_to_forward_id_.end() &&
  1321. op_exec_info->op_name == "HookBackward") {
  1322. // the input object is not a output of forward cnode, eg: parameter
  1323. result[i] = tensor;
  1324. } else {
  1325. // the input object is a output of forward cnode
  1326. auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
  1327. new_tensor->set_device_address(tensor->device_address());
  1328. new_tensor->set_sync_status(tensor->sync_status());
  1329. result[i] = new_tensor;
  1330. }
  1331. }
  1332. *status = PYNATIVE_SUCCESS;
  1333. MS_LOG(INFO) << "RunOpInVM end";
  1334. return std::move(result);
  1335. }
  1336. auto primitive = op_exec_info->py_primitive;
  1337. MS_EXCEPTION_IF_NULL(primitive);
  1338. auto result = primitive->RunPyComputeFunction(op_inputs);
  1339. MS_LOG(INFO) << "RunOpInVM end";
  1340. if (py::isinstance<py::none>(result)) {
  1341. MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
  1342. *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
  1343. py::tuple err_ret(0);
  1344. return std::move(err_ret);
  1345. }
  1346. *status = PYNATIVE_SUCCESS;
  1347. if (py::isinstance<py::tuple>(result)) {
  1348. return result;
  1349. }
  1350. py::tuple tuple_result = py::make_tuple(result);
  1351. return std::move(tuple_result);
  1352. }
  1353. py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  1354. MS_EXCEPTION_IF_NULL(op_exec_info);
  1355. MS_EXCEPTION_IF_NULL(status);
  1356. MS_LOG(INFO) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms";
  1357. auto ms_context = MsContext::GetInstance();
  1358. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
  1359. if (session == nullptr) {
  1360. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  1361. session = session::SessionFactory::Get().Create(device_target);
  1362. MS_EXCEPTION_IF_NULL(session);
  1363. session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));
  1364. }
  1365. std::vector<tensor::TensorPtr> input_tensors;
  1366. std::vector<int64_t> tensors_mask;
  1367. ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
  1368. ConvertAttrToUnifyMindIR(op_exec_info);
  1369. // get graph info for checking it whether existing in the cache
  1370. std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
  1371. #if defined(__APPLE__)
  1372. session::OpRunInfo op_run_info = {op_exec_info->op_name,
  1373. op_exec_info->py_primitive,
  1374. op_exec_info->abstract,
  1375. op_exec_info->is_dynamic_shape,
  1376. op_exec_info->is_mixed_precision_cast,
  1377. op_exec_info->next_op_name,
  1378. static_cast<int>(op_exec_info->next_input_index)};
  1379. #else
  1380. session::OpRunInfo op_run_info = {op_exec_info->op_name,
  1381. op_exec_info->py_primitive,
  1382. op_exec_info->abstract,
  1383. op_exec_info->is_dynamic_shape,
  1384. op_exec_info->is_mixed_precision_cast,
  1385. op_exec_info->next_op_name,
  1386. op_exec_info->next_input_index};
  1387. #endif
  1388. VectorRef outputs;
  1389. session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask);
  1390. if (op_exec_info->is_dynamic_shape) {
  1391. op_exec_info->abstract = op_run_info.abstract;
  1392. }
  1393. auto result = BaseRefToPyData(outputs);
  1394. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
  1395. *status = PYNATIVE_SUCCESS;
  1396. MS_LOG(INFO) << "End run op [" << op_exec_info->op_name << "] with backend policy ms";
  1397. return result;
  1398. }
  1399. void PynativeExecutor::PushCurrentGraphToStack() { graph_stack_.push(curr_g_); }
  1400. void PynativeExecutor::PushCurrentCellOpInfoToStack() {
  1401. std::string cell_op_info = "Cell ops: ";
  1402. cell_op_info_stack_.push(cell_op_info);
  1403. }
  1404. void PynativeExecutor::PopGraphStack() {
  1405. if (graph_stack_.empty()) {
  1406. MS_LOG(EXCEPTION) << "Stack graph_stack_ is empty";
  1407. }
  1408. graph_stack_.pop();
  1409. if (!graph_stack_.empty()) {
  1410. curr_g_ = graph_stack_.top();
  1411. }
  1412. }
  1413. void PynativeExecutor::PopCurrentCellOpInfoFromStack() {
  1414. if (cell_op_info_stack_.empty()) {
  1415. MS_LOG(EXCEPTION) << "The cell op info stack is empty";
  1416. }
  1417. cell_op_info_stack_.pop();
  1418. }
  1419. std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) {
  1420. auto cell_id = GetId(cell);
  1421. for (size_t i = 0; i < args.size(); i++) {
  1422. std::string arg_id = GetId(args[i]);
  1423. auto it = node_abs_map_.find(arg_id);
  1424. if (it != node_abs_map_.end()) {
  1425. cell_id += "_" + it->second->BuildShape()->ToString();
  1426. cell_id += it->second->BuildType()->ToString();
  1427. } else {
  1428. auto abs = PyAttrValue(args[i])->ToAbstract();
  1429. auto config = abstract::AbstractBase::kBroadenTensorOnly;
  1430. abs = abs->Broaden(config);
  1431. node_abs_map_[arg_id] = abs;
  1432. cell_id += "_" + abs->BuildShape()->ToString();
  1433. cell_id += abs->BuildType()->ToString();
  1434. }
  1435. }
  1436. return GetTensorCellId(cell_id);
  1437. }
  1438. std::string PynativeExecutor::GetTensorCellId(const std::string &cell_id) {
  1439. if (cell_id.find("NoShape") == std::string::npos) {
  1440. return cell_id;
  1441. }
  1442. std::string key = cell_id.substr(0, PTR_LEN);
  1443. auto fn = [](const std::string &str, std::vector<std::string> &value) {
  1444. size_t pos = 0;
  1445. size_t pre_pos = 0;
  1446. while ((pos = str.find_first_of('_', pre_pos)) != std::string::npos) {
  1447. value.emplace_back(str.substr(pre_pos, pos - pre_pos + 1));
  1448. pre_pos = pos + 1;
  1449. }
  1450. value.emplace_back(str.substr(pre_pos));
  1451. };
  1452. auto it =
  1453. std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&key](const CellInfoPtr &value) {
  1454. return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos;
  1455. });
  1456. if (it != cell_graph_list_.end()) {
  1457. std::vector<std::string> pre_cell_id;
  1458. std::vector<std::string> cur_cell_id;
  1459. fn((*it)->cell_id, pre_cell_id);
  1460. fn(cell_id, cur_cell_id);
  1461. auto pre_tensor_size = pre_cell_id.size();
  1462. if (pre_tensor_size == cur_cell_id.size()) {
  1463. size_t same_tensor_count = 0;
  1464. for (size_t i = 0; i < pre_tensor_size; ++i) {
  1465. if (cur_cell_id[i].find("NoShape") != std::string::npos || cur_cell_id[i] == pre_cell_id[i]) {
  1466. ++same_tensor_count;
  1467. }
  1468. }
  1469. if (same_tensor_count == pre_tensor_size) {
  1470. MS_LOG(DEBUG) << "Changed cell id from " << cell_id << " to " << (*it)->cell_id;
  1471. return (*it)->cell_id;
  1472. }
  1473. }
  1474. }
  1475. return cell_id;
  1476. }
  1477. void PynativeExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) {
  1478. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  1479. DumpIR(filename, graph);
  1480. }
  1481. }
  1482. bool PynativeExecutor::IsNestedGrad() const {
  1483. MS_LOG(DEBUG) << "Grad nested order is " << grad_order_;
  1484. return grad_order_ > 1;
  1485. }
  1486. bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
  1487. return std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
  1488. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; });
  1489. }
  1490. bool PynativeExecutor::IsTopestGraph(const std::string &cell_id) {
  1491. return std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
  1492. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->is_topest; });
  1493. }
  1494. TopCellInfoPtr PynativeExecutor::GetTopCell(const string &cell_id, bool find_nearest) {
  1495. auto find_top_cell = [&](const string &cell_id) -> TopCellInfoPtr {
  1496. auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &top_cell) {
  1497. return cell_id == top_cell->cell_id && top_cell->is_topest;
  1498. });
  1499. if (iter != top_cell_list_.end()) {
  1500. return *iter;
  1501. }
  1502. return nullptr;
  1503. };
  1504. TopCellInfoPtr top_cell = find_top_cell(cell_id);
  1505. // find nearest top cell
  1506. if (top_cell == nullptr && find_nearest) {
  1507. for (const auto &cell_info : cell_graph_list_) {
  1508. MS_EXCEPTION_IF_NULL(cell_info);
  1509. top_cell = find_top_cell(cell_info->cell_id);
  1510. if (cell_id == cell_info->cell_id) {
  1511. break;
  1512. }
  1513. }
  1514. }
  1515. return top_cell;
  1516. }
  1517. void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled) {
  1518. auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
  1519. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; });
  1520. if (it != top_cell_list_.end()) {
  1521. (*it)->do_vm_compiled = vm_compiled;
  1522. (*it)->forward_already_run = false;
  1523. (*it)->need_grad = true;
  1524. if ((*it)->is_topest) {
  1525. in_grad_process_ = false;
  1526. top_cell_index_ = 0;
  1527. }
  1528. }
  1529. }
  1530. bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) {
  1531. return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
  1532. [&cell_id](const CellInfoPtr &value) {
  1533. return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos;
  1534. });
  1535. }
  1536. bool PynativeExecutor::IsFirstGradStep(const std::string &cell_id) { return !CheckCellGraph(cell_id, true); }
  1537. void PynativeExecutor::SubNestedGradOrder() {
  1538. if (grad_order_ > 0) {
  1539. --grad_order_;
  1540. }
  1541. }
  1542. bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) {
  1543. return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
  1544. [&cell_id, is_grad](const CellInfoPtr &value) {
  1545. return value->cell_id == cell_id && (!is_grad || value->is_grad);
  1546. });
  1547. }
  1548. bool PynativeExecutor::CheckDynamicCell(const std::string &cell_id) {
  1549. return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
  1550. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_dynamic; });
  1551. }
  1552. bool PynativeExecutor::CheckRealDynamicCell(const std::string &cell_id) {
  1553. return std::any_of(
  1554. cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
  1555. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_real_dynamic; });
  1556. }
  1557. void PynativeExecutor::ClearResidualRes(const std::string &cell_id) {
  1558. // Abnormal case
  1559. if (top_cell_list_.empty() && !graph_stack_.empty()) {
  1560. graph_id_ = 0;
  1561. graph_info_map_.clear();
  1562. cell_graph_list_.clear();
  1563. std::stack<FuncGraphPtr>().swap(graph_stack_);
  1564. }
  1565. if (CheckRealDynamicCell(cell_id)) {
  1566. if (IsTopGraph(cell_id) && graph_stack_.empty() && !IsBpropGraph(cell_id)) {
  1567. // Clear previous step resource
  1568. auto resource = GetResource(cell_id);
  1569. if (resource != nullptr && resource->results().find(pipeline::kBackend) != resource->results().end()) {
  1570. compile::BackendPtr backend = resource->results()[pipeline::kBackend].cast<compile::BackendPtr>();
  1571. auto ms_backend = std::dynamic_pointer_cast<compile::MsBackend>(backend);
  1572. ms_backend->ClearSessionGraphs();
  1573. }
  1574. }
  1575. }
  1576. }
  1577. FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
  1578. // If top graph hold
  1579. for (auto it = top_cell_list_.rbegin(); it != top_cell_list_.rend(); ++it) {
  1580. if (cell_id.find((*it)->cell_id) != std::string::npos) {
  1581. return (*it)->df_builder;
  1582. }
  1583. }
  1584. // Current cell is not top graph, get first top cell
  1585. if (!top_cell_list_.empty()) {
  1586. return top_cell_list_.front()->df_builder;
  1587. }
  1588. return nullptr;
  1589. }
  1590. ResourcePtr PynativeExecutor::GetResource(const std::string &cell_id) {
  1591. for (auto it = top_cell_list_.rbegin(); it != top_cell_list_.rend(); ++it) {
  1592. if (cell_id.find((*it)->cell_id) != std::string::npos) {
  1593. return (*it)->resource;
  1594. }
  1595. }
  1596. // Current cell is not top graph, get first top cell
  1597. if (!top_cell_list_.empty()) {
  1598. return top_cell_list_.front()->resource;
  1599. }
  1600. return nullptr;
  1601. }
  1602. std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
  1603. parse::AstMainType type) {
  1604. MS_EXCEPTION_IF_NULL(ast);
  1605. if (py::isinstance<py::none>(node)) {
  1606. MS_LOG(DEBUG) << "Get none type node!";
  1607. return "";
  1608. }
  1609. auto node_type = ast->GetNodeType(node);
  1610. MS_EXCEPTION_IF_NULL(node_type);
  1611. // check node type
  1612. parse::AstMainType node_main_type = node_type->main_type();
  1613. if (node_main_type != type) {
  1614. MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type;
  1615. return "";
  1616. }
  1617. std::string node_name = node_type->node_name();
  1618. MS_LOG(DEBUG) << "Ast node is " << node_name;
  1619. return node_name;
  1620. }
  1621. void PynativeExecutor::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
  1622. MS_EXCEPTION_IF_NULL(ast);
  1623. py::list args = ast->GetArgs(fn_node);
  1624. for (size_t i = 1; i < args.size(); i++) {
  1625. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  1626. MS_LOG(DEBUG) << "Input arg name: " << arg_name;
  1627. cell_input_args_.emplace(arg_name);
  1628. }
  1629. }
  1630. bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
  1631. MS_LOG(DEBUG) << "Parse if/while expr";
  1632. py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
  1633. const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR);
  1634. if (node_name == parse::NAMED_PRIMITIVE_COMPARE) {
  1635. py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT);
  1636. py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS);
  1637. if (comparators_node.empty()) {
  1638. MS_LOG(DEBUG) << "Get comparators node failed!";
  1639. return false;
  1640. }
  1641. auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
  1642. auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
  1643. // while self.a > self.b and changed self.a or self.b
  1644. if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) {
  1645. auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
  1646. std::string left_variable;
  1647. if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) {
  1648. left_variable = py::cast<std::string>(left_value.attr("id")) + py::cast<std::string>(left_node.attr("attr"));
  1649. }
  1650. auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE);
  1651. std::string right_variable;
  1652. if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) {
  1653. right_variable =
  1654. py::cast<std::string>(right_value.attr("id")) + py::cast<std::string>(comparators_node[0].attr("attr"));
  1655. }
  1656. return ParseBodyContext(ast, node, {left_variable, right_variable});
  1657. }
  1658. // if a[0]
  1659. if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
  1660. py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
  1661. left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
  1662. }
  1663. MS_LOG(DEBUG) << "Left is " << left << " Right is " << right;
  1664. if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
  1665. unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
  1666. return true;
  1667. }
  1668. }
  1669. // if flag:
  1670. if (node_name == parse::NAMED_PRIMITIVE_NAME) {
  1671. std::string id = py::cast<std::string>(test_node.attr("id"));
  1672. if (cell_input_args_.find(id) != cell_input_args_.end()) {
  1673. return true;
  1674. }
  1675. }
  1676. return false;
  1677. }
  1678. bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
  1679. MS_LOG(DEBUG) << "Parse assign expr";
  1680. py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE);
  1681. const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR);
  1682. if (node_name == parse::NAMED_PRIMITIVE_CALL) {
  1683. py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC);
  1684. const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR);
  1685. if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
  1686. py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
  1687. py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
  1688. const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
  1689. if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end()) {
  1690. return true;
  1691. }
  1692. }
  1693. }
  1694. return false;
  1695. }
  1696. bool PynativeExecutor::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
  1697. const std::vector<std::string> &compare_prim) {
  1698. MS_LOG(DEBUG) << "Parse augassign expr";
  1699. bool ret = false;
  1700. if (compare_prim.empty()) {
  1701. return ret;
  1702. }
  1703. py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET);
  1704. if (py::isinstance<py::none>(target_node)) {
  1705. MS_LOG(DEBUG) << "Parse target node is none!";
  1706. return ret;
  1707. }
  1708. py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE);
  1709. if (py::isinstance<py::none>(value_node)) {
  1710. MS_LOG(DEBUG) << "Parse value node is none!";
  1711. return ret;
  1712. }
  1713. std::string assign_prim;
  1714. if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) {
  1715. assign_prim = py::cast<std::string>(value_node.attr("id")) + py::cast<std::string>(target_node.attr("attr"));
  1716. }
  1717. auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim);
  1718. if (iter != compare_prim.end()) {
  1719. ret = true;
  1720. }
  1721. return ret;
  1722. }
  1723. bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
  1724. MS_LOG(DEBUG) << "Parse for expr";
  1725. py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
  1726. if (py::isinstance<py::none>(body_node)) {
  1727. MS_LOG(DEBUG) << "Parse body of for expression is none!";
  1728. return false;
  1729. }
  1730. py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
  1731. size_t count = LongToSize(pcount);
  1732. MS_LOG(DEBUG) << "The for nodes count in body is " << count;
  1733. for (size_t i = 0; i < count; ++i) {
  1734. auto it = py::cast<py::list>(body_node)[i];
  1735. const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
  1736. if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) {
  1737. return true;
  1738. }
  1739. }
  1740. return false;
  1741. }
  1742. bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
  1743. const std::vector<std::string> &compare_prim) {
  1744. MS_EXCEPTION_IF_NULL(ast);
  1745. py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
  1746. if (py::isinstance<py::none>(func_obj)) {
  1747. MS_LOG(DEBUG) << "Parse body of cell is none!";
  1748. return false;
  1749. }
  1750. py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
  1751. size_t count = IntToSize(pcount);
  1752. MS_LOG(DEBUG) << "The nodes count in body is " << count;
  1753. bool ret = false;
  1754. for (size_t i = 0; i < count; ++i) {
  1755. auto node = py::cast<py::list>(func_obj)[i];
  1756. const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT);
  1757. if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) {
  1758. ret = ParseAssignExprNode(ast, node);
  1759. } else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) {
  1760. ret = ParseAugAssignExprNode(ast, node, compare_prim);
  1761. } else if (node_name == parse::NAMED_PRIMITIVE_FOR) {
  1762. ret = ParseForExprNode(ast, node);
  1763. } else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) {
  1764. ret = ParseIfWhileExprNode(ast, node);
  1765. }
  1766. if (ret) {
  1767. MS_LOG(INFO) << "Current cell is dynamic!";
  1768. break;
  1769. }
  1770. }
  1771. return ret;
  1772. }
  1773. std::string PynativeExecutor::GetCellInfo(const py::object &cell) {
  1774. if (py::isinstance<Cell>(cell)) {
  1775. auto c_cell = py::cast<CellPtr>(cell);
  1776. MS_EXCEPTION_IF_NULL(c_cell);
  1777. auto cell_info = c_cell->ToString();
  1778. return cell_info;
  1779. }
  1780. return "";
  1781. }
  1782. bool PynativeExecutor::IsDynamicCell(const py::object &cell) {
  1783. std::string cell_info = GetCellInfo(cell);
  1784. if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) {
  1785. return false;
  1786. }
  1787. // using ast parse to check whether the construct of cell will be changed
  1788. auto ast = std::make_shared<parse::ParseAst>(cell);
  1789. bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
  1790. if (!success) {
  1791. MS_LOG(ERROR) << "Parse code to ast tree failed";
  1792. return false;
  1793. }
  1794. py::object fn_node = ast->GetAstNode();
  1795. // get the name of input args as the initialize of dynamic_variables
  1796. ParseInputArgs(ast, fn_node);
  1797. // parse body context
  1798. bool ret = false;
  1799. ret = ParseBodyContext(ast, fn_node);
  1800. cell_input_args_.clear();
  1801. return ret;
  1802. }
  1803. void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
  1804. auto cell_id = GetCellId(cell, args);
  1805. MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
  1806. // check whether cell needed to construct grad graph
  1807. if (graph_stack_.empty() && !top_cell_list_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) {
  1808. // Clear previous step resource
  1809. auto init_fn = [&](bool flag) {
  1810. CleanPreMemoryInValueNode();
  1811. op_index_map_.clear();
  1812. in_grad_process_ = true;
  1813. in_bprop_process_ = false;
  1814. auto top_cell = GetTopCell(cell_id, flag);
  1815. MS_EXCEPTION_IF_NULL(top_cell);
  1816. top_cell_id_ = top_cell->cell_id;
  1817. top_cell_index_ = top_cell->top_cell_index;
  1818. top_cell->forward_already_run = true;
  1819. MS_LOG(DEBUG) << "Top cell id " << top_cell_id_;
  1820. };
  1821. if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) {
  1822. init_fn(false);
  1823. }
  1824. if (!in_grad_process_ && cell_op_info_stack_.empty()) {
  1825. init_fn(true);
  1826. }
  1827. PushCurrentCellOpInfoToStack();
  1828. MS_LOG(INFO) << "NewGraph already compiled";
  1829. return;
  1830. }
  1831. // Init resource for constructing forward graph and grad graph
  1832. curr_g_ = std::make_shared<FuncGraph>();
  1833. ClearResidualRes(cell_id);
  1834. if (graph_stack_.empty()) {
  1835. if (IsBpropGraph(cell_id)) {
  1836. in_bprop_process_ = true;
  1837. } else {
  1838. MakeNewTopGraph(cell_id, args);
  1839. }
  1840. }
  1841. PushCurrentGraphToStack();
  1842. PushCurrentCellOpInfoToStack();
  1843. if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) {
  1844. auto graph_info = std::make_shared<GraphInfo>(cell_id);
  1845. graph_info_map_[curr_g_] = graph_info;
  1846. }
  1847. for (size_t i = 0; i < args.size(); ++i) {
  1848. auto param = args[i];
  1849. auto new_param = curr_g_->add_parameter();
  1850. std::string param_id = GetId(param);
  1851. SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true);
  1852. SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
  1853. SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
  1854. }
  1855. // Check whether the construct of cell will be changed
  1856. if (!has_dynamic_cell_) {
  1857. has_dynamic_cell_ = IsDynamicCell(cell);
  1858. MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << has_dynamic_cell_;
  1859. if (has_dynamic_cell_ && IsBpropGraph(cell_id)) {
  1860. auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
  1861. [this](const CellInfoPtr &value) { return value->cell_id == top_cell_id_; });
  1862. while (it != cell_graph_list_.end()) {
  1863. (*it)->is_dynamic = true;
  1864. ++it;
  1865. }
  1866. }
  1867. }
  1868. }
  1869. void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args) {
  1870. for (const auto &arg : args) {
  1871. if (py::isinstance<tensor::Tensor>(arg)) {
  1872. auto tensor = arg.cast<tensor::TensorPtr>();
  1873. if (tensor && tensor->is_parameter()) {
  1874. MS_EXCEPTION(TypeError) << "The inputs could not be Parameter.";
  1875. }
  1876. }
  1877. }
  1878. // Clear resource in old top cell
  1879. if (CheckRealDynamicCell(cell_id)) {
  1880. VectorClear<std::vector<TopCellInfoPtr>>(&top_cell_list_, cell_id);
  1881. }
  1882. CleanPreMemoryInValueNode();
  1883. // Init resource for new top cell
  1884. if (!CheckCellGraph(cell_id)) {
  1885. has_dynamic_cell_ = false;
  1886. }
  1887. op_index_map_.clear();
  1888. top_cell_id_ = cell_id;
  1889. in_grad_process_ = true;
  1890. // update forward already run flag with previous top cell
  1891. auto pre_top_cell = GetTopCell(cell_id);
  1892. if (pre_top_cell != nullptr) {
  1893. pre_top_cell->forward_already_run = true;
  1894. }
  1895. auto df_builder = std::make_shared<FuncGraph>();
  1896. auto graph_info = std::make_shared<GraphInfo>(cell_id);
  1897. graph_info_map_[df_builder] = graph_info;
  1898. auto resource = std::make_shared<pipeline::Resource>();
  1899. resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
  1900. auto top_cell_info = std::make_shared<TopCellInfo>(true, resource, df_builder, cell_id);
  1901. top_cell_info->forward_already_run = true;
  1902. if (!IsTopestGraph(cell_id)) {
  1903. top_cell_info->top_cell_index = cell_graph_list_.size();
  1904. top_cell_index_ = top_cell_info->top_cell_index;
  1905. } else {
  1906. auto top_cell = GetTopCell(cell_id, true);
  1907. MS_EXCEPTION_IF_NULL(top_cell);
  1908. top_cell_info->top_cell_index = top_cell->top_cell_index;
  1909. top_cell_index_ = top_cell_info->top_cell_index;
  1910. }
  1911. top_cell_list_.emplace_back(top_cell_info);
  1912. MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get();
  1913. }
  1914. std::string PynativeExecutor::GetCellOpInfo() {
  1915. if (cell_op_info_stack_.empty()) {
  1916. MS_LOG(EXCEPTION) << "The cell op info stack is empty";
  1917. }
  1918. return cell_op_info_stack_.top();
  1919. }
  1920. void PynativeExecutor::ReplaceCellOpInfoByCellId(const std::string &cell_id) {
  1921. if (cell_id.empty()) {
  1922. MS_LOG(EXCEPTION) << "The cell id is empty";
  1923. }
  1924. if (cell_op_info_stack_.empty()) {
  1925. MS_LOG(DEBUG) << "The cell op info stack is empty, No need replace";
  1926. return;
  1927. }
  1928. cell_op_info_stack_.top() = cell_op_info_stack_.top() + cell_id;
  1929. }
  1930. void PynativeExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
  1931. bool is_param) {
  1932. if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
  1933. return;
  1934. }
  1935. auto tuple = args.cast<py::tuple>();
  1936. auto tuple_size = static_cast<int64_t>(tuple.size());
  1937. for (int64_t i = 0; i < tuple_size; ++i) {
  1938. auto id = GetId(tuple[i]);
  1939. if (is_param && node->isa<Parameter>()) {
  1940. auto param = node->cast<ParameterPtr>();
  1941. MS_EXCEPTION_IF_NULL(param);
  1942. SetParamNodeMapInGraphInfoMap(g, id, param);
  1943. }
  1944. SetNodeMapInGraphInfoMap(g, id, node, i);
  1945. SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, std::vector<int64_t>{i}, is_param);
  1946. }
  1947. }
  1948. void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args,
  1949. const AnfNodePtr &node,
  1950. const std::vector<int64_t> &index_sequence, bool is_param) {
  1951. if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
  1952. return;
  1953. }
  1954. auto tuple = args.cast<py::tuple>();
  1955. auto tuple_size = static_cast<int64_t>(tuple.size());
  1956. for (int64_t i = 0; i < tuple_size; ++i) {
  1957. std::vector<int64_t> tmp = index_sequence;
  1958. tmp.emplace_back(i);
  1959. auto id = GetId(tuple[i]);
  1960. if (is_param && node->isa<Parameter>()) {
  1961. auto param = node->cast<ParameterPtr>();
  1962. MS_EXCEPTION_IF_NULL(param);
  1963. SetParamNodeMapInGraphInfoMap(g, id, param);
  1964. }
  1965. SetNodeMapInGraphInfoMap(g, id, node, tmp);
  1966. SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, tmp, is_param);
  1967. }
  1968. }
  1969. void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
  1970. const auto &cell_id = GetCellId(cell, args);
  1971. MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
  1972. if (graph_stack_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) {
  1973. PopCurrentCellOpInfoFromStack();
  1974. MS_LOG(INFO) << "Endgraph already compiled";
  1975. return;
  1976. }
  1977. auto out_id = GetId(out);
  1978. // x =op1, y =op2, return (x, y)
  1979. auto graph_info = graph_info_map_.at(curr_g_);
  1980. MS_EXCEPTION_IF_NULL(graph_info);
  1981. if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) {
  1982. if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
  1983. auto tuple = out.cast<py::tuple>();
  1984. auto tuple_size = static_cast<int64_t>(tuple.size());
  1985. std::vector<AnfNodePtr> inputs;
  1986. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  1987. for (int64_t i = 0; i < tuple_size; i++) {
  1988. inputs.emplace_back(GetInput(tuple[i], false));
  1989. }
  1990. auto cnode = curr_g_->NewCNode(inputs);
  1991. SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
  1992. SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
  1993. } else {
  1994. MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
  1995. MakeValueNode(out, out_id);
  1996. }
  1997. }
  1998. EndGraphByOutId(cell, cell_id, out, out_id, args);
  1999. }
  2000. void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
  2001. const std::string &out_id, const py::args &args) {
  2002. AnfNodePtr output_node = GetObjNode(out, out_id);
  2003. curr_g_->set_output(output_node);
  2004. MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString();
  2005. if (EndBpropGraph(cell_id)) {
  2006. MS_LOG(DEBUG) << "Get bprop function cell";
  2007. return;
  2008. }
  2009. auto resource = GetResource(top_cell_id_);
  2010. MS_EXCEPTION_IF_NULL(resource);
  2011. resource->manager()->AddFuncGraph(curr_g_);
  2012. UpdateCellGraph(cell, curr_g_, cell_id, true, false);
  2013. FuncGraphPtr newfg = nullptr;
  2014. auto top_cell = GetTopCell(top_cell_id_);
  2015. MS_EXCEPTION_IF_NULL(top_cell);
  2016. // Cell no Change
  2017. if (CheckDynamicCell(cell_id) && !CheckCellChanged(cell_id)) {
  2018. MS_LOG(DEBUG) << "Cell is not dynamic, No need make ad grad";
  2019. top_cell->need_grad = false;
  2020. std::unordered_set<AnfNodePtr> node_set;
  2021. ClearCnodeRes(curr_g_->output(), &node_set);
  2022. node_set.clear();
  2023. } else {
  2024. MS_LOG(DEBUG) << "Need make ad grad";
  2025. if (!top_cell->need_grad) {
  2026. std::unordered_set<AnfNodePtr> node_set;
  2027. ClearCnodeRes(curr_g_->output(), &node_set);
  2028. node_set.clear();
  2029. }
  2030. newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args);
  2031. }
  2032. if (graph_stack_.size() > 1) {
  2033. std::vector<AnfNodePtr> inputs;
  2034. inputs.emplace_back(NewValueNode(curr_g_));
  2035. PopGraphStack();
  2036. PopCurrentCellOpInfoFromStack();
  2037. ReplaceCellOpInfoByCellId(cell_id);
  2038. // connect the previous graph to the inside graph
  2039. auto graph_prev = graph_stack_.top();
  2040. for (size_t i = 0; i < args.size(); i++) {
  2041. auto input = GetInput(args[i], false);
  2042. inputs.emplace_back(input);
  2043. }
  2044. auto out_cnode = graph_prev->NewCNode(inputs);
  2045. SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args));
  2046. SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
  2047. SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
  2048. } else {
  2049. if (newfg != nullptr) {
  2050. DumpGraphIR("before_resolve.ir", newfg);
  2051. parse::ResolveFuncGraph(newfg, resource);
  2052. DumpGraphIR("after_resolve.ir", newfg);
  2053. resource->set_func_graph(newfg);
  2054. }
  2055. PopGraphStack();
  2056. PopCurrentCellOpInfoFromStack();
  2057. }
  2058. }
  2059. bool PynativeExecutor::EndBpropGraph(const string &cell_id) {
  2060. auto is_bprop_graph = IsBpropGraph(cell_id);
  2061. if (is_bprop_graph) {
  2062. if (!IsNestedGrad()) {
  2063. PopGraphStack();
  2064. PopCurrentCellOpInfoFromStack();
  2065. ReplaceCellOpInfoByCellId(cell_id);
  2066. }
  2067. return true;
  2068. }
  2069. return false;
  2070. }
  2071. bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) {
  2072. bool res = false;
  2073. if (CheckRealDynamicCell(cell_id)) {
  2074. MS_LOG(DEBUG) << "Cur cell " << cell_id << " is dynamic, no need check";
  2075. return true;
  2076. }
  2077. if (GetCellOpInfo().empty()) {
  2078. MS_LOG(DEBUG) << "Cell op info is empty";
  2079. return true;
  2080. }
  2081. auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
  2082. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  2083. if (it == cell_graph_list_.end() || IsFirstGradStep(top_cell_id_)) {
  2084. return true;
  2085. }
  2086. MS_LOG(DEBUG) << "Cell op info " << GetCellOpInfo() << ", old " << (*it)->cell_ops_info.at((*it)->call_times);
  2087. if ((*it)->cell_ops_info.at((*it)->call_times) != GetCellOpInfo()) {
  2088. res = true;
  2089. UpdateCellDynamic(cell_id);
  2090. MS_LOG(DEBUG) << "Cell self changed";
  2091. }
  2092. (*it)->call_times = (*it)->call_times < (*it)->cell_ops_info.size() - 1 ? (*it)->call_times + 1 : 0;
  2093. return res;
  2094. }
  2095. void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) {
  2096. for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) {
  2097. if ((*it)->cell_id != cell_id) {
  2098. (*it)->is_real_dynamic = true;
  2099. continue;
  2100. }
  2101. (*it)->is_real_dynamic = true;
  2102. break;
  2103. }
  2104. }
  2105. bool PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
  2106. bool need_cloned, bool is_grad) {
  2107. auto update_in_endgraph = need_cloned && !is_grad;
  2108. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  2109. // Bprop just save backward graph
  2110. auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
  2111. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  2112. if (it != cell_graph_list_.end()) {
  2113. (*it)->is_grad = is_grad;
  2114. if (g != (*it)->fg) {
  2115. graph_info_map_.update((*it)->fg, g);
  2116. (*it)->fg = g;
  2117. }
  2118. if (update_in_endgraph && IsFirstGradStep(top_cell_id_)) {
  2119. (*it)->cell_ops_info.emplace_back(GetCellOpInfo());
  2120. }
  2121. MS_LOG(DEBUG) << "Update bprop bg cell id " << cell_id;
  2122. } else {
  2123. py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
  2124. auto bprop_func_cell_id = GetId(bprop_func);
  2125. MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id
  2126. << " cell ops info " << GetCellOpInfo();
  2127. auto cell_info = std::make_shared<CellInfo>(true, has_dynamic_cell_, g, cell_id, bprop_func_cell_id);
  2128. cell_info->cell_ops_info.emplace_back(GetCellOpInfo());
  2129. if (in_bprop_process_) {
  2130. cell_graph_list_.emplace_back(cell_info);
  2131. } else {
  2132. cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info);
  2133. }
  2134. }
  2135. return true;
  2136. }
  2137. return false;
  2138. }
  2139. void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
  2140. bool need_cloned, bool is_grad) {
  2141. auto update_in_endgraph = need_cloned && !is_grad;
  2142. if (UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad)) {
  2143. return;
  2144. }
  2145. FuncGraphPtr tmp = g;
  2146. if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) {
  2147. MS_LOG(DEBUG) << "No need cloned";
  2148. need_cloned = false;
  2149. }
  2150. auto clone_fn = [&g, &tmp, need_cloned, this]() {
  2151. if (!need_cloned) {
  2152. return;
  2153. }
  2154. tmp = BasicClone(g);
  2155. graph_info_map_.update(g, tmp);
  2156. std::unordered_set<AnfNodePtr> node_set;
  2157. ClearCnodeRes(tmp->output(), &node_set);
  2158. node_set.clear();
  2159. };
  2160. // First call or cell id not exist
  2161. if (update_in_endgraph && (IsFirstGradStep(top_cell_id_) || !CheckCellGraph(cell_id))) {
  2162. if (!CheckCellGraph(cell_id)) {
  2163. clone_fn();
  2164. MS_LOG(DEBUG) << "Add new cell with cloned graph " << cell_id << " cell ops info " << GetCellOpInfo();
  2165. auto cell_info = std::make_shared<CellInfo>(true, has_dynamic_cell_, tmp, cell_id, "");
  2166. cell_info->cell_ops_info.emplace_back(GetCellOpInfo());
  2167. if (in_bprop_process_) {
  2168. cell_graph_list_.emplace_back(cell_info);
  2169. } else {
  2170. cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info);
  2171. }
  2172. } else {
  2173. auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
  2174. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  2175. if (it != cell_graph_list_.end()) {
  2176. (*it)->cell_ops_info.emplace_back(GetCellOpInfo());
  2177. }
  2178. MS_LOG(DEBUG) << "Add another same cell ops info";
  2179. }
  2180. return;
  2181. }
  2182. for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) {
  2183. if ((*it)->cell_id != cell_id) {
  2184. continue;
  2185. }
  2186. if (IsFirstGradStep(cell_id)) {
  2187. // no compute grad
  2188. (*it)->is_grad = is_grad;
  2189. }
  2190. if (need_cloned) {
  2191. clone_fn();
  2192. if ((*it)->fg != nullptr) {
  2193. graph_info_map_.erase((*it)->fg);
  2194. }
  2195. MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with cloned new " << tmp.get();
  2196. (*it)->fg = tmp;
  2197. }
  2198. if (!need_cloned && !is_grad) {
  2199. graph_info_map_.erase((*it)->fg);
  2200. MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with new " << tmp.get();
  2201. (*it)->fg = tmp;
  2202. }
  2203. break;
  2204. }
  2205. }
  2206. void PynativeExecutor::ClearCnodeRes(const AnfNodePtr &node, std::unordered_set<AnfNodePtr> *node_set) {
  2207. MS_EXCEPTION_IF_NULL(node);
  2208. MS_EXCEPTION_IF_NULL(node_set);
  2209. if (!node->isa<CNode>() || (*node_set).find(node) != (*node_set).end()) {
  2210. return;
  2211. }
  2212. (*node_set).insert(node);
  2213. auto cnode = node->cast<CNodePtr>();
  2214. cnode->clear_inputs_value();
  2215. cnode->set_forward(nullptr, "");
  2216. for (size_t i = 0; i < cnode->size(); ++i) {
  2217. auto n = cnode->input(i);
  2218. ClearCnodeRes(n, node_set);
  2219. }
  2220. }
  2221. FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r,
  2222. const std::string &cell_id, const py::args &args) {
  2223. bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME);
  2224. if (is_custom_bprop) {
  2225. size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
  2226. if (par_number > 0) {
  2227. MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
  2228. << " parameters that is not supported in the net.";
  2229. }
  2230. MS_LOG(INFO) << "Use cell custom bprop function.";
  2231. FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
  2232. if (bprop_graph != nullptr) {
  2233. (void)g->transforms().emplace(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
  2234. (void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g)));
  2235. }
  2236. }
  2237. auto is_top = IsTopGraph(cell_id);
  2238. MS_LOG(DEBUG) << "Grad top cell " << is_top;
  2239. // Before make grad graph, we need to run auto-monad on forward graph,
  2240. // so that side effects in forward graph can be handled in grad graph.
  2241. (void)pipeline::AutoMonad(g);
  2242. set_need_replace_forward(!IsNestedGrad());
  2243. // Obtain grad graph
  2244. auto newfg = ad::Grad(g, r, is_top);
  2245. if (is_custom_bprop) {
  2246. auto params = newfg->parameters();
  2247. auto manager = Manage({newfg}, false);
  2248. if (args.size() > params.size()) {
  2249. MS_EXCEPTION(TypeError) << "The number of arguments " << args.size()
  2250. << " is more than the number of parameters required, which is " << params.size();
  2251. }
  2252. for (size_t i = 0; i < args.size(); i++) {
  2253. ValuePtr value = PyAttrValue(args[i]);
  2254. auto v_node = NewValueNode(value);
  2255. manager->Replace(params[i], v_node);
  2256. }
  2257. UpdateCellGraph(cell, newfg, cell_id, false, false);
  2258. }
  2259. return newfg;
  2260. }
  2261. std::string PynativeExecutor::GetGradCellId(bool has_sens, const py::object &cell, const py::args &args,
  2262. py::object *forward_args, py::object *sens) {
  2263. auto size = args.size();
  2264. size_t forward_args_size = size;
  2265. if (has_sens) {
  2266. if (size >= 1) {
  2267. --forward_args_size;
  2268. if (sens != nullptr) {
  2269. *sens = args[forward_args_size];
  2270. }
  2271. }
  2272. py::tuple f_args(forward_args_size);
  2273. for (size_t i = 0; i < forward_args_size; ++i) {
  2274. f_args[i] = args[i];
  2275. }
  2276. *forward_args = f_args;
  2277. }
  2278. const auto &cell_id = GetCellId(cell, *forward_args);
  2279. return cell_id;
  2280. }
  2281. void PynativeExecutor::SaveAllValueNodeTensors(const FuncGraphPtr &graph) {
  2282. std::unordered_set<tensor::TensorPtr> all_value_node_tensors;
  2283. auto trace_function = [&all_value_node_tensors](const AnfNodePtr &anf_node) {
  2284. auto value = GetValueNode(anf_node);
  2285. if (value) {
  2286. if (value->isa<tensor::Tensor>()) {
  2287. auto tensor = value->cast<tensor::TensorPtr>();
  2288. MS_EXCEPTION_IF_NULL(tensor);
  2289. if (tensor->device_address()) {
  2290. all_value_node_tensors.emplace(tensor);
  2291. }
  2292. } else if (value->isa<ValueTuple>()) {
  2293. auto tuple = value->cast<ValueTuplePtr>();
  2294. MS_EXCEPTION_IF_NULL(tuple);
  2295. for (size_t i = 0; i < tuple->size(); i++) {
  2296. if ((*tuple)[i]->isa<tensor::Tensor>()) {
  2297. auto tensor = (*tuple)[i]->cast<tensor::TensorPtr>();
  2298. MS_EXCEPTION_IF_NULL(tensor);
  2299. if (tensor->device_address()) {
  2300. all_value_node_tensors.emplace(tensor);
  2301. }
  2302. }
  2303. }
  2304. }
  2305. }
  2306. return FOLLOW;
  2307. };
  2308. (void)TopoSort(graph->get_return(), SuccDeeperSimple, trace_function);
  2309. all_value_node_tensors_ = all_value_node_tensors;
  2310. }
  2311. void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  2312. const py::args &args) {
  2313. auto size = args.size();
  2314. py::object sens = py::none();
  2315. py::object forward_args = args;
  2316. const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args, &forward_args, &sens);
  2317. MS_LOG(DEBUG) << "GradNet start " << size << " " << cell_id;
  2318. const auto &params_changed = CheckGradParamsChanged(cell_id, weights, sens);
  2319. if (!params_changed && !IsFirstGradStep(cell_id) && !CheckRealDynamicCell(cell_id)) {
  2320. UpdateTopCellInfo(cell_id, false);
  2321. ClearDynamicTopRes(cell_id);
  2322. MS_LOG(INFO) << "Gradgraph already compiled";
  2323. return;
  2324. }
  2325. // Nested graph
  2326. if (CheckCellGraph(cell_id) && !graph_stack_.empty()) {
  2327. MS_LOG(DEBUG) << "Set nested top graph";
  2328. SetNestedTopGraph(cell, forward_args, cell_id);
  2329. }
  2330. auto df_builder = GetDfbuilder(cell_id);
  2331. MS_EXCEPTION_IF_NULL(df_builder);
  2332. auto resource = GetResource(cell_id);
  2333. MS_EXCEPTION_IF_NULL(resource);
  2334. MS_LOG(DEBUG) << "df_builder ptr " << df_builder.get() << " resource ptr " << resource.get();
  2335. // Set all params(input+weights)
  2336. SetGradGraphParams(df_builder, resource, size);
  2337. // Get params(weights) require derivative
  2338. auto w_args = GetWeightsArgs(weights, df_builder);
  2339. // Get the parameters items and add the value to args_spec
  2340. auto args_spec = GetArgsSpec(args, df_builder);
  2341. resource->set_args_spec(args_spec);
  2342. // Get real grad graph
  2343. DumpGraphIR("before_grad.ir", resource->func_graph());
  2344. GradGraph(resource->func_graph(), grad, w_args, size, cell_id);
  2345. DumpGraphIR("after_grad.ir", df_builder);
  2346. resource->set_func_graph(df_builder);
  2347. resource->manager()->KeepRoots({df_builder});
  2348. resource->results()[pipeline::kBackend] = compile::CreateBackend();
  2349. MS_LOG(INFO) << "Start opt";
  2350. if (has_dynamic_cell_) {
  2351. SaveAllValueNodeTensors(resource->func_graph());
  2352. }
  2353. PynativeOptimizeAction(resource);
  2354. DumpGraphIR("after_opt.ir", resource->func_graph());
  2355. SaveTensorsInValueNode(resource);
  2356. TaskEmitAction(resource);
  2357. ExecuteAction(resource);
  2358. ClearUselessRes(df_builder, cell, cell_id);
  2359. UpdateCellGraph(cell, curr_g_, cell_id, false, true);
  2360. UpdateTopCellInfo(cell_id, true);
  2361. resource->Clean();
  2362. }
  2363. void PynativeExecutor::ClearDynamicTopRes(const std::string &cell_id) {
  2364. if (IsTopestGraph(cell_id)) {
  2365. op_index_map_.clear();
  2366. }
  2367. // Delete unused top cell resource
  2368. if (!CheckDynamicCell(cell_id)) {
  2369. return;
  2370. }
  2371. int same_top_cell_count = 0;
  2372. for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) {
  2373. if ((*it)->cell_id == cell_id) {
  2374. ++same_top_cell_count;
  2375. if (same_top_cell_count > 1) {
  2376. graph_info_map_.erase((*it)->df_builder);
  2377. it = top_cell_list_.erase(it);
  2378. --same_top_cell_count;
  2379. } else {
  2380. ++it;
  2381. }
  2382. } else {
  2383. ++it;
  2384. }
  2385. }
  2386. }
  2387. bool PynativeExecutor::CheckGradParamsChanged(const std::string &cell_id, const py::object &weights,
  2388. const py::object &sens) {
  2389. bool res = false;
  2390. auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
  2391. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; });
  2392. if (it == top_cell_list_.end()) {
  2393. return res;
  2394. }
  2395. auto fn = [](const py::object &arg) {
  2396. std::string arg_id;
  2397. if (py::isinstance<tensor::Tensor>(arg)) {
  2398. auto tensor_ptr = py::cast<tensor::TensorPtr>(arg);
  2399. auto dtype = tensor_ptr->data_type();
  2400. auto shape = tensor_ptr->shape();
  2401. std::stringstream ss;
  2402. std::for_each(shape.begin(), shape.end(), [&ss](int i) { ss << i; });
  2403. arg_id = ss.str() + std::to_string(dtype);
  2404. } else {
  2405. arg_id = std::string(py::str(arg));
  2406. }
  2407. return arg_id;
  2408. };
  2409. std::string sens_id = "sens";
  2410. if (!py::isinstance<py::none>(sens)) {
  2411. sens_id = fn(sens);
  2412. }
  2413. if (!(*it)->sens_id.empty() && (*it)->sens_id != sens_id) {
  2414. (*it)->sens_id = sens_id;
  2415. }
  2416. std::string weights_id = fn(weights);
  2417. if (!(*it)->weights_id.empty() && (*it)->weights_id != weights_id) {
  2418. (*it)->weights_id = weights_id;
  2419. res = true;
  2420. }
  2421. return res;
  2422. }
  2423. void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id) {
  2424. if (IsTopGraph(cell_id)) {
  2425. VectorClear<std::vector<TopCellInfoPtr>>(&top_cell_list_, cell_id);
  2426. }
  2427. ResourcePtr resource = nullptr;
  2428. auto ia = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
  2429. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; });
  2430. if (ia != top_cell_list_.end()) {
  2431. resource = GetResource((*ia)->cell_id);
  2432. MS_EXCEPTION_IF_NULL(resource);
  2433. MS_LOG(DEBUG) << "Find old resource " << resource.get();
  2434. }
  2435. if (resource == nullptr) {
  2436. resource = std::make_shared<pipeline::Resource>();
  2437. resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
  2438. MS_LOG(DEBUG) << "Make new resource " << resource.get();
  2439. }
  2440. MS_EXCEPTION_IF_NULL(resource);
  2441. FuncGraphPtr df_builder = std::make_shared<FuncGraph>();
  2442. auto graph_info = std::make_shared<GraphInfo>(cell_id);
  2443. graph_info_map_[df_builder] = graph_info;
  2444. auto top_cell_info = std::make_shared<TopCellInfo>(false, resource, df_builder, cell_id);
  2445. top_cell_info->top_cell_index = top_cell_index_;
  2446. top_cell_list_.emplace_back(top_cell_info);
  2447. FuncGraphPtr forward_graph = nullptr;
  2448. auto ib = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
  2449. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  2450. if (ib != cell_graph_list_.end()) {
  2451. forward_graph = (*ib)->fg;
  2452. }
  2453. MS_EXCEPTION_IF_NULL(forward_graph);
  2454. if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  2455. DumpGraphIR("nested_bprop.ir", forward_graph);
  2456. // Custom bprop get backward graph(before opt), which use like other forward graph
  2457. curr_g_ = forward_graph;
  2458. resource->set_func_graph(forward_graph);
  2459. return;
  2460. }
  2461. // Copy weights parameters
  2462. ReplaceGraphParams(df_builder, forward_graph, cell_id);
  2463. resource->manager()->AddFuncGraph(forward_graph);
  2464. DumpGraphIR("nested_fg.ir", forward_graph);
  2465. set_need_replace_forward(false);
  2466. auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args);
  2467. resource->set_func_graph(newfg);
  2468. }
  2469. void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph,
  2470. const std::string &cell_id) {
  2471. std::vector<FuncGraphPtr> graph_before{};
  2472. bool index_find = false;
  2473. for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) {
  2474. if (IsBpropGraph((*it)->cell_id) || (*it)->fg == nullptr) {
  2475. continue;
  2476. }
  2477. if (index_find) {
  2478. graph_before.emplace_back((*it)->fg);
  2479. continue;
  2480. }
  2481. if ((*it)->cell_id == cell_id) {
  2482. index_find = true;
  2483. graph_before.emplace_back((*it)->fg);
  2484. }
  2485. }
  2486. auto manager = Manage({forward_graph}, false);
  2487. for (const auto &f : graph_before) {
  2488. auto graph_info = graph_info_map_.at(f);
  2489. MS_EXCEPTION_IF_NULL(graph_info);
  2490. for (const auto &it : graph_info->params) {
  2491. if (!it.second->has_default()) {
  2492. continue;
  2493. }
  2494. auto new_param = df_builder->add_parameter();
  2495. new_param->set_abstract(it.second->abstract());
  2496. new_param->set_name(it.second->name());
  2497. new_param->set_default_param(it.second->default_param());
  2498. ScopePtr scope = (it.second->scope() != kDefaultScope) ? it.second->scope() : kDefaultScope;
  2499. new_param->set_scope(scope);
  2500. manager->Replace(it.second, new_param);
  2501. replace_weights_map_[forward_graph].emplace_back(std::make_pair(it.second, new_param));
  2502. MS_LOG(DEBUG) << "Param name " << new_param->name() << " ptr " << new_param.get();
  2503. auto graph_info_of_df_builder = graph_info_map_.at(df_builder);
  2504. MS_EXCEPTION_IF_NULL(graph_info_of_df_builder);
  2505. graph_info_of_df_builder->params[it.first] = new_param;
  2506. SetParamNodeMapInGraphInfoMap(df_builder, it.first, new_param);
  2507. SetNodeMapInGraphInfoMap(df_builder, it.first, new_param);
  2508. }
  2509. }
  2510. graph_before.clear();
  2511. }
  2512. void PynativeExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size) {
  2513. std::vector<AnfNodePtr> new_params;
  2514. for (size_t i = 0; i < size; i++) {
  2515. ParameterPtr p = std::make_shared<Parameter>(df_builder);
  2516. new_params.emplace_back(p);
  2517. }
  2518. MS_LOG(DEBUG) << "GradNet weight param size " << df_builder->parameters().size();
  2519. // df_builder_->parameters() set in GetInput, which are weights params
  2520. new_params.insert(new_params.end(), df_builder->parameters().begin(), df_builder->parameters().end());
  2521. df_builder->set_parameters(new_params);
  2522. resource->manager()->SetParameters(df_builder, new_params);
  2523. }
  2524. std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder) {
  2525. std::vector<AnfNodePtr> w_args;
  2526. if (!py::hasattr(weights, "__parameter_tuple__")) {
  2527. MS_LOG(DEBUG) << "No paramter_tuple get";
  2528. return {};
  2529. }
  2530. auto tuple = weights.cast<py::tuple>();
  2531. MS_LOG(DEBUG) << "Get weights tuple size " << tuple.size();
  2532. w_args.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  2533. for (size_t it = 0; it < tuple.size(); ++it) {
  2534. auto param = tuple[it];
  2535. auto param_id = GetId(param);
  2536. AnfNodePtr para_node = nullptr;
  2537. auto graph_info = graph_info_map_.at(df_builder);
  2538. MS_EXCEPTION_IF_NULL(graph_info);
  2539. if (graph_info->params.find(param_id) != graph_info->params.end() &&
  2540. graph_info->node_map.find(param_id) != graph_info->node_map.end()) {
  2541. para_node = graph_info->node_map[param_id].first;
  2542. } else {
  2543. auto name_attr = parse::python_adapter::GetPyObjAttr(param, "name");
  2544. if (py::isinstance<py::none>(name_attr)) {
  2545. MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
  2546. }
  2547. auto param_name = py::cast<std::string>(name_attr);
  2548. auto free_param = df_builder->add_parameter();
  2549. free_param->set_name(param_name);
  2550. auto value = py::cast<tensor::TensorPtr>(param);
  2551. free_param->set_default_param(value);
  2552. free_param->debug_info()->set_name(param_name);
  2553. para_node = free_param;
  2554. }
  2555. w_args.emplace_back(para_node);
  2556. }
  2557. return w_args;
  2558. }
  2559. abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder) {
  2560. abstract::AbstractBasePtrList args_spec;
  2561. std::size_t size = args.size();
  2562. auto df_params = df_builder->parameters();
  2563. if (df_params.size() < size) {
  2564. MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << size;
  2565. }
  2566. // input params
  2567. for (std::size_t i = 0; i < size; i++) {
  2568. ValuePtr converted = nullptr;
  2569. bool succ = parse::ConvertData(args[i], &converted);
  2570. if (!succ) {
  2571. MS_LOG(EXCEPTION) << "Args convert error";
  2572. }
  2573. bool broaden = true;
  2574. auto abs = abstract::FromValue(converted, broaden);
  2575. args_spec.emplace_back(abs);
  2576. auto param_node = std::static_pointer_cast<Parameter>(df_params[i]);
  2577. param_node->set_abstract(abs);
  2578. }
  2579. // weights params
  2580. for (const auto &param : df_params) {
  2581. auto param_node = std::static_pointer_cast<Parameter>(param);
  2582. if (param_node->has_default()) {
  2583. ValuePtr value = param_node->default_param();
  2584. auto ptr = value->ToAbstract();
  2585. MS_EXCEPTION_IF_NULL(ptr);
  2586. args_spec.emplace_back(ptr);
  2587. param_node->set_abstract(ptr);
  2588. }
  2589. }
  2590. MS_LOG(DEBUG) << "Args_spec size " << args_spec.size();
  2591. return args_spec;
  2592. }
  2593. void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op,
  2594. const std::vector<AnfNodePtr> &weights, size_t arg_size, const std::string &cell_id) {
  2595. FuncGraphPtr top_g = nullptr;
  2596. auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
  2597. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  2598. if (it != cell_graph_list_.end()) {
  2599. top_g = (*it)->fg;
  2600. }
  2601. MS_EXCEPTION_IF_NULL(top_g);
  2602. auto nparam = top_g->parameters().size();
  2603. MS_LOG(DEBUG) << "Top graph input params size " << nparam;
  2604. std::ostringstream ss;
  2605. ss << "grad{" << nparam << "}";
  2606. auto df_builder = GetDfbuilder(cell_id);
  2607. MS_EXCEPTION_IF_NULL(df_builder);
  2608. auto resource = GetResource(cell_id);
  2609. MS_EXCEPTION_IF_NULL(resource);
  2610. df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  2611. df_builder->debug_info()->set_name(ss.str());
  2612. auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g->parameters(), weights);
  2613. std::vector<AnfNodePtr> inputs = {NewValueNode(df)};
  2614. auto df_params = df_builder->parameters();
  2615. if (df_params.size() < arg_size) {
  2616. MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << arg_size;
  2617. }
  2618. for (size_t i = 0; i < arg_size; ++i) {
  2619. inputs.emplace_back(df_params[i]);
  2620. }
  2621. auto out = df_builder->NewCNode(inputs);
  2622. df_builder->set_output(out);
  2623. resource->manager()->AddFuncGraph(df);
  2624. resource->manager()->AddFuncGraph(df_builder);
  2625. }
  2626. void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell,
  2627. const std::string &cell_id) {
  2628. graph_info_map_.erase(df_builder);
  2629. bool has_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME);
  2630. bool is_dynamic_top_fist_grad = CheckDynamicCell(cell_id) && IsFirstGradStep(cell_id);
  2631. bool is_topmost = IsTopestGraph(cell_id);
  2632. if (has_custom_bprop || is_dynamic_top_fist_grad || !is_topmost) {
  2633. return;
  2634. }
  2635. MS_LOG(DEBUG) << "Update topmost cell graph list and graph info map";
  2636. // Clear graph_info_map_
  2637. std::vector<std::string> l{};
  2638. bool index_find = false;
  2639. auto it_end = cell_graph_list_.end();
  2640. for (size_t i = 0; i < top_cell_list_.size(); ++i) {
  2641. if (top_cell_list_[i]->cell_id == cell_id) {
  2642. index_find = true;
  2643. continue;
  2644. }
  2645. if (index_find) {
  2646. it_end = cell_graph_list_.begin() + top_cell_list_[i]->top_cell_index;
  2647. break;
  2648. }
  2649. }
  2650. index_find = false;
  2651. for (auto it = cell_graph_list_.begin() + top_cell_index_; it != it_end; ++it) {
  2652. if ((*it)->fg != nullptr) {
  2653. std::unordered_set<AnfNodePtr> node_set;
  2654. ClearCnodeRes((*it)->fg->output(), &node_set);
  2655. node_set.clear();
  2656. (*it)->fg = nullptr;
  2657. }
  2658. if (index_find) {
  2659. l.emplace_back((*it)->cell_id);
  2660. continue;
  2661. }
  2662. if ((*it)->cell_id == cell_id) {
  2663. index_find = true;
  2664. l.emplace_back((*it)->cell_id);
  2665. }
  2666. }
  2667. for (const auto &it : l) {
  2668. for (auto ic = graph_info_map_.begin(); ic != graph_info_map_.end();) {
  2669. if (ic->second->cell_id.find(it) != std::string::npos) {
  2670. ic = graph_info_map_.erase(ic);
  2671. } else {
  2672. ++ic;
  2673. }
  2674. }
  2675. }
  2676. }
  2677. py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) {
  2678. BaseRef ret = false;
  2679. AddNestedGradOrder();
  2680. if (!grad_running()) {
  2681. MS_LOG(DEBUG) << "Grad not running yet";
  2682. return BaseRefToPyData(ret);
  2683. }
  2684. const auto &cell_id = GetCellId(cell, args);
  2685. std::string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size()));
  2686. MS_LOG(DEBUG) << "Key is " << key;
  2687. for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) {
  2688. MS_LOG(DEBUG) << "Cur cell id " << (*it)->cell_id;
  2689. if (key != (*it)->cell_id.substr(0, std::min(PTR_LEN, (*it)->cell_id.size()))) {
  2690. continue;
  2691. }
  2692. MS_LOG(DEBUG) << "Delete cellid from cell graph list";
  2693. graph_info_map_.erase((*it)->fg);
  2694. cell_graph_list_.erase(it);
  2695. ret = true;
  2696. break;
  2697. }
  2698. return BaseRefToPyData(ret);
  2699. }
  2700. py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) {
  2701. const auto &cell_id = GetCellId(cell, args);
  2702. auto top_cell = GetTopCell(cell_id);
  2703. bool forward_run = false;
  2704. if (top_cell != nullptr) {
  2705. forward_run = top_cell->forward_already_run;
  2706. if (forward_run) {
  2707. top_cell_index_ = top_cell->top_cell_index;
  2708. }
  2709. }
  2710. MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id << " top_cell_id_ "
  2711. << top_cell_id_;
  2712. return BaseRefToPyData(forward_run);
  2713. }
  2714. void PynativeExecutor::RunInner(const py::object &cell, const py::tuple &args, const py::object &phase,
  2715. py::object *ret) {
  2716. MS_EXCEPTION_IF_NULL(ret);
  2717. auto cell_id = GetCellId(cell, args);
  2718. MS_LOG(DEBUG) << "Run start cell id " << cell_id;
  2719. bool has_sens = false;
  2720. for (const auto &it : top_cell_list_) {
  2721. if (cell_id.find(it->cell_id) != std::string::npos && cell_id != it->cell_id) {
  2722. has_sens = true;
  2723. break;
  2724. }
  2725. }
  2726. py::object forward_args = args;
  2727. cell_id = GetGradCellId(has_sens, cell, args, &forward_args);
  2728. MS_LOG(DEBUG) << "Run has sens " << has_sens << " forward cell id " << cell_id;
  2729. auto resource = GetResource(cell_id);
  2730. MS_EXCEPTION_IF_NULL(resource);
  2731. MS_LOG(DEBUG) << "Run resource ptr " << resource.get();
  2732. VectorRef arg_list;
  2733. py::tuple converted_args = ConvertArgs(args);
  2734. pipeline::ProcessVmArgInner(converted_args, resource, &arg_list);
  2735. if (resource->results().find(pipeline::kOutput) == resource->results().end()) {
  2736. MS_LOG(EXCEPTION) << "Can't find run graph output";
  2737. }
  2738. if (!resource->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
  2739. MS_LOG(EXCEPTION) << "Run graph is not VmEvalFuncPtr";
  2740. }
  2741. compile::VmEvalFuncPtr run = resource->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
  2742. MS_EXCEPTION_IF_NULL(run);
  2743. std::string backend = MsContext::GetInstance()->backend_policy();
  2744. MS_LOG(DEBUG) << "Eval run " << backend;
  2745. set_grad_runing(true);
  2746. BaseRef value = (*run)(arg_list);
  2747. set_grad_runing(false);
  2748. MS_LOG(DEBUG) << "Eval run end " << value.ToString();
  2749. *ret = BaseRefToPyData(value);
  2750. auto do_vm_compiled =
  2751. std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
  2752. [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->do_vm_compiled; });
  2753. if (do_vm_compiled) {
  2754. if (MakeBpropNestedCnode(cell, *ret, cell_id)) {
  2755. return;
  2756. }
  2757. MakeNestedCnode(cell_id, args, resource, *ret, has_sens);
  2758. }
  2759. }
  2760. bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id) {
  2761. if (graph_stack_.empty() || !py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
  2762. MS_LOG(DEBUG) << "No nested bprop grad find";
  2763. return false;
  2764. }
  2765. auto out_id = GetId(out);
  2766. std::vector<AnfNodePtr> inputs;
  2767. inputs.emplace_back(NewValueNode(curr_g_));
  2768. PopGraphStack();
  2769. auto graph_info = graph_info_map_.at(curr_g_);
  2770. MS_EXCEPTION_IF_NULL(graph_info);
  2771. for (const auto &ig : graph_info->params) {
  2772. if (!ig.second->has_default()) {
  2773. inputs.emplace_back(ig.second);
  2774. }
  2775. }
  2776. auto cnode = curr_g_->NewCNode(inputs);
  2777. SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
  2778. SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
  2779. MS_LOG(DEBUG) << "Custom bprop make nested node is " << cnode->DebugString(4);
  2780. return true;
  2781. }
  2782. void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
  2783. const py::object &out, bool has_sens) {
  2784. if (graph_stack_.empty()) {
  2785. MS_LOG(DEBUG) << "No nested grad find";
  2786. return;
  2787. }
  2788. auto graph_prev = graph_stack_.top();
  2789. MS_EXCEPTION_IF_NULL(graph_prev);
  2790. MS_LOG(DEBUG) << "Get pre graph ptr " << graph_prev.get();
  2791. auto newfg = resource->func_graph();
  2792. MS_EXCEPTION_IF_NULL(newfg);
  2793. auto inputs_size = args.size();
  2794. if (has_sens) {
  2795. inputs_size -= 1;
  2796. }
  2797. std::vector<AnfNodePtr> inputs;
  2798. inputs.emplace_back(NewValueNode(newfg));
  2799. for (size_t i = 0; i < inputs_size; ++i) {
  2800. inputs.emplace_back(GetInput(args[i], false));
  2801. }
  2802. if (newfg->parameters().size() > args.size()) {
  2803. RecoverGraphParams(newfg, cell_id, &inputs);
  2804. }
  2805. auto out_id = GetId(out);
  2806. auto cnode = graph_prev->NewCNode(inputs);
  2807. SetTupleArgsToGraphInfoMap(graph_prev, out, cnode);
  2808. SetNodeMapInGraphInfoMap(graph_prev, out_id, cnode);
  2809. MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
  2810. }
  2811. void PynativeExecutor::RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id,
  2812. std::vector<AnfNodePtr> *inputs) {
  2813. FuncGraphPtr forward_graph = nullptr;
  2814. auto ic = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
  2815. [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
  2816. if (ic != cell_graph_list_.end()) {
  2817. forward_graph = (*ic)->fg;
  2818. }
  2819. MS_EXCEPTION_IF_NULL(forward_graph);
  2820. auto param_list = replace_weights_map_.at(forward_graph);
  2821. auto params = newfg->parameters();
  2822. auto manage = Manage({newfg}, false);
  2823. for (const auto &it : params) {
  2824. auto param = it->cast<ParameterPtr>();
  2825. if (!param->has_default()) {
  2826. continue;
  2827. }
  2828. for (auto p = param_list.begin(); p != param_list.end();) {
  2829. MS_LOG(DEBUG) << "Param name " << param->name() << " ptr " << param.get();
  2830. if (p->second->name() == param->name()) {
  2831. manage->Replace(param, p->first);
  2832. inputs->emplace_back(p->first);
  2833. param_list.erase(p);
  2834. break;
  2835. }
  2836. }
  2837. }
  2838. replace_weights_map_.erase(forward_graph);
  2839. }
  2840. void PynativeExecutor::Clear(const std::string &cell_id) {
  2841. if (cell_id.empty()) {
  2842. Clean();
  2843. return;
  2844. }
  2845. MS_LOG(DEBUG) << "Clear cell res, cell id " << cell_id;
  2846. for (auto it = graph_info_map_.begin(); it != graph_info_map_.end();) {
  2847. if (it->second->cell_id.find(cell_id) != std::string::npos) {
  2848. it = graph_info_map_.erase(it);
  2849. } else {
  2850. ++it;
  2851. }
  2852. }
  2853. // Maybe exit in runop step
  2854. auto ms_context = MsContext::GetInstance();
  2855. if (ms_context != nullptr) {
  2856. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
  2857. }
  2858. ConfigManager::GetInstance().ResetIterNum();
  2859. VectorClear<std::vector<CellInfoPtr>>(&cell_graph_list_, cell_id);
  2860. VectorClear<std::vector<TopCellInfoPtr>>(&top_cell_list_, cell_id);
  2861. node_abs_map_.clear();
  2862. }
  2863. void PynativeExecutor::Clean() {
  2864. MS_LOG(DEBUG) << "Clean";
  2865. SubNestedGradOrder();
  2866. node_abs_map_.clear();
  2867. obj_to_forward_id_.clear();
  2868. ad::CleanRes();
  2869. pipeline::ReclaimOptimizer();
  2870. }
  2871. void PynativeExecutor::ClearRes() {
  2872. MS_LOG(DEBUG) << "Clear all res";
  2873. Clean();
  2874. graph_id_ = 0;
  2875. grad_order_ = 0;
  2876. grad_flag_ = false;
  2877. in_grad_process_ = false;
  2878. in_bprop_process_ = false;
  2879. has_dynamic_cell_ = false;
  2880. grad_is_running_ = false;
  2881. need_replace_forward_ = true;
  2882. curr_g_ = nullptr;
  2883. top_cell_id_.clear();
  2884. graph_info_map_.clear();
  2885. replace_weights_map_.clear();
  2886. cell_graph_list_.clear();
  2887. top_cell_list_.clear();
  2888. cell_input_args_.clear();
  2889. op_index_map_.clear();
  2890. cell_op_index_with_tensor_id_.clear();
  2891. cell_tensor_id_with_tensor_.clear();
  2892. prim_abs_list_.clear();
  2893. all_value_node_tensors_.clear();
  2894. std::stack<FuncGraphPtr>().swap(graph_stack_);
  2895. std::stack<std::string>().swap(cell_op_info_stack_);
  2896. }
  2897. void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
  2898. PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args);
  2899. }
  2900. void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
  2901. MS_LOG(DEBUG) << "Enter end graph process.";
  2902. auto &mem_cleaner = pipeline::Resource::mem_cleaner();
  2903. mem_cleaner.EnterPynativeEndGraphProcess();
  2904. PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
  2905. mem_cleaner.LeavePynativeEndGraphProcess();
  2906. MS_LOG(DEBUG) << "Leave end graph process.";
  2907. }
  2908. void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  2909. const py::args &args) {
  2910. PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
  2911. }
  2912. py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) {
  2913. py::object ret;
  2914. PynativeExecutorTry(this, &PynativeExecutor::RunInner, cell, args, phase, &ret);
  2915. return ret;
  2916. }
  2917. void PynativeExecutor::Sync() {
  2918. if (session == nullptr) {
  2919. MS_EXCEPTION(NotExistsError) << "No session has been created!";
  2920. }
  2921. session->SyncStream();
  2922. }
  2923. void PynativeExecutor::EnterConstruct(const py::object &cell) {
  2924. if (top_cell_ != nullptr) {
  2925. return;
  2926. }
  2927. top_cell_ = cell.ptr();
  2928. pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess();
  2929. MS_LOG(DEBUG) << "Enter construct process.";
  2930. }
  2931. void PynativeExecutor::LeaveConstruct(const py::object &cell) {
  2932. if (top_cell_ != cell.ptr()) {
  2933. return;
  2934. }
  2935. top_cell_ = nullptr;
  2936. pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess();
  2937. MS_LOG(DEBUG) << "Leave construct process.";
  2938. }
  2939. REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
  2940. (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
  2941. .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
  2942. .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
  2943. .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
  2944. .def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.")
  2945. .def("check_run", &PynativeExecutor::CheckAlreadyRun, "pynative check graph run before.")
  2946. .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
  2947. .def("clear", &PynativeExecutor::Clear, "pynative clear status.")
  2948. .def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
  2949. .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
  2950. .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
  2951. "Executor set grad flag.")
  2952. .def("enter_construct", &PynativeExecutor::EnterConstruct,
  2953. "Do something before enter construct function.")
  2954. .def("leave_construct", &PynativeExecutor::LeaveConstruct,
  2955. "Do something after leave construct function.");
  2956. }));
  2957. } // namespace mindspore::pynative