You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.cc 118 kB

4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/session_basic.h"
  17. #include <algorithm>
  18. #include <set>
  19. #include <queue>
  20. #include <utility>
  21. #include <functional>
  22. #include <unordered_map>
  23. #include "utils/hash_map.h"
  24. #include "ops/primitive_c.h"
  25. #include "ir/manager.h"
  26. #include "abstract/utils.h"
  27. #include "backend/kernel_compiler/common_utils.h"
  28. #include "base/core_ops.h"
  29. #include "base/base_ref_utils.h"
  30. #include "common/trans.h"
  31. #include "utils/config_manager.h"
  32. #include "backend/session/anf_runtime_algorithm.h"
  33. #include "backend/session/executor_manager.h"
  34. #include "backend/optimizer/common/common_backend_optimization.h"
  35. #include "backend/optimizer/common/helper.h"
  36. #include "runtime/device/kernel_runtime_manager.h"
  37. #include "utils/ms_utils.h"
  38. #include "ir/anf.h"
  39. #include "ir/func_graph_cloner.h"
  40. #include "utils/utils.h"
  41. #include "debug/anf_ir_dump.h"
  42. #include "debug/dump_proto.h"
  43. #include "utils/file_utils.h"
  44. #include "utils/trace_base.h"
  45. #include "frontend/parallel/context.h"
  46. #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
  47. #include "ps/ps_cache/ps_cache_manager.h"
  48. #include "ps/constants.h"
  49. #include "ps/util.h"
  50. #include "ps/ps_context.h"
  51. #include "abstract/abstract_value.h"
  52. #endif
  53. #include "backend/session/session_factory.h"
  54. #include "backend/session/pynative_task_manager.h"
  55. namespace mindspore {
  56. namespace session {
  57. MS_REG_SESSION(kSessionBasic, SessionBasic);
  58. namespace {
  59. const int kSummaryGetItem = 2;
  60. const size_t max_depth = 128;
  61. bool IsShapeDynamic(const abstract::ShapePtr &shape) {
  62. if (shape == nullptr) {
  63. return false;
  64. }
  65. return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
  66. }
  67. bool RecursiveCheck(const FuncGraphManagerPtr &manager, const std::pair<AnfNodePtr, int64_t> &kernel, size_t *idx) {
  68. auto node = kernel.first;
  69. MS_EXCEPTION_IF_NULL(manager);
  70. MS_EXCEPTION_IF_NULL(node);
  71. if (kernel.second > 1 &&
  72. (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad))) {
  73. return false;
  74. }
  75. if (AnfUtils::IsRealKernel(node) && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
  76. return true;
  77. }
  78. (*idx) += 1;
  79. // max recursion depth
  80. if (*idx <= max_depth) {
  81. auto users = manager->node_users()[node];
  82. if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
  83. return RecursiveCheck(manager, kernel, idx);
  84. })) {
  85. return true;
  86. }
  87. }
  88. return false;
  89. }
  90. bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, const uint32_t graph_id) {
  91. MS_EXCEPTION_IF_NULL(manager);
  92. MS_EXCEPTION_IF_NULL(node);
  93. auto node_users = manager->node_users()[node];
  94. // filter nodes not in current graph
  95. for (auto iter = node_users.begin(); iter != node_users.end();) {
  96. auto func_graph = iter->first->func_graph();
  97. auto kernel_graph = func_graph->cast<KernelGraphPtr>();
  98. if (kernel_graph == nullptr) {
  99. MS_LOG(EXCEPTION) << "func graph cast kernel graph failed, related node is: " << iter->first->DebugString();
  100. }
  101. if (kernel_graph->graph_id() != graph_id) {
  102. iter = node_users.erase(iter);
  103. } else {
  104. iter++;
  105. }
  106. }
  107. size_t idx = 0;
  108. if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
  109. return RecursiveCheck(manager, kernel, &idx);
  110. })) {
  111. return true;
  112. }
  113. return false;
  114. }
  115. ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
  116. if (node == nullptr) {
  117. return nullptr;
  118. }
  119. auto parameter = node->cast<ParameterPtr>();
  120. if (parameter == nullptr || !parameter->has_default()) {
  121. return nullptr;
  122. }
  123. return parameter->param_info();
  124. }
  125. static bool IsPynativeMode() {
  126. auto ms_context = MsContext::GetInstance();
  127. MS_EXCEPTION_IF_NULL(ms_context);
  128. return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
  129. }
  130. BaseRef GetNodeOutputTensorFromInputs(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
  131. const std::vector<tensor::TensorPtr> &input_tensors) {
  132. auto &node = node_output_pair.first;
  133. MS_EXCEPTION_IF_NULL(node);
  134. if (HasAbstractMonad(node)) {
  135. return std::make_shared<tensor::Tensor>(int64_t(0), kBool);
  136. }
  137. // if node is a value node, no need sync addr from device to host
  138. if (node->isa<ValueNode>()) {
  139. auto value_node = node->cast<ValueNodePtr>();
  140. MS_EXCEPTION_IF_NULL(value_node);
  141. return value_node->value();
  142. }
  143. if (IsPynativeMode()) {
  144. return nullptr;
  145. }
  146. if (!node->isa<Parameter>()) {
  147. return nullptr;
  148. }
  149. MS_EXCEPTION_IF_NULL(graph);
  150. auto param_node = node->cast<ParameterPtr>();
  151. if (param_node != nullptr && param_node->IsUsedByRealKernelInGraph(graph->graph_id())) {
  152. return nullptr;
  153. }
  154. for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
  155. if (input_idx >= input_tensors.size()) {
  156. MS_LOG(EXCEPTION) << "Input idx:" << input_idx << " is out of range:" << input_tensors.size();
  157. }
  158. if (graph->inputs()[input_idx] == node) {
  159. return input_tensors[input_idx];
  160. }
  161. }
  162. return nullptr;
  163. }
  164. int64_t ShapeSize(const std::vector<int64_t> &shape) {
  165. return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
  166. }
  167. BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
  168. const std::vector<tensor::TensorPtr> &input_tensors,
  169. std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
  170. auto &node = node_output_pair.first;
  171. size_t output_index = node_output_pair.second;
  172. MS_EXCEPTION_IF_NULL(node);
  173. MS_EXCEPTION_IF_NULL(graph);
  174. auto tensor_from_input = GetNodeOutputTensorFromInputs(node_output_pair, graph, input_tensors);
  175. if (tensor_from_input != nullptr) {
  176. return tensor_from_input;
  177. }
  178. TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
  179. if (type_id == kTypeUnknown) {
  180. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  181. }
  182. std::vector<int64_t> temp_shape;
  183. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  184. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  185. if (AnfAlgo::IsDynamicShape(node)) {
  186. auto max_shape = AnfAlgo::GetOutputMaxShape(node, output_index);
  187. temp_shape = ShapeSize(max_shape) > ShapeSize(temp_shape) ? max_shape : temp_shape;
  188. }
  189. tensor::TensorPtr tensor;
  190. bool is_internal_output = graph->IsInternalOutput(node, output_index);
  191. if (is_internal_output) {
  192. tensor = graph->GetInternalOutputTensor(node, output_index);
  193. if (tensor == nullptr) {
  194. tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  195. graph->AddInternalOutputTensor(node, output_index, tensor);
  196. }
  197. } else {
  198. tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  199. }
  200. MS_EXCEPTION_IF_NULL(tensor);
  201. tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
  202. if (is_internal_output) {
  203. tensor->set_sync_status(kNoNeedSync);
  204. } else {
  205. // if in pynative mode,data only copied to host when user want to print data
  206. auto ms_context = MsContext::GetInstance();
  207. MS_EXCEPTION_IF_NULL(ms_context);
  208. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
  209. ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
  210. tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
  211. } else {
  212. tensor->set_sync_status(kNeedSyncDeviceToHost);
  213. }
  214. }
  215. tensor->SetIsGraphOutput();
  216. (*tensor_to_node)[tensor] = node_output_pair;
  217. return tensor;
  218. }
  219. BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &graph,
  220. const std::vector<tensor::TensorPtr> &input_tensors,
  221. std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
  222. KernelMapTensor *node_to_tensor) {
  223. MS_EXCEPTION_IF_NULL(anf);
  224. MS_EXCEPTION_IF_NULL(tensor_to_node);
  225. MS_EXCEPTION_IF_NULL(node_to_tensor);
  226. MS_LOG(DEBUG) << "Create tensor for output[" << anf->DebugString() << "]";
  227. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  228. MS_EXCEPTION_IF_NULL(item_with_index.first);
  229. MS_LOG(DEBUG) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
  230. // special handle for maketuple
  231. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  232. auto cnode = item_with_index.first->cast<CNodePtr>();
  233. MS_EXCEPTION_IF_NULL(cnode);
  234. VectorRef ret;
  235. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  236. auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node, node_to_tensor);
  237. ret.push_back(out);
  238. }
  239. return ret;
  240. }
  241. // if is graph return nothing ,the function should return a null anylist
  242. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  243. if (size == 0) {
  244. return VectorRef();
  245. }
  246. // The outputs of graph may have the same kernel node, no need to create new tensor.
  247. const auto &iter = node_to_tensor->find(item_with_index);
  248. if (iter != node_to_tensor->end()) {
  249. return iter->second;
  250. }
  251. const auto &tensor = CreateNodeOutputTensor(item_with_index, graph, input_tensors, tensor_to_node);
  252. (*node_to_tensor)[item_with_index] = tensor;
  253. return tensor;
  254. }
  255. ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
  256. MS_EXCEPTION_IF_NULL(anf);
  257. MS_EXCEPTION_IF_NULL(graph);
  258. auto value_node = anf->cast<ValueNodePtr>();
  259. MS_EXCEPTION_IF_NULL(value_node);
  260. auto value = value_node->value();
  261. MS_EXCEPTION_IF_NULL(value);
  262. if (value->isa<None>()) {
  263. return nullptr;
  264. }
  265. auto new_value_node = graph->NewValueNode(value_node);
  266. graph->FrontBackendMapAdd(anf, new_value_node);
  267. graph->AddValueNodeToGraph(new_value_node);
  268. return new_value_node;
  269. }
  270. ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
  271. int64_t tensor_mask) {
  272. MS_EXCEPTION_IF_NULL(graph);
  273. auto param = graph->NewParameter();
  274. MS_EXCEPTION_IF_NULL(param);
  275. if (tensor_mask == kParameterWeightTensorMask) {
  276. param->set_default_param(input_tensor);
  277. }
  278. // set the kernel info of parameter
  279. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  280. MS_EXCEPTION_IF_NULL(input_tensor);
  281. auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
  282. if (device_address == nullptr) {
  283. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  284. TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
  285. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
  286. } else {
  287. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
  288. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
  289. kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
  290. AnfAlgo::SetOutputAddr(device_address, 0, param.get());
  291. }
  292. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
  293. // construct abstract of parameter
  294. auto type_of_tensor = input_tensor->Dtype();
  295. auto shape_of_tensor = input_tensor->shape();
  296. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  297. param->set_abstract(abstract);
  298. return param;
  299. }
  300. void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
  301. MS_LOG(INFO) << "Graph outputs:";
  302. const size_t max_deep = 10;
  303. if (recurse_level > max_deep) {
  304. MS_LOG(INFO) << "Recurse too deep";
  305. return;
  306. }
  307. std::string tab_str;
  308. for (size_t i = 0; i < recurse_level; i++) {
  309. tab_str = tab_str.append(" ");
  310. }
  311. if (any.is<AnyList>()) {
  312. (void)tab_str.append("{");
  313. MS_LOG(INFO) << tab_str;
  314. auto any_list = any.cast<AnyList>();
  315. for (auto &it : any_list) {
  316. DumpGraphOutput(it, recurse_level + 1);
  317. }
  318. (void)tab_str.append("}");
  319. MS_LOG(INFO) << tab_str;
  320. }
  321. (void)tab_str.append(any.ToString());
  322. MS_LOG(INFO) << tab_str;
  323. }
  324. #ifndef ENABLE_SECURITY
  325. bool ExistSummaryNode(const KernelGraph *graph) {
  326. MS_EXCEPTION_IF_NULL(graph);
  327. auto ret = graph->get_return();
  328. MS_EXCEPTION_IF_NULL(ret);
  329. auto all_nodes = DeepLinkedGraphSearch(ret);
  330. for (auto &n : all_nodes) {
  331. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  332. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  333. return true;
  334. }
  335. }
  336. return false;
  337. }
  338. #endif
  339. BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
  340. const std::vector<tensor::TensorPtr> &input_tensors,
  341. const std::vector<size_t> &indexes,
  342. std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
  343. auto &node = node_output_pair.first;
  344. MS_EXCEPTION_IF_NULL(node);
  345. MS_EXCEPTION_IF_NULL(graph);
  346. MS_EXCEPTION_IF_NULL(output_indexes);
  347. MS_LOG(DEBUG) << "Create placeholder for output[" << node->DebugString() << "] index[" << node_output_pair.second
  348. << "]";
  349. // if node is a value node, no need sync addr from device to host
  350. if (node->isa<ValueNode>()) {
  351. auto value_node = node->cast<ValueNodePtr>();
  352. MS_EXCEPTION_IF_NULL(value_node);
  353. return value_node->value();
  354. }
  355. if (node->isa<Parameter>()) {
  356. for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
  357. if (input_idx >= input_tensors.size()) {
  358. MS_LOG(EXCEPTION) << "Input idx:" << input_idx << " is out of range:" << input_tensors.size();
  359. }
  360. if (graph->inputs()[input_idx] == node) {
  361. return input_tensors[input_idx];
  362. }
  363. }
  364. MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr";
  365. }
  366. (*output_indexes)[node_output_pair].emplace_back(indexes);
  367. BaseRef output_placeholder = std::make_shared<BaseRef>();
  368. return output_placeholder;
  369. }
  370. BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph,
  371. const std::vector<tensor::TensorPtr> &input_tensors,
  372. const std::vector<size_t> &indexes,
  373. std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
  374. MS_EXCEPTION_IF_NULL(anf);
  375. MS_EXCEPTION_IF_NULL(output_indexes);
  376. MS_LOG(DEBUG) << "Create placeholder for output[" << anf->DebugString() << "]";
  377. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  378. MS_EXCEPTION_IF_NULL(item_with_index.first);
  379. MS_LOG(DEBUG) << "Create placeholder for output after visit:" << item_with_index.first->DebugString();
  380. // special handle for maketuple
  381. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  382. auto cnode = item_with_index.first->cast<CNodePtr>();
  383. MS_EXCEPTION_IF_NULL(cnode);
  384. VectorRef ret;
  385. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  386. std::vector<size_t> cur_index = indexes;
  387. cur_index.emplace_back(i - 1);
  388. auto out = CreateNodeOutputPlaceholder(cnode->input(i), graph, input_tensors, cur_index, output_indexes);
  389. ret.push_back(out);
  390. }
  391. return ret;
  392. }
  393. // if is graph return nothing ,the function should return a null anylist
  394. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  395. if (size == 0) {
  396. return VectorRef();
  397. }
  398. return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes);
  399. }
  400. void CheckInputTensorShape(const TensorPtr &tensor, const CNodePtr &kernel, size_t input_index) {
  401. MS_EXCEPTION_IF_NULL(tensor);
  402. const auto &tensor_shape = tensor->shape();
  403. const auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel, input_index);
  404. if (tensor_shape.size() != input_shape.size()) {
  405. MS_LOG(EXCEPTION) << "The input tensor's shape size: " << tensor_shape.size()
  406. << " is not equal to expected size: " << input_shape.size() << " for input[" << input_index
  407. << "] of kernel: " << AnfAlgo::GetCNodeName(kernel) << trace::DumpSourceLines(kernel);
  408. }
  409. for (size_t i = 0; i < tensor_shape.size(); i++) {
  410. if (tensor_shape[i] < 0 || static_cast<size_t>(tensor_shape[i]) != input_shape[i]) {
  411. MS_LOG(EXCEPTION) << "The input tensor's shape: " << tensor_shape
  412. << " is not equal to expected shape: " << input_shape << " for input[" << input_index
  413. << "] of kernel: " << AnfAlgo::GetCNodeName(kernel) << trace::DumpSourceLines(kernel);
  414. }
  415. }
  416. }
  417. bool ExistGraphCaller(const AnfNodePtr &partial_node) {
  418. MS_EXCEPTION_IF_NULL(partial_node);
  419. auto partial_cnode = partial_node->cast<CNodePtr>();
  420. MS_EXCEPTION_IF_NULL(partial_cnode);
  421. auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
  422. MS_EXCEPTION_IF_NULL(partial_graph);
  423. auto graph_nodes = TopoSort(partial_graph->get_return());
  424. return std::any_of(graph_nodes.begin(), graph_nodes.end(), IsValueNode<FuncGraph>);
  425. }
  426. // 1. Convert the node to make_tuple if the node is a ValueNode<ValueTuple> and it's the input of 'return' node.
  427. // 2. Set the return of graph if node is "Return" node.
  428. void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
  429. MS_EXCEPTION_IF_NULL(graph);
  430. MS_EXCEPTION_IF_NULL(node);
  431. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
  432. constexpr auto kReturnInputIdx = 1;
  433. auto return_node = node->cast<CNodePtr>();
  434. graph->set_return(return_node);
  435. auto graph_output = return_node->input(kReturnInputIdx);
  436. MS_EXCEPTION_IF_NULL(graph_output);
  437. // If return's input is value node, then the graph has no kernel, and the pass 'trans tuple to make_tuple' cannot
  438. // match this pattern because that pass begin with output node but return node. So we add transform value tuple
  439. // to make_tuple here.
  440. if (AnfAlgo::IsTupleOutput(graph_output) && graph_output->isa<ValueNode>()) {
  441. return_node->set_input(kReturnInputIdx, graph->TransTupleToMakeTuple(graph_output));
  442. }
  443. }
  444. }
  445. } // namespace
  446. GraphId SessionBasic::graph_sum_ = 0;
  447. void SessionBasic::InitExecutor(const std::string &device_name, uint32_t device_id) {
  448. device_id_ = device_id;
  449. context_ = std::make_shared<Context>(device_name, device_id);
  450. executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id);
  451. }
  452. GraphId SessionBasic::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
  453. for (const auto &graph_item : graphs_) {
  454. auto graph = graph_item.second;
  455. MS_EXCEPTION_IF_NULL(graph);
  456. // if front_anf is a parameter,the backend parameter may have two
  457. if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
  458. return graph_item.first;
  459. }
  460. }
  461. MS_EXCEPTION_IF_NULL(front_anf);
  462. MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
  463. return kInvalidGraphId;
  464. }
  465. KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
  466. auto it = graphs_.find(graph_id);
  467. if (it == graphs_.end()) {
  468. MS_LOG(INFO) << "Can't find graph " << graph_id;
  469. return nullptr;
  470. }
  471. return it->second;
  472. }
  473. void SessionBasic::ClearGraph() {
  474. auto graph_iter = graphs_.begin();
  475. while (graph_iter != graphs_.end()) {
  476. graph_iter->second.reset();
  477. graph_iter = graphs_.erase(graph_iter);
  478. }
  479. graph_sum_ = 0;
  480. }
  481. void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter) {
  482. auto graph_id = GetGraphIdByNode(out_node);
  483. if (graph_id == kInvalidGraphId) {
  484. return;
  485. }
  486. auto node_graph = GetGraph(graph_id);
  487. if (node_graph == nullptr) {
  488. return;
  489. }
  490. MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
  491. auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node);
  492. if (ref_node == nullptr) {
  493. MS_LOG(INFO) << "No corresponding internal output for output node";
  494. return;
  495. }
  496. size_t output_idx = 0;
  497. if (AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  498. output_idx = AnfAlgo::GetTupleGetItemOutIndex(out_node->cast<CNodePtr>());
  499. }
  500. auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx);
  501. auto ref_real_node = real_kernel.first;
  502. auto ref_real_node_index = real_kernel.second;
  503. if (ref_real_node->isa<CNode>() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) {
  504. auto kernel_info = ref_real_node->kernel_info();
  505. if (kernel_info == nullptr || !kernel_info->has_build_info()) {
  506. MS_LOG(INFO) << "No kernel info";
  507. return;
  508. }
  509. if (!opt::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
  510. MS_LOG(INFO) << "No kernel address";
  511. return;
  512. }
  513. auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
  514. auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
  515. auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
  516. auto d_kernel_info = std::make_shared<device::KernelInfo>();
  517. MS_EXCEPTION_IF_NULL(d_kernel_info);
  518. parameter->set_kernel_info(d_kernel_info);
  519. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  520. builder.SetOutputsDeviceType({type});
  521. builder.SetOutputsFormat({format});
  522. d_kernel_info->set_select_kernel_build_info(builder.Build());
  523. AnfAlgo::SetOutputAddr(address, 0, parameter.get());
  524. auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type),
  525. parameter->Shape()->cast<abstract::BaseShapePtr>());
  526. parameter->set_abstract(abstract);
  527. }
  528. }
  529. AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
  530. MS_EXCEPTION_IF_NULL(node);
  531. MS_EXCEPTION_IF_NULL(graph);
  532. auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
  533. auto parameters = AnfAlgo::GetAllOutput(new_parameter);
  534. std::vector<AnfNodePtr> pre_graph_out = {node};
  535. // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
  536. if (!pre_graph_out.empty() && !AnfUtils::IsRealKernel(node)) {
  537. pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
  538. }
  539. for (size_t i = 0; i < parameters.size(); ++i) {
  540. const auto &parameter = parameters[i];
  541. auto context_ptr = MsContext::GetInstance();
  542. MS_EXCEPTION_IF_NULL(context_ptr);
  543. if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) == true) {
  544. // In control flow, if the input of the cnode is a call node, it will be processed as a make_tuple input,
  545. // which needs to be linked when processing the internal node.
  546. graph->CacheInternalParameterToFrontNode(parameter, {node, i});
  547. }
  548. auto valid_inputs = graph->MutableValidInputs();
  549. MS_EXCEPTION_IF_NULL(valid_inputs);
  550. auto graph_inputs = graph->MutableInputs();
  551. MS_EXCEPTION_IF_NULL(graph_inputs);
  552. valid_inputs->push_back(true);
  553. graph_inputs->push_back(parameter);
  554. }
  555. size_t param_index = 0;
  556. for (const auto &out_node : pre_graph_out) {
  557. size_t output_size = AnfAlgo::GetOutputTensorNum(out_node);
  558. for (size_t i = 0; i < output_size; i++) {
  559. if (param_index >= parameters.size()) {
  560. MS_LOG(EXCEPTION) << "Parameters size:" << parameters.size() << "out of range.Node:" << node->DebugString()
  561. << ",out_node:" << out_node->DebugString();
  562. }
  563. InitInternalOutputParameter(out_node, parameters[param_index++]);
  564. }
  565. }
  566. return new_parameter;
  567. }
  568. ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  569. MS_EXCEPTION_IF_NULL(anf);
  570. if (!anf->isa<Parameter>()) {
  571. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  572. }
  573. MS_EXCEPTION_IF_NULL(graph);
  574. auto param_value = GetParamDefaultValue(anf);
  575. auto valid_inputs = graph->MutableValidInputs();
  576. MS_EXCEPTION_IF_NULL(valid_inputs);
  577. auto graph_inputs = graph->MutableInputs();
  578. MS_EXCEPTION_IF_NULL(graph_inputs);
  579. ParameterPtr new_parameter = nullptr;
  580. // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
  581. if (param_value != nullptr) {
  582. new_parameter = param_value->parameter();
  583. }
  584. if (new_parameter == nullptr) {
  585. TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
  586. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  587. auto input_node_iter = partial_parameters_map_.find(anf);
  588. if (input_node_iter != partial_parameters_map_.end()) {
  589. InitInternalOutputParameter(input_node_iter->second, new_parameter);
  590. }
  591. if (param_value != nullptr) {
  592. param_value->set_parameter(new_parameter);
  593. }
  594. }
  595. new_parameter->IncreaseUsedGraphCount();
  596. graph_inputs->push_back(new_parameter);
  597. valid_inputs->push_back(true);
  598. return new_parameter;
  599. }
  600. AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) {
  601. MS_EXCEPTION_IF_NULL(anf);
  602. MS_EXCEPTION_IF_NULL(graph);
  603. MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
  604. return CreateParameterFromTuple(anf, graph);
  605. }
  606. void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const {
  607. MS_EXCEPTION_IF_NULL(cnode);
  608. MS_EXCEPTION_IF_NULL(cnode_inputs);
  609. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  610. if (prim != nullptr) {
  611. // push attr to inputs[0] of new cnode
  612. cnode_inputs->push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
  613. } else {
  614. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  615. MS_EXCEPTION_IF_NULL(fg);
  616. auto new_fg = BasicClone(fg);
  617. cnode_inputs->push_back(std::make_shared<ValueNode>(new_fg));
  618. }
  619. }
  620. void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
  621. mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  622. MS_EXCEPTION_IF_NULL(cnode);
  623. MS_EXCEPTION_IF_NULL(graph);
  624. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  625. MS_EXCEPTION_IF_NULL(cnode_inputs);
  626. auto origin_inputs = cnode->inputs();
  627. const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend);
  628. // if has multiple depends,only select first depend as parameter
  629. for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
  630. auto anf = origin_inputs[input_idx];
  631. MS_EXCEPTION_IF_NULL(anf);
  632. // anf has been created before
  633. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  634. (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  635. continue;
  636. } else if ((is_depend && input_idx > kRealInputIndexInDepend)) {
  637. cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
  638. continue;
  639. } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
  640. cnode_inputs->push_back((*other_graph_cnode)[anf]);
  641. continue;
  642. } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
  643. // if input is a value node,
  644. auto new_value_node = CreateNewValueNode(anf, graph);
  645. if (new_value_node != nullptr) {
  646. (void)cnode_inputs->emplace_back(new_value_node);
  647. }
  648. continue;
  649. } else if (anf->isa<Parameter>()) {
  650. auto new_parameter = CreateNewParameterFromParameter(anf, graph);
  651. cnode_inputs->push_back(new_parameter);
  652. graph->FrontBackendMapAdd(anf, new_parameter);
  653. continue;
  654. } else {
  655. // the input node is a cnode from other graph
  656. auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph);
  657. if (parameter_from_cnode == nullptr) {
  658. parameter_from_cnode = NewValueNode(MakeValue(SizeToLong(input_idx)));
  659. }
  660. if (parameter_from_cnode->isa<Parameter>() && IsPrimitiveCNode(anf, prim::kPrimLoad)) {
  661. auto para = parameter_from_cnode->cast<ParameterPtr>();
  662. auto load_cnode = anf->cast<CNodePtr>();
  663. para->set_name(load_cnode->input(kFirstDataInputIndex)->fullname_with_scope());
  664. }
  665. cnode_inputs->push_back(parameter_from_cnode);
  666. (*other_graph_cnode)[anf] = parameter_from_cnode;
  667. }
  668. }
  669. }
  670. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
  671. mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  672. MS_EXCEPTION_IF_NULL(cnode);
  673. MS_EXCEPTION_IF_NULL(graph);
  674. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  675. // get primitive of old node
  676. std::vector<AnfNodePtr> cnode_inputs;
  677. GetCNodeInfo(cnode, &cnode_inputs);
  678. GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
  679. TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
  680. auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
  681. return new_cnode;
  682. }
  683. CNodePtr SessionBasic::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph) {
  684. MS_EXCEPTION_IF_NULL(node_input);
  685. MS_EXCEPTION_IF_NULL(graph);
  686. // switch input generalizes partial
  687. std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
  688. if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) {
  689. auto backend_node = graph->GetBackendAnfByFrontAnf(node_input);
  690. return backend_node->cast<CNodePtr>();
  691. } else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
  692. partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
  693. } else {
  694. KernelGraphPtr kernel_graph = NewKernelGraph();
  695. MS_EXCEPTION_IF_NULL(kernel_graph);
  696. auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get());
  697. MS_EXCEPTION_IF_NULL(parameter);
  698. parameter->set_abstract(cnode->abstract());
  699. auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
  700. auto return_node = kernel_graph->NewCNode({primitive, parameter});
  701. return_node->set_abstract(cnode->abstract());
  702. kernel_graph->set_return(return_node);
  703. partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
  704. partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
  705. }
  706. auto partial_node = graph->NewCNode(partial_inputs);
  707. return partial_node;
  708. }
  709. std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph) {
  710. MS_EXCEPTION_IF_NULL(cnode);
  711. MS_EXCEPTION_IF_NULL(graph);
  712. std::vector<AnfNodePtr> cnode_inputs = {
  713. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  714. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  715. MS_EXCEPTION_IF_NULL(attr_input);
  716. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  717. auto switch_cnode = cnode_input->cast<CNodePtr>();
  718. MS_EXCEPTION_IF_NULL(switch_cnode);
  719. if (cnode->inputs().size() <= 1) {
  720. cnode_inputs = switch_cnode->inputs();
  721. return cnode_inputs;
  722. }
  723. std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
  724. switch_cnode->input(kFirstDataInputIndex)};
  725. for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
  726. auto node = switch_cnode->input(index);
  727. // there is real input in call, should put it to true and false branch in switch
  728. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
  729. auto partial_node = node->cast<CNodePtr>();
  730. MS_EXCEPTION_IF_NULL(partial_node);
  731. std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
  732. // Put all call args at the end of partial inputs.
  733. for (size_t i = kFirstDataInputIndex; i < cnode->size(); ++i) {
  734. (void)partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(i)));
  735. }
  736. auto new_partial = graph->NewCNode(partial_inputs);
  737. (void)switch_inputs.emplace_back(new_partial);
  738. }
  739. }
  740. if (switch_inputs.size() < kSwitchInputSize) {
  741. MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
  742. }
  743. auto switch_node = graph->NewCNode(switch_inputs);
  744. (void)cnode_inputs.emplace_back(switch_node);
  745. return cnode_inputs;
  746. }
  747. void SessionBasic::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph,
  748. const std::vector<AnfNodePtr> &real_inputs) {
  749. MS_EXCEPTION_IF_NULL(cnode);
  750. // func1 =switch(branch1, branch2)
  751. // func2 = func1(param1)
  752. // out = func2(param2)
  753. // process the last cnode(func2), not func1 which abstract is AbstractFunction
  754. if (cnode->abstract()->isa<abstract::AbstractFunction>()) {
  755. return;
  756. }
  757. MS_EXCEPTION_IF_NULL(graph);
  758. auto ret = graph->get_return();
  759. MS_EXCEPTION_IF_NULL(ret);
  760. auto return_input = ret->input(kFirstDataInputIndex);
  761. // return node is a function
  762. std::vector<AnfNodePtr> call_inputs = {
  763. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  764. if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
  765. auto return_input_cnode = return_input->cast<CNodePtr>();
  766. auto partial_inputs = return_input_cnode->inputs();
  767. call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end());
  768. } else if (IsValueNode<KernelGraph>(return_input)) { // return node is kernel graph
  769. call_inputs.emplace_back(return_input);
  770. } else { // return node is value node
  771. KernelGraphPtr kernel_graph = NewKernelGraph();
  772. auto valid_inputs = kernel_graph->MutableValidInputs();
  773. MS_EXCEPTION_IF_NULL(valid_inputs);
  774. auto graph_inputs = kernel_graph->MutableInputs();
  775. MS_EXCEPTION_IF_NULL(graph_inputs);
  776. std::vector<AnfNodePtr> cnode_inputs = {return_input};
  777. for (auto &real_input : real_inputs) {
  778. auto new_parameter = kernel_graph->NewParameter(real_input->abstract());
  779. valid_inputs->push_back(true);
  780. graph_inputs->push_back(new_parameter);
  781. cnode_inputs.push_back(new_parameter);
  782. }
  783. auto new_cnode = kernel_graph->NewCNode(cnode_inputs);
  784. new_cnode->set_abstract(cnode->abstract());
  785. std::vector<AnfNodePtr> return_inputs = {
  786. kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))), new_cnode};
  787. auto return_node = kernel_graph->NewCNode(return_inputs);
  788. return_node->set_abstract(cnode->abstract());
  789. kernel_graph->set_return(return_node);
  790. call_inputs.push_back(std::make_shared<ValueNode>(kernel_graph));
  791. }
  792. // new call node inputs
  793. for (auto &input_node : real_inputs) {
  794. auto parameter_for_input = CreateNewParameterFromCNode(input_node, graph);
  795. call_inputs.emplace_back(parameter_for_input);
  796. }
  797. auto call_node = graph->NewCNode(call_inputs);
  798. call_node->set_abstract(cnode->abstract());
  799. // update return input
  800. ret->set_input(kFirstDataInputIndex, call_node);
  801. }
  802. std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) {
  803. MS_EXCEPTION_IF_NULL(cnode);
  804. MS_EXCEPTION_IF_NULL(graph);
  805. std::vector<AnfNodePtr> cnode_inputs = {
  806. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  807. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  808. MS_EXCEPTION_IF_NULL(attr_input);
  809. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  810. auto switch_layer_cnode = cnode_input->cast<CNodePtr>();
  811. MS_EXCEPTION_IF_NULL(switch_layer_cnode);
  812. std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
  813. switch_layer_cnode->input(kFirstDataInputIndex)};
  814. auto make_tuple_node = switch_layer_cnode->input(kMakeTupleInSwitchLayerIndex);
  815. MS_EXCEPTION_IF_NULL(make_tuple_node);
  816. auto node = make_tuple_node->cast<CNodePtr>();
  817. MS_EXCEPTION_IF_NULL(node);
  818. auto make_tuple_inputs = node->inputs();
  819. // there are real inputs in call, should put it to make_tuple in switch_layer
  820. std::vector<AnfNodePtr> real_inputs;
  821. for (size_t idx = kFirstDataInputIndex; idx < cnode->inputs().size(); ++idx) {
  822. real_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(idx)));
  823. }
  824. std::vector<AnfNodePtr> new_make_tuple_inputs = {
  825. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))};
  826. for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
  827. auto partial_idx = make_tuple_inputs[idx];
  828. MS_EXCEPTION_IF_NULL(cnode->abstract());
  829. std::vector<AnfNodePtr> new_partial_inputs;
  830. KernelGraphPtr partial_kernel_graph;
  831. // switch_layer node input is partial cnode
  832. if (AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) {
  833. auto partial_node = partial_idx->cast<CNodePtr>();
  834. MS_EXCEPTION_IF_NULL(partial_node);
  835. auto partial_input = partial_node->input(kFirstDataInputIndex);
  836. partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
  837. new_partial_inputs = partial_node->inputs();
  838. } else if (IsValueNode<KernelGraph>(partial_idx)) { // switch_layer node input is kernel graph value node
  839. new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name())));
  840. new_partial_inputs.emplace_back(partial_idx);
  841. partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_idx);
  842. }
  843. // when branch in swich_layer return function
  844. MS_EXCEPTION_IF_NULL(partial_kernel_graph);
  845. auto ret = partial_kernel_graph->get_return();
  846. MS_EXCEPTION_IF_NULL(ret);
  847. auto return_input = ret->input(kFirstDataInputIndex);
  848. if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
  849. ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs);
  850. }
  851. // partial node add input args
  852. new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
  853. // create new partial node
  854. auto new_partial = graph->NewCNode(new_partial_inputs);
  855. new_make_tuple_inputs.emplace_back(new_partial);
  856. }
  857. auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
  858. auto abstract = make_tuple_node->abstract();
  859. if (abstract == nullptr) {
  860. abstract = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList());
  861. }
  862. new_make_tuple->set_abstract(abstract);
  863. switch_layer_inputs.emplace_back(new_make_tuple);
  864. auto new_switch_layer = graph->NewCNode(switch_layer_inputs);
  865. cnode_inputs.emplace_back(new_switch_layer);
  866. return cnode_inputs;
  867. }
  868. std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
  869. MS_EXCEPTION_IF_NULL(cnode);
  870. MS_EXCEPTION_IF_NULL(graph);
  871. // create primitive of cnode:call(partial or switch or switch_layer)
  872. std::vector<AnfNodePtr> cnode_inputs = {
  873. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  874. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  875. MS_EXCEPTION_IF_NULL(attr_input);
  876. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  877. if (cnode_input == nullptr) {
  878. MS_LOG(ERROR) << "CNode input[0] is CNode:" << attr_input->DebugString() << ", but input[0] has not been created.";
  879. return {};
  880. }
  881. // if the node is partial, insert the inputs of partial to the call
  882. if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
  883. auto partial_node = attr_input->cast<CNodePtr>();
  884. MS_EXCEPTION_IF_NULL(partial_node);
  885. auto partial_inputs = partial_node->inputs();
  886. (void)std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
  887. std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
  888. MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
  889. return graph->GetBackendAnfByFrontAnf(node);
  890. });
  891. return cnode_inputs;
  892. } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
  893. return CreateCallSwitchInputs(cnode, graph);
  894. } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) {
  895. return CreateCallSwitchLayerInputs(cnode, graph);
  896. }
  897. MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString()
  898. << "must be partial or switch or switch_layer.";
  899. return {};
  900. }
  901. std::vector<AnfNodePtr> SessionBasic::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) {
  902. MS_EXCEPTION_IF_NULL(cnode);
  903. MS_EXCEPTION_IF_NULL(graph);
  904. std::vector<AnfNodePtr> cnode_inputs;
  905. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  906. MS_EXCEPTION_IF_NULL(attr_input);
  907. if (AnfAlgo::IsGraphKernel(cnode)) {
  908. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  909. MS_EXCEPTION_IF_NULL(fg);
  910. auto new_fg = BasicClone(fg);
  911. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  912. } else {
  913. // create primitive of cnode:call
  914. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  915. // create a ValueNode<KernelGraph> as input of cnode:call
  916. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  917. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
  918. } else {
  919. auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
  920. if (new_value_node != nullptr) {
  921. cnode_inputs.emplace_back(new_value_node);
  922. }
  923. }
  924. }
  925. return cnode_inputs;
  926. }
  927. void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs) {
  928. MS_EXCEPTION_IF_NULL(cnode);
  929. MS_EXCEPTION_IF_NULL(graph);
  930. if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
  931. (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
  932. for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) {
  933. auto node_input = cnode->input(index);
  934. auto switch_input = CreateSwitchInput(cnode, node_input, graph);
  935. (void)cnode_inputs->emplace_back(switch_input);
  936. }
  937. } else {
  938. for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) {
  939. auto anf = cnode->input(input_idx);
  940. MS_EXCEPTION_IF_NULL(anf);
  941. // anf has been created before
  942. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  943. (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  944. continue;
  945. } else if (IsValueNode<None>(anf)) {
  946. continue;
  947. }
  948. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  949. }
  950. }
  951. }
  952. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
  953. MS_EXCEPTION_IF_NULL(cnode);
  954. MS_EXCEPTION_IF_NULL(graph);
  955. std::vector<AnfNodePtr> cnode_inputs;
  956. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  957. MS_EXCEPTION_IF_NULL(attr_input);
  958. if (IsValueNode<FuncGraph>(attr_input)) {
  959. // cnode is a graph or a call
  960. cnode_inputs = CreateValueNode(cnode, graph);
  961. } else if (attr_input->isa<CNode>()) {
  962. // cnode ia a call (partial/switch/switch_layer)
  963. // 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph
  964. // 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created
  965. cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
  966. if (cnode_inputs.empty()) {
  967. MS_LOG_ERROR << "Create switch or partial failed, cnode:" << cnode->DebugString();
  968. return nullptr;
  969. }
  970. } else {
  971. // get primitive of old node
  972. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  973. MS_EXCEPTION_IF_NULL(prim);
  974. // push attr to inputs[0] of new cnode
  975. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
  976. }
  977. // handle inputs of cnode except primitive
  978. CreateCNodeInputs(cnode, graph, &cnode_inputs);
  979. TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
  980. auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
  981. // if the cnode is call switch, remove call
  982. if (new_cnode->inputs().size() > 1) {
  983. auto first_input = new_cnode->input(kFirstDataInputIndex);
  984. MS_EXCEPTION_IF_NULL(first_input);
  985. if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
  986. AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
  987. new_cnode = first_input->cast<CNodePtr>();
  988. }
  989. if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
  990. AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) {
  991. auto abstract = cnode->abstract();
  992. new_cnode = first_input->cast<CNodePtr>();
  993. new_cnode->set_abstract(abstract);
  994. }
  995. }
  996. return new_cnode;
  997. }
  998. ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
  999. MS_EXCEPTION_IF_NULL(anf);
  1000. MS_EXCEPTION_IF_NULL(graph);
  1001. auto value_node = anf->cast<ValueNodePtr>();
  1002. MS_EXCEPTION_IF_NULL(value_node);
  1003. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
  1004. MS_EXCEPTION_IF_NULL(sub_func_graph);
  1005. if (front_backend_graph_map_.find(sub_func_graph.get()) == front_backend_graph_map_.end()) {
  1006. MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
  1007. }
  1008. auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph.get()];
  1009. ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
  1010. new_value_node->set_abstract(value_node->abstract());
  1011. // create new kernel_info of new value_node
  1012. auto kernel_info = std::make_shared<device::KernelInfo>();
  1013. new_value_node->set_kernel_info(kernel_info);
  1014. // create kernel_build_info for new value node
  1015. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  1016. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  1017. AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
  1018. graph->FrontBackendMapAdd(anf, new_value_node);
  1019. return new_value_node;
  1020. }
  1021. ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  1022. MS_EXCEPTION_IF_NULL(anf);
  1023. MS_EXCEPTION_IF_NULL(graph);
  1024. if (!anf->isa<Parameter>()) {
  1025. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  1026. }
  1027. auto param_value = GetParamDefaultValue(anf);
  1028. ParameterPtr new_parameter = nullptr;
  1029. // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
  1030. if (param_value != nullptr) {
  1031. new_parameter = param_value->parameter();
  1032. if (new_parameter == nullptr) {
  1033. TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
  1034. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  1035. param_value->set_parameter(new_parameter);
  1036. }
  1037. } else {
  1038. TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
  1039. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  1040. }
  1041. new_parameter->IncreaseUsedGraphCount();
  1042. return new_parameter;
  1043. }
  1044. KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
  1045. bool common_opt) {
  1046. mindspore::HashMap<AnfNodePtr, AnfNodePtr> other_graph_cnode;
  1047. auto graph = NewKernelGraph();
  1048. MS_EXCEPTION_IF_NULL(graph);
  1049. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  1050. for (const auto &node : lst) {
  1051. MS_EXCEPTION_IF_NULL(node);
  1052. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  1053. if (!node->isa<CNode>()) {
  1054. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
  1055. }
  1056. auto cnode = node->cast<CNodePtr>();
  1057. MS_EXCEPTION_IF_NULL(cnode);
  1058. // create a new cnode object
  1059. auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode);
  1060. MS_EXCEPTION_IF_NULL(new_cnode);
  1061. new_cnode->set_abstract(cnode->abstract());
  1062. new_cnode->set_scope(cnode->scope());
  1063. if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
  1064. new_cnode->set_fullname_with_scope(cnode->input(kFirstDataInputIndex)->fullname_with_scope());
  1065. }
  1066. // record map relations between anf from ME and new anf node used in backend
  1067. graph->FrontBackendMapAdd(node, new_cnode);
  1068. }
  1069. // add a make_tuple at the end of graph as output
  1070. graph->set_output(ConstructOutput(outputs, graph));
  1071. FuncGraphManagerPtr manager = MakeManager({graph});
  1072. if (manager) {
  1073. manager->AddFuncGraph(graph);
  1074. graph->set_manager(manager);
  1075. }
  1076. graph->SetExecOrderByDefault();
  1077. #ifndef ENABLE_SECURITY
  1078. if (ExistSummaryNode(graph.get())) {
  1079. graph->set_summary_node_exist(true);
  1080. }
  1081. #endif
  1082. MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
  1083. if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
  1084. UnifyMindIR(graph);
  1085. graph->UpdateGraphAquireGilAttr();
  1086. if (common_opt) {
  1087. opt::BackendCommonOptimization(graph);
  1088. }
  1089. graph->SetInputNodes();
  1090. SetInputNodeUsage(graph, manager);
  1091. graph->SetOptimizerFlag();
  1092. }
  1093. return graph;
  1094. }
  1095. void SessionBasic::SetInputNodeUsage(const KernelGraphPtr &graph, const FuncGraphManagerPtr &manager) {
  1096. MS_EXCEPTION_IF_NULL(graph);
  1097. MS_EXCEPTION_IF_NULL(manager);
  1098. auto input_nodes = graph->input_nodes();
  1099. for (auto &input_node : input_nodes) {
  1100. if (input_node->isa<Parameter>()) {
  1101. auto node_ptr = input_node->cast<ParameterPtr>();
  1102. MS_EXCEPTION_IF_NULL(node_ptr);
  1103. if (!IsUsedByRealKernel(manager, input_node, graph->graph_id())) {
  1104. node_ptr->SetNotUsedByRealKernelInGraph(graph->graph_id());
  1105. }
  1106. auto shape = node_ptr->Shape();
  1107. if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) {
  1108. node_ptr->set_has_dynamic_shape(true);
  1109. }
  1110. }
  1111. }
  1112. }
  1113. GraphInfo SessionBasic::GetSingleOpGraphInfo(const CNodePtr &kernel,
  1114. const std::vector<tensor::TensorPtr> &input_tensors) {
  1115. MS_EXCEPTION_IF_NULL(kernel);
  1116. auto prim = AnfAlgo::GetCNodePrimitive(kernel);
  1117. MS_EXCEPTION_IF_NULL(prim);
  1118. const AbstractBasePtr &abstract = kernel->abstract();
  1119. MS_EXCEPTION_IF_NULL(abstract);
  1120. size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
  1121. GraphInfo graph_info;
  1122. // get input tensor info
  1123. for (const auto &tensor : input_tensors) {
  1124. MS_EXCEPTION_IF_NULL(tensor);
  1125. auto tensor_shape = tensor->shape();
  1126. (void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
  1127. [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
  1128. (void)graph_info.append(std::to_string(tensor->data_type()) + "_");
  1129. if (tensor->device_address() != nullptr) {
  1130. const auto type_id = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id();
  1131. (void)graph_info.append(std::to_string(type_id) + "_");
  1132. const auto format = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format();
  1133. (void)graph_info.append(format + "_");
  1134. }
  1135. for (const auto &padding_type : tensor->padding_type()) {
  1136. (void)graph_info.append(std::to_string(padding_type) + "_");
  1137. }
  1138. }
  1139. // get attr info
  1140. const auto &attr_map = prim->attrs();
  1141. (void)std::for_each(attr_map.begin(), attr_map.end(), [&](const auto &element) {
  1142. if (element.second->ToString().empty()) {
  1143. return;
  1144. }
  1145. (void)graph_info.append(element.second->ToString() + "_");
  1146. });
  1147. auto build_shape = abstract->BuildShape();
  1148. MS_EXCEPTION_IF_NULL(build_shape);
  1149. (void)graph_info.append(build_shape->ToString() + "_");
  1150. for (size_t output_index = 0; output_index < output_num; output_index += 1) {
  1151. const auto output_type = AnfAlgo::GetOutputInferDataType(kernel, output_index);
  1152. (void)graph_info.append(std::to_string(output_type) + "_");
  1153. }
  1154. graph_info.append(std::to_string(prim->id()));
  1155. return graph_info;
  1156. }
  1157. OpRunInfo SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInfo &graph_info,
  1158. const InputTensorInfo &tensor_info) {
  1159. MS_EXCEPTION_IF_NULL(cnode);
  1160. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  1161. const auto &abstract = cnode->abstract();
  1162. if (abstract == nullptr) {
  1163. MS_LOG(EXCEPTION) << "Abstract is nullptr, node = " << cnode->DebugString();
  1164. }
  1165. const auto &shape = abstract->BuildShape();
  1166. MS_EXCEPTION_IF_NULL(shape);
  1167. OpRunInfo op_run_info = {.op_name = primitive->name(),
  1168. .primitive = primitive.get(),
  1169. .abstract = abstract,
  1170. .is_dynamic_shape = shape->IsDynamic(),
  1171. .is_auto_mixed_precision = false,
  1172. .lazy_build = false,
  1173. .next_op_name = std::string(),
  1174. .next_input_index = 0,
  1175. .graph_info = graph_info,
  1176. .tensor_mask = tensor_info.input_tensors_mask,
  1177. .input_tensors = tensor_info.input_tensors};
  1178. return op_run_info;
  1179. }
  1180. void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
  1181. std::map<AnfNodePtr, size_t> *parameter_index) {
  1182. size_t index = 0;
  1183. for (const auto &input_node : graph->inputs()) {
  1184. auto params = AnfAlgo::GetAllOutput(input_node);
  1185. for (const auto &param : params) {
  1186. if (index >= inputs.size()) {
  1187. MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
  1188. << ", input size: " << inputs.size();
  1189. }
  1190. const auto &input = inputs[index];
  1191. MS_EXCEPTION_IF_NULL(input);
  1192. // Check shape of input and parameter
  1193. const auto &input_shape = input->shape();
  1194. const auto &param_shape = AnfAlgo::GetOutputInferShape(param, 0);
  1195. if (input_shape.size() != param_shape.size()) {
  1196. MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
  1197. << ", parameter: " << param->fullname_with_scope();
  1198. }
  1199. bool is_dynamic = param->Shape()->IsDynamic();
  1200. for (size_t i = 0; i < input_shape.size(); i += 1) {
  1201. if (input_shape[i] < 0 || (static_cast<size_t>(input_shape[i]) != param_shape[i] && !is_dynamic)) {
  1202. MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
  1203. << ", parameter: " << param->fullname_with_scope();
  1204. }
  1205. }
  1206. parameter_index->emplace(param, index++);
  1207. }
  1208. }
  1209. }
  1210. void SessionBasic::CreateOutputPlaceholder(
  1211. const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *const outputs,
  1212. std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
  1213. MS_EXCEPTION_IF_NULL(kernel_graph);
  1214. MS_EXCEPTION_IF_NULL(outputs);
  1215. MS_EXCEPTION_IF_NULL(output_indexes);
  1216. auto anf_outputs = kernel_graph->outputs();
  1217. size_t index = 0;
  1218. for (auto &item : anf_outputs) {
  1219. MS_EXCEPTION_IF_NULL(item);
  1220. std::vector<size_t> indexes{index++};
  1221. outputs->emplace_back(CreateNodeOutputPlaceholder(item, kernel_graph, input_tensors, indexes, output_indexes));
  1222. }
  1223. }
  1224. void SessionBasic::GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
  1225. MS_EXCEPTION_IF_NULL(graph);
  1226. for (const auto &kernel : graph->execution_order()) {
  1227. for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
  1228. const auto &input = kernel->input(i);
  1229. auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
  1230. const auto &node = kernel_with_index.first;
  1231. if (node->isa<CNode>()) {
  1232. (*ref_count)[kernel_with_index] += 1;
  1233. }
  1234. }
  1235. }
  1236. }
  1237. void SessionBasic::HandleOpInputs(const std::set<KernelWithIndex> &input_kernel,
  1238. std::map<KernelWithIndex, size_t> *ref_count,
  1239. std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
  1240. MS_EXCEPTION_IF_NULL(ref_count);
  1241. MS_EXCEPTION_IF_NULL(op_output_map);
  1242. for (auto &kernel_with_index : input_kernel) {
  1243. MS_EXCEPTION_IF_NULL(kernel_with_index.first);
  1244. if (!kernel_with_index.first->isa<CNode>()) {
  1245. continue;
  1246. }
  1247. auto ref_iter = ref_count->find(kernel_with_index);
  1248. if (ref_iter == ref_count->end()) {
  1249. MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
  1250. << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
  1251. }
  1252. // Reduce reference count number, when it was reduced to zero, release the useless output of pre node.
  1253. ref_iter->second -= 1;
  1254. if (ref_iter->second != 0) {
  1255. continue;
  1256. }
  1257. ref_count->erase(ref_iter);
  1258. auto output_iter = op_output_map->find(kernel_with_index);
  1259. if (output_iter == op_output_map->end()) {
  1260. MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = "
  1261. << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
  1262. }
  1263. op_output_map->erase(output_iter);
  1264. }
  1265. }
  1266. void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
  1267. const std::map<KernelWithIndex, size_t> &ref_count,
  1268. std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map,
  1269. GraphOutputInfo *const graph_output_info) {
  1270. MS_EXCEPTION_IF_NULL(kernel);
  1271. MS_EXCEPTION_IF_NULL(op_output_map);
  1272. MS_EXCEPTION_IF_NULL(graph_output_info);
  1273. MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
  1274. auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
  1275. if (output_tensors.size() > op_outputs.size()) {
  1276. MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
  1277. }
  1278. size_t out_index = 0;
  1279. for (const auto &output_tensor : output_tensors) {
  1280. auto kernel_with_index = make_pair(kernel, out_index++);
  1281. if (ref_count.find(kernel_with_index) != ref_count.end()) {
  1282. (*op_output_map)[kernel_with_index] = output_tensor;
  1283. }
  1284. const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
  1285. if (iter == graph_output_info->output_indexes.end()) {
  1286. continue;
  1287. }
  1288. const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
  1289. for (const auto &ref_indexes : multiple_ref_indexes) {
  1290. size_t n = 0;
  1291. const VectorRef *cur_vector_ref = graph_output_info->graph_outputs;
  1292. for (; n < ref_indexes.size() - 1; n += 1) {
  1293. size_t index = ref_indexes.at(n);
  1294. if (index >= cur_vector_ref->size()) {
  1295. MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
  1296. << cur_vector_ref->size();
  1297. }
  1298. const BaseRef &base_ref = (*cur_vector_ref)[index];
  1299. if (!utils::isa<VectorRef>(base_ref)) {
  1300. MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
  1301. }
  1302. cur_vector_ref = &utils::cast<VectorRef>(base_ref);
  1303. }
  1304. BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
  1305. tensor_ref = output_tensor;
  1306. graph_output_info->graph_output_tensors.emplace_back(output_tensor);
  1307. }
  1308. }
  1309. }
  1310. TensorPtr SessionBasic::GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index) {
  1311. MS_EXCEPTION_IF_NULL(node);
  1312. if (!node->isa<ValueNode>()) {
  1313. return nullptr;
  1314. }
  1315. auto value_node = node->cast<ValueNodePtr>();
  1316. MS_EXCEPTION_IF_NULL(value_node);
  1317. auto value = GetValueNode(value_node);
  1318. MS_EXCEPTION_IF_NULL(value);
  1319. if (value->isa<ValueTuple>()) {
  1320. auto value_tuple = value->cast<ValueTuplePtr>();
  1321. MS_EXCEPTION_IF_NULL(value_tuple);
  1322. if (output_index >= value_tuple->size()) {
  1323. MS_LOG(EXCEPTION) << "Index " << output_index << "is out of value tuple range";
  1324. }
  1325. auto tensor_value = value_tuple->value()[output_index];
  1326. if (tensor_value->isa<tensor::Tensor>()) {
  1327. return tensor_value->cast<tensor::TensorPtr>();
  1328. }
  1329. } else if (value->isa<tensor::Tensor>()) {
  1330. if (output_index != 0) {
  1331. MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << output_index;
  1332. }
  1333. return value->cast<TensorPtr>();
  1334. } else if (value->isa<StringImm>()) {
  1335. auto value_string = GetValue<std::string>(value);
  1336. const ShapeVector shape = {1, SizeToLong(value_string.size())};
  1337. TensorPtr tensor = std::make_shared<Tensor>(kObjectTypeString, shape, value_string.data(), value_string.size());
  1338. MS_EXCEPTION_IF_NULL(tensor);
  1339. tensor->set_sync_status(kNeedSyncHostToDevice);
  1340. return tensor;
  1341. }
  1342. return nullptr;
  1343. }
  1344. TensorPtr SessionBasic::GetParameterOutputTensor(const AnfNodePtr &node,
  1345. const std::map<AnfNodePtr, size_t> &parameter_index,
  1346. const std::vector<tensor::TensorPtr> &graph_inputs) {
  1347. MS_EXCEPTION_IF_NULL(node);
  1348. if (!node->isa<Parameter>()) {
  1349. return nullptr;
  1350. }
  1351. const auto &iter = parameter_index.find(node);
  1352. if (iter == parameter_index.end()) {
  1353. MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, parameter = " << node->DebugString();
  1354. }
  1355. const size_t index = iter->second;
  1356. if (index >= graph_inputs.size()) {
  1357. MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = " << index
  1358. << ", input tensor size = " << graph_inputs.size();
  1359. }
  1360. return graph_inputs[index];
  1361. }
  1362. TensorPtr SessionBasic::GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index,
  1363. const std::map<KernelWithIndex, tensor::TensorPtr> &op_output) {
  1364. const auto &iter = op_output.find(kernel_with_index);
  1365. if (iter == op_output.end()) {
  1366. MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << kernel_with_index.first->DebugString();
  1367. }
  1368. return iter->second;
  1369. }
  1370. void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
  1371. const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
  1372. const std::map<AnfNodePtr, size_t> &parameter_index,
  1373. const std::vector<tensor::TensorPtr> &graph_inputs,
  1374. InputTensorInfo *input_tensor_info) {
  1375. MS_EXCEPTION_IF_NULL(cnode);
  1376. MS_EXCEPTION_IF_NULL(input_tensor_info);
  1377. const auto input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
  1378. for (size_t i = 1; i <= input_tensor_num; i += 1) {
  1379. const auto &input = cnode->input(i);
  1380. auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
  1381. auto real_input = kernel_with_index.first;
  1382. MS_EXCEPTION_IF_NULL(real_input);
  1383. tensor::TensorPtr tensor = nullptr;
  1384. if (real_input->isa<ValueNode>()) {
  1385. tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
  1386. const auto &value_ptr = GetValueNode(real_input);
  1387. MS_EXCEPTION_IF_NULL(value_ptr);
  1388. input_tensor_info->input_tensors_mask.emplace_back(value_ptr->isa<StringImm>() ? kValueNodeTensorMask
  1389. : kParameterDataTensorMask);
  1390. } else if (real_input->isa<Parameter>()) {
  1391. tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
  1392. input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask
  1393. : kParameterDataTensorMask);
  1394. } else if (real_input->isa<CNode>()) {
  1395. tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
  1396. if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
  1397. CheckInputTensorShape(tensor, cnode, i - 1);
  1398. }
  1399. input_tensor_info->input_kernel.insert(kernel_with_index);
  1400. input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask
  1401. : kParameterDataTensorMask);
  1402. } else {
  1403. MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
  1404. }
  1405. MS_EXCEPTION_IF_NULL(tensor);
  1406. MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
  1407. << real_input->fullname_with_scope() << "-" << kernel_with_index.second;
  1408. input_tensor_info->input_tensors.emplace_back(tensor);
  1409. }
  1410. }
  1411. tensor::TensorPtr SessionBasic::GetOpInputTensorByIndex(const CNodePtr &cnode,
  1412. const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
  1413. const std::map<AnfNodePtr, size_t> &parameter_index,
  1414. const std::vector<tensor::TensorPtr> &graph_inputs,
  1415. InputTensorInfo *const input_tensor_info, size_t input_index) {
  1416. MS_EXCEPTION_IF_NULL(cnode);
  1417. MS_EXCEPTION_IF_NULL(input_tensor_info);
  1418. if (input_index >= cnode->inputs().size() - 1) {
  1419. MS_LOG(EXCEPTION) << "Input index is out of range:" << cnode->inputs().size() << ",cnode:" << cnode->DebugString();
  1420. }
  1421. const auto &input = cnode->input(input_index + 1);
  1422. auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
  1423. auto real_input = kernel_with_index.first;
  1424. MS_EXCEPTION_IF_NULL(real_input);
  1425. if (real_input->isa<Parameter>()) {
  1426. return GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
  1427. } else if (real_input->isa<CNode>()) {
  1428. tensor::TensorPtr tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
  1429. if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
  1430. CheckInputTensorShape(tensor, cnode, input_index);
  1431. }
  1432. input_tensor_info->input_kernel.insert(kernel_with_index);
  1433. return tensor;
  1434. } else {
  1435. MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
  1436. }
  1437. }
  1438. bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
  1439. MS_EXCEPTION_IF_NULL(node);
  1440. MS_EXCEPTION_IF_NULL(graph);
  1441. auto cnode = node->cast<CNodePtr>();
  1442. MS_EXCEPTION_IF_NULL(cnode);
  1443. // create a new cnode object
  1444. auto new_cnode = CreateNewCNode(cnode, graph);
  1445. if (new_cnode == nullptr) {
  1446. return false;
  1447. }
  1448. new_cnode->set_abstract(cnode->abstract());
  1449. std::string fullname;
  1450. if (cnode->input(kAnfPrimitiveIndex)->isa<CNode>()) {
  1451. fullname = cnode->input(kAnfPrimitiveIndex)->fullname_with_scope();
  1452. } else if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
  1453. fullname = cnode->input(kFirstDataInputIndex)->fullname_with_scope();
  1454. } else {
  1455. fullname = cnode->fullname_with_scope();
  1456. }
  1457. new_cnode->set_fullname_with_scope(fullname);
  1458. new_cnode->set_scope(cnode->scope());
  1459. graph->FrontBackendMapAdd(node, new_cnode);
  1460. SetReturnNode(new_cnode, graph);
  1461. return true;
  1462. }
  1463. std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
  1464. std::vector<KernelGraphPtr> *all_out_graph) {
  1465. MS_EXCEPTION_IF_NULL(func_graph);
  1466. MS_EXCEPTION_IF_NULL(all_out_graph);
  1467. auto node_list = TopoSort(func_graph->get_return());
  1468. auto graph = NewKernelGraph();
  1469. MS_EXCEPTION_IF_NULL(graph);
  1470. front_backend_graph_map_[func_graph.get()] = graph;
  1471. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  1472. for (const auto &node : node_list) {
  1473. MS_EXCEPTION_IF_NULL(node);
  1474. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  1475. // Create parameter
  1476. if (node->isa<Parameter>()) {
  1477. auto graph_inputs = graph->MutableInputs();
  1478. MS_EXCEPTION_IF_NULL(graph_inputs);
  1479. auto new_parameter = CreateNewParameter(node, graph.get());
  1480. graph_inputs->push_back(new_parameter);
  1481. graph->FrontBackendMapAdd(node, new_parameter);
  1482. continue;
  1483. }
  1484. // Create value node
  1485. if (node->isa<ValueNode>()) {
  1486. // Create common value node
  1487. if (!IsValueNode<FuncGraph>(node)) {
  1488. (void)CreateNewValueNode(node, graph.get());
  1489. continue;
  1490. }
  1491. // Create child kernel graph according ValueNode<FuncGraph>
  1492. FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
  1493. if (front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()) {
  1494. (void)ConstructKernelGraph(child_graph, all_out_graph);
  1495. }
  1496. (void)CreateValueNodeKernelGraph(node, graph.get());
  1497. continue;
  1498. }
  1499. // Create cnode
  1500. if (!CreateCNodeOfKernelGraph(node, graph.get())) {
  1501. #ifdef ENABLE_DUMP_IR
  1502. DumpIR("construct_kernel_graph_fail.ir", func_graph);
  1503. #endif
  1504. MS_LOG(EXCEPTION) << "Construct func graph " << func_graph->ToString() << " failed."
  1505. << trace::DumpSourceLines(node);
  1506. }
  1507. }
  1508. AddParameterToGraphInputs(func_graph->parameters(), graph.get());
  1509. FuncGraphManagerPtr manager = MakeManager({graph});
  1510. graph->SetInputNodes();
  1511. SetInputNodeUsage(graph, manager);
  1512. graph->SetExecOrderByDefault();
  1513. #ifndef ENABLE_SECURITY
  1514. if (ExistSummaryNode(graph.get())) {
  1515. graph->set_summary_node_exist(true);
  1516. }
  1517. #endif
  1518. all_out_graph->push_back(graph);
  1519. return graph;
  1520. }
  1521. void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) {
  1522. MS_EXCEPTION_IF_NULL(graph);
  1523. auto graph_inputs = graph->MutableInputs();
  1524. MS_EXCEPTION_IF_NULL(graph_inputs);
  1525. graph_inputs->clear();
  1526. for (auto &parameter : parameters) {
  1527. MS_EXCEPTION_IF_NULL(parameter);
  1528. auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
  1529. if (backend_parameter == nullptr) {
  1530. // for example "def f(x,y,z) {return x + y}", parameter z in unused
  1531. auto new_parameter = CreateNewParameter(parameter, graph);
  1532. graph_inputs->push_back(new_parameter);
  1533. graph->FrontBackendMapAdd(parameter, new_parameter);
  1534. MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
  1535. continue;
  1536. }
  1537. graph_inputs->push_back(backend_parameter);
  1538. }
  1539. }
  1540. void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  1541. const std::vector<tensor::TensorPtr> &input_tensors,
  1542. std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const {
  1543. MS_EXCEPTION_IF_NULL(kernel_graph);
  1544. MS_EXCEPTION_IF_NULL(outputs);
  1545. MS_EXCEPTION_IF_NULL(tensor_to_node);
  1546. KernelMapTensor node_to_tensor;
  1547. auto anf_outputs = kernel_graph->outputs();
  1548. for (auto &item : anf_outputs) {
  1549. MS_EXCEPTION_IF_NULL(item);
  1550. MS_LOG(DEBUG) << "Update output[" << item->DebugString() << "]";
  1551. outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, &node_to_tensor));
  1552. }
  1553. auto ms_context = MsContext::GetInstance();
  1554. MS_EXCEPTION_IF_NULL(ms_context);
  1555. for (auto &item : *tensor_to_node) {
  1556. auto &tensor = item.first;
  1557. auto &node = item.second.first;
  1558. auto &output_index = item.second.second;
  1559. DeviceAddressPtr address = nullptr;
  1560. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
  1561. ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
  1562. address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
  1563. } else {
  1564. address = AnfAlgo::GetMutableOutputAddr(node, output_index);
  1565. }
  1566. MS_EXCEPTION_IF_NULL(tensor);
  1567. tensor->set_device_address(address);
  1568. tensor->SetNeedWait(false);
  1569. MS_LOG(DEBUG) << "Debug address: Output tensor obj " << tensor.get() << ", tensor id " << tensor->id()
  1570. << ", device address " << tensor->device_address().get();
  1571. if (AnfAlgo::IsDynamicShape(node)) {
  1572. const auto &updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
  1573. ShapeVector int_shape;
  1574. (void)std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
  1575. (void)tensor->set_shape(int_shape);
  1576. }
  1577. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
  1578. tensor->data_sync(false);
  1579. tensor->set_sync_status(kNeedSyncHostToDevice);
  1580. }
  1581. }
  1582. }
  1583. void SessionBasic::UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph,
  1584. OpRunInfo *op_run_info) const {
  1585. MS_EXCEPTION_IF_NULL(kernel_graph);
  1586. MS_EXCEPTION_IF_NULL(op_run_info);
  1587. const auto &kernels = kernel_graph->execution_order();
  1588. for (const auto &kernel : kernels) {
  1589. MS_EXCEPTION_IF_NULL(kernel);
  1590. if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) {
  1591. op_run_info->abstract = kernel->abstract();
  1592. }
  1593. }
  1594. }
  1595. std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const GraphId &graph_id,
  1596. const std::vector<tensor::TensorPtr> &inputs) {
  1597. auto graph = GetGraph(graph_id);
  1598. MS_EXCEPTION_IF_NULL(graph);
  1599. if (!graph->has_optimizer()) {
  1600. return {};
  1601. }
  1602. auto input_nodes = graph->inputs();
  1603. bool check_monad = false;
  1604. if (input_nodes.size() == inputs.size()) {
  1605. check_monad = true;
  1606. }
  1607. std::vector<tensor::TensorPtr> result;
  1608. for (size_t i = 0; i < inputs.size(); ++i) {
  1609. if (check_monad && HasAbstractMonad(input_nodes[i])) {
  1610. continue;
  1611. }
  1612. auto &tensor = inputs[i];
  1613. MS_EXCEPTION_IF_NULL(tensor);
  1614. if (!tensor->IsGraphOutput()) {
  1615. result.emplace_back(tensor);
  1616. }
  1617. }
  1618. return result;
  1619. }
  1620. void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
  1621. VectorRef *outputs,
  1622. std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
  1623. KernelMapTensor *node_to_tensor) {
  1624. auto kernel_graph = GetGraph(graph_id);
  1625. MS_EXCEPTION_IF_NULL(kernel_graph);
  1626. MS_EXCEPTION_IF_NULL(outputs);
  1627. MS_EXCEPTION_IF_NULL(tensor_to_node);
  1628. auto anf_outputs = kernel_graph->outputs();
  1629. for (auto &item : anf_outputs) {
  1630. MS_EXCEPTION_IF_NULL(item);
  1631. MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
  1632. outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, node_to_tensor));
  1633. }
  1634. }
  1635. void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
  1636. const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
  1637. std::map<DeviceAddressPtr, DeviceAddressPtr> *) {
  1638. auto context_ptr = MsContext::GetInstance();
  1639. MS_EXCEPTION_IF_NULL(context_ptr);
  1640. if (device::KernelRuntime::UseMemScheduler()) {
  1641. return;
  1642. }
  1643. MS_EXCEPTION_IF_NULL(outputs);
  1644. for (const auto &item : *outputs) {
  1645. if (utils::isa<VectorRefPtr>(item)) {
  1646. const auto &vector_ref = utils::cast<VectorRef>(item);
  1647. std::map<DeviceAddressPtr, DeviceAddressPtr> new_to_old_device_address;
  1648. UpdateOutputTensors(&vector_ref, tensor_to_node, &new_to_old_device_address);
  1649. } else if (utils::isa<tensor::TensorPtr>(item)) {
  1650. const auto &tensor = utils::cast<tensor::TensorPtr>(item);
  1651. MS_EXCEPTION_IF_NULL(tensor);
  1652. const auto &iter = tensor_to_node.find(tensor);
  1653. if (iter != tensor_to_node.end()) {
  1654. const auto &node = iter->second.first;
  1655. const auto &output_index = iter->second.second;
  1656. if (!AnfAlgo::OutputAddrExist(node, output_index, true)) {
  1657. continue;
  1658. }
  1659. const auto &address = AnfAlgo::GetMutableOutputAddr(node, output_index);
  1660. tensor->set_device_address(address);
  1661. if (AnfAlgo::IsDynamicShape(node)) {
  1662. const auto &updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
  1663. ShapeVector int_shape;
  1664. (void)std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
  1665. (void)tensor->set_shape(int_shape);
  1666. }
  1667. }
  1668. if (tensor->NeedSyncDeviceToHostImmediately()) {
  1669. tensor->data_sync(false);
  1670. tensor->set_device_address(nullptr);
  1671. tensor->set_sync_status(kNeedSyncHostToDevice);
  1672. }
  1673. }
  1674. }
  1675. }
  1676. void SessionBasic::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
  1677. std::vector<std::string> *inputs_name) const {
  1678. MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id;
  1679. auto kernel_graph = GetGraph(graph_id);
  1680. MS_EXCEPTION_IF_NULL(kernel_graph);
  1681. MS_EXCEPTION_IF_NULL(inputs);
  1682. MS_EXCEPTION_IF_NULL(inputs_name);
  1683. auto kernel_graph_inputs = kernel_graph->inputs();
  1684. // find parameters of graph inputs
  1685. for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
  1686. if (!kernel_graph_inputs[i]->isa<Parameter>()) {
  1687. MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
  1688. continue;
  1689. }
  1690. auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
  1691. if (!AnfAlgo::IsParameterWeight(parameter)) {
  1692. vector<int64_t> input_shape;
  1693. auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
  1694. (void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape),
  1695. [](const size_t dim) { return SizeToLong(dim); });
  1696. auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
  1697. auto data_type = kernel_build_info->GetOutputDeviceType(0);
  1698. auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape);
  1699. inputs->push_back(ms_tensor);
  1700. inputs_name->push_back(parameter->name());
  1701. }
  1702. }
  1703. }
  1704. void SessionBasic::GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
  1705. std::vector<std::string> *output_names) const {
  1706. std::vector<tensor::TensorPtr> inputs;
  1707. std::vector<std::string> input_names;
  1708. GetModelInputsInfo(graph_id, &inputs, &input_names);
  1709. auto kernel_graph = GetGraph(graph_id);
  1710. MS_EXCEPTION_IF_NULL(kernel_graph);
  1711. MS_EXCEPTION_IF_NULL(outputs);
  1712. MS_EXCEPTION_IF_NULL(output_names);
  1713. VectorRef vector_outputs;
  1714. std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
  1715. KernelMapTensor node_to_tensor;
  1716. auto anf_outputs = kernel_graph->outputs();
  1717. for (auto &item : anf_outputs) {
  1718. MS_EXCEPTION_IF_NULL(item);
  1719. MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
  1720. vector_outputs.emplace_back(CreateNodeOutputTensors(item, kernel_graph, inputs, &tensor_to_node, &node_to_tensor));
  1721. }
  1722. *outputs = TransformVectorRefToMultiTensor(vector_outputs);
  1723. for (size_t i = 0; i < outputs->size(); i++) {
  1724. output_names->push_back("output" + std::to_string(i));
  1725. }
  1726. }
  1727. #ifndef ENABLE_SECURITY
  1728. void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
  1729. MS_EXCEPTION_IF_NULL(callback);
  1730. summary_callback_ = callback;
  1731. }
  1732. void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
  1733. MS_LOG(DEBUG) << "Update summary Start";
  1734. MS_EXCEPTION_IF_NULL(graph);
  1735. if (!graph->summary_node_exist()) {
  1736. return;
  1737. }
  1738. auto summary = graph->summary_nodes();
  1739. auto apply_list = TopoSort(graph->get_return());
  1740. for (auto &n : apply_list) {
  1741. MS_EXCEPTION_IF_NULL(n);
  1742. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  1743. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  1744. auto cnode = n->cast<CNodePtr>();
  1745. MS_EXCEPTION_IF_NULL(cnode);
  1746. if (cnode->inputs().size() <= kSummaryGetItem) {
  1747. MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least, but got " << cnode->inputs().size() - 1
  1748. << ". trace: " << trace::DumpSourceLines(cnode);
  1749. }
  1750. auto node = cnode->input(kSummaryGetItem);
  1751. MS_EXCEPTION_IF_NULL(node);
  1752. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  1753. MS_EXCEPTION_IF_NULL(item_with_index.first);
  1754. if (!AnfUtils::IsRealKernel(item_with_index.first)) {
  1755. MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
  1756. }
  1757. summary[n->fullname_with_scope()] = item_with_index;
  1758. }
  1759. }
  1760. graph->set_summary_nodes(summary);
  1761. MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
  1762. }
  1763. void SessionBasic::Summary(KernelGraph *graph) {
  1764. if (summary_callback_ == nullptr) {
  1765. return;
  1766. }
  1767. MS_EXCEPTION_IF_NULL(graph);
  1768. bool exist_summary = graph->summary_node_exist();
  1769. if (!exist_summary) {
  1770. return;
  1771. }
  1772. static bool is_first = true;
  1773. if (is_first && !IsSupportSummary()) {
  1774. is_first = false;
  1775. MS_LOG(ERROR) << "The Summary operator can not collect data correctly. Detail: the data sink mode is used and the"
  1776. " sink size(in model.train() python api) is not equal to 1.";
  1777. }
  1778. SetSummaryNodes(graph);
  1779. auto summary_outputs = graph->summary_nodes();
  1780. std::map<std::string, tensor::TensorPtr> params_list;
  1781. // fetch outputs apply kernel in session & run callback functions
  1782. for (auto &output_item : summary_outputs) {
  1783. auto node = output_item.second.first;
  1784. size_t index = IntToSize(output_item.second.second);
  1785. auto address = AnfAlgo::GetOutputAddr(node, index);
  1786. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  1787. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  1788. std::vector<int64_t> temp_shape;
  1789. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  1790. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  1791. MS_EXCEPTION_IF_NULL(address);
  1792. if (!address->GetPtr()) {
  1793. continue;
  1794. }
  1795. if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
  1796. tensor->data_type(), tensor->data_c())) {
  1797. MS_LOG(ERROR) << "Failed to sync output from device to host.";
  1798. }
  1799. tensor->set_sync_status(kNoNeedSync);
  1800. params_list[output_item.first] = tensor;
  1801. }
  1802. // call callback function here
  1803. summary_callback_(0, params_list);
  1804. }
  1805. #endif
  1806. namespace {
  1807. bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
  1808. if (node == nullptr) {
  1809. return false;
  1810. }
  1811. auto cnode = node->cast<CNodePtr>();
  1812. if (cnode == nullptr) {
  1813. return false;
  1814. }
  1815. auto prim = cnode->input(kAnfPrimitiveIndex);
  1816. if (prim == nullptr || !IsValueNode<Primitive>(prim)) {
  1817. return false;
  1818. }
  1819. return true;
  1820. }
  1821. std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager,
  1822. const AnfNodePtr &front_node) {
  1823. MS_EXCEPTION_IF_NULL(front_func_graph_manager);
  1824. auto &users = front_func_graph_manager->node_users()[front_node];
  1825. std::vector<AnfNodePtr> result;
  1826. for (auto &user : users) {
  1827. if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimDepend) ||
  1828. AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimLoad)) {
  1829. auto depend_cnode = user.first->cast<CNodePtr>();
  1830. if (depend_cnode == nullptr) {
  1831. continue;
  1832. }
  1833. if (front_node != depend_cnode->input(1)) {
  1834. continue;
  1835. }
  1836. auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
  1837. result.insert(result.end(), res.begin(), res.end());
  1838. } else if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimMakeTuple)) {
  1839. auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
  1840. (void)result.insert(result.end(), res.begin(), res.end());
  1841. } else {
  1842. (void)result.emplace_back(user.first);
  1843. }
  1844. }
  1845. return result;
  1846. }
  1847. AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
  1848. MS_EXCEPTION_IF_NULL(front_node);
  1849. if (!front_node->isa<CNode>()) {
  1850. return nullptr;
  1851. }
  1852. if (AnfUtils::IsRealKernel(front_node)) {
  1853. return front_node;
  1854. }
  1855. if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
  1856. return front_node;
  1857. }
  1858. if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimMakeTuple)) {
  1859. auto cnode = front_node->cast<CNodePtr>();
  1860. MS_EXCEPTION_IF_NULL(cnode);
  1861. auto &inputs = cnode->inputs();
  1862. if (inputs.size() > 1) {
  1863. return GetSupportedInternalNode(inputs[1]);
  1864. }
  1865. }
  1866. if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimDepend)) {
  1867. auto cnode = front_node->cast<CNodePtr>();
  1868. MS_EXCEPTION_IF_NULL(cnode);
  1869. auto &inputs = cnode->inputs();
  1870. if (inputs.size() >= kDependInputSize) {
  1871. return GetSupportedInternalNode(inputs[kRealInputIndexInDepend]);
  1872. }
  1873. }
  1874. return nullptr;
  1875. }
  1876. bool IsUnusedInternlOutput(const AnfNodePtr &user) {
  1877. if (!CNodeFirstInputIsPrimitive(user)) {
  1878. return true;
  1879. }
  1880. if (IsPrimitiveCNode(user, prim::kPrimSwitch) || IsPrimitiveCNode(user, prim::kPrimSwitchLayer)) {
  1881. return true;
  1882. }
  1883. if (!AnfUtils::IsRealKernel(user)) {
  1884. return true;
  1885. }
  1886. return false;
  1887. }
  1888. } // namespace
  1889. constexpr auto kMixTarget = "MixTarget";
  1890. constexpr auto kNoTarget = "NoTarget";
  1891. std::string SessionBasic::AddPartialParametersMap(const AnfNodePtr &partial_node) {
  1892. MS_EXCEPTION_IF_NULL(partial_node);
  1893. auto iter = partial_target_map_.find(partial_node);
  1894. if (iter != partial_target_map_.end()) {
  1895. return iter->second;
  1896. }
  1897. auto partial_cnode = partial_node->cast<CNodePtr>();
  1898. MS_EXCEPTION_IF_NULL(partial_cnode);
  1899. auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
  1900. MS_EXCEPTION_IF_NULL(partial_graph);
  1901. auto parameters = partial_graph->parameters();
  1902. auto partial_inputs = partial_cnode->inputs();
  1903. const size_t kNonParameterNum = 2;
  1904. if (parameters.size() + kNonParameterNum != partial_inputs.size()) {
  1905. return kMixTarget;
  1906. }
  1907. for (size_t i = 0; i < parameters.size(); ++i) {
  1908. partial_parameters_map_[parameters[i]] = partial_inputs[kNonParameterNum + i];
  1909. }
  1910. auto graph_nodes = TopoSort(partial_graph->get_return());
  1911. std::string graph_target = kNoTarget;
  1912. for (auto &node : graph_nodes) {
  1913. if (!node->isa<CNode>()) {
  1914. continue;
  1915. }
  1916. if (!AnfUtils::IsRealKernel(node)) {
  1917. continue;
  1918. }
  1919. std::string cur_target = GetCNodeTarget(node);
  1920. if (graph_target == kNoTarget) {
  1921. graph_target = cur_target;
  1922. }
  1923. if (graph_target != cur_target) {
  1924. graph_target = kMixTarget;
  1925. break;
  1926. }
  1927. }
  1928. (void)partial_target_map_.emplace(std::pair<AnfNodePtr, std::string>(partial_node, graph_target));
  1929. return graph_target;
  1930. }
  1931. void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
  1932. const FuncGraphManagerPtr &front_func_graph_manager,
  1933. const std::shared_ptr<KernelGraph> &backend_graph) {
  1934. auto front_node = GetSupportedInternalNode(input_front_node);
  1935. if (front_node == nullptr) {
  1936. return;
  1937. }
  1938. auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0);
  1939. auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0);
  1940. auto backend_real_kernel = backend_real_kernel_pair.first;
  1941. if (backend_real_kernel == nullptr || !backend_real_kernel->isa<CNode>()) {
  1942. return;
  1943. }
  1944. auto front_real_kernel = front_real_kernel_pair.first;
  1945. std::string kernel_target = GetCNodeTarget(front_real_kernel);
  1946. bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel);
  1947. bool unique_target = true;
  1948. if (internal_output && opt::IsNopNode(front_real_kernel)) {
  1949. auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0);
  1950. auto pre_node_target = GetCNodeTarget(pre_node_pair.first);
  1951. if (pre_node_target != kernel_target) {
  1952. unique_target = false;
  1953. }
  1954. }
  1955. if (internal_output) {
  1956. auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
  1957. for (auto &user : users) {
  1958. if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
  1959. !ExistGraphCaller(user)) {
  1960. auto partial_target = AddPartialParametersMap(user);
  1961. if (partial_target != kNoTarget && partial_target != kernel_target) {
  1962. unique_target = false;
  1963. }
  1964. continue;
  1965. }
  1966. if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimUpdateState)) {
  1967. continue;
  1968. }
  1969. if (IsUnusedInternlOutput(user)) {
  1970. internal_output = false;
  1971. break;
  1972. }
  1973. if (kernel_target != GetCNodeTarget(user)) {
  1974. unique_target = false;
  1975. }
  1976. }
  1977. }
  1978. if (internal_output) {
  1979. MS_LOG(INFO) << "AddInternalOutput: " << front_node->DebugString() << " To " << backend_real_kernel->DebugString()
  1980. << ", unique_target: " << unique_target;
  1981. backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target);
  1982. }
  1983. }
  1984. CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
  1985. MS_EXCEPTION_IF_NULL(graph);
  1986. std::vector<AnfNodePtr> output_args;
  1987. for (const auto &output : outputs) {
  1988. MS_EXCEPTION_IF_NULL(output);
  1989. MS_LOG(INFO) << "Output:" << output->DebugString();
  1990. }
  1991. auto FindEqu = [graph, outputs, this](const AnfNodePtr &out) -> AnfNodePtr {
  1992. auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
  1993. if (backend_anf != nullptr) {
  1994. auto context_ptr = MsContext::GetInstance();
  1995. MS_EXCEPTION_IF_NULL(context_ptr);
  1996. if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
  1997. return backend_anf;
  1998. }
  1999. MS_EXCEPTION_IF_NULL(out);
  2000. auto out_func_graph = out->func_graph();
  2001. MS_EXCEPTION_IF_NULL(out_func_graph);
  2002. auto out_func_graph_manager = out_func_graph->manager();
  2003. if (out_func_graph_manager == nullptr) {
  2004. return backend_anf;
  2005. }
  2006. HandleInternalOutput(out, backend_anf, out_func_graph_manager, graph);
  2007. return backend_anf;
  2008. }
  2009. MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
  2010. };
  2011. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  2012. (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
  2013. [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
  2014. return graph->NewCNode(output_args);
  2015. }
  2016. void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
  2017. std::vector<AnfNodePtr> make_tuple_inputs;
  2018. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  2019. MS_EXCEPTION_IF_NULL(graph);
  2020. if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
  2021. for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
  2022. auto idx = NewValueNode(SizeToLong(output_index));
  2023. MS_EXCEPTION_IF_NULL(idx);
  2024. auto imm = std::make_shared<Int64Imm>(output_index);
  2025. idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
  2026. auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
  2027. std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
  2028. std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
  2029. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
  2030. make_tuple_inputs.push_back(getitem);
  2031. }
  2032. } else {
  2033. make_tuple_inputs.push_back(cnode);
  2034. }
  2035. // create output
  2036. auto g_output = graph->NewCNode(make_tuple_inputs);
  2037. graph->set_output(g_output);
  2038. }
  2039. std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  2040. const std::vector<tensor::TensorPtr> &input_tensors,
  2041. const std::vector<int64_t> &tensors_mask,
  2042. bool is_ascend) {
  2043. auto graph = std::make_shared<KernelGraph>();
  2044. graph->set_graph_id(graph_sum_);
  2045. graph_sum_++;
  2046. std::vector<AnfNodePtr> inputs;
  2047. // set input[0]
  2048. auto op_prim = op_run_info.primitive;
  2049. MS_EXCEPTION_IF_NULL(op_prim);
  2050. // Decoupling of frontend PrimitivePy and backend Primitive
  2051. inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*op_prim)));
  2052. // set input parameter
  2053. if (input_tensors.size() != tensors_mask.size()) {
  2054. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  2055. << tensors_mask.size();
  2056. }
  2057. for (size_t i = 0; i < input_tensors.size(); ++i) {
  2058. if (tensors_mask[i] == kValueNodeTensorMask) {
  2059. auto value_node = graph->NewValueNode(input_tensors[i]);
  2060. inputs.push_back(value_node);
  2061. continue;
  2062. }
  2063. auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
  2064. inputs.push_back(parameter);
  2065. auto mutable_inputs = graph->MutableInputs();
  2066. MS_EXCEPTION_IF_NULL(mutable_inputs);
  2067. mutable_inputs->push_back(parameter);
  2068. }
  2069. // set execution order
  2070. auto cnode = graph->NewCNode(inputs);
  2071. MS_EXCEPTION_IF_NULL(cnode);
  2072. // set abstract,which include inferred shapes and types
  2073. cnode->set_abstract(op_run_info.abstract);
  2074. // get output dynamic shape info
  2075. AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(op_run_info.is_dynamic_shape), cnode);
  2076. if (op_run_info.is_auto_mixed_precision) {
  2077. AnfAlgo::SetNodeAttr(kAttrPynativeNextOpName, MakeValue(op_run_info.next_op_name), cnode);
  2078. AnfAlgo::SetNodeAttr(kAttrPynativeNextIndex, MakeValue(op_run_info.next_input_index), cnode);
  2079. }
  2080. // set execution order
  2081. std::vector<CNodePtr> exe_order = {cnode};
  2082. graph->set_execution_order(exe_order);
  2083. // set output
  2084. if (is_ascend) {
  2085. graph->set_output(cnode);
  2086. } else {
  2087. CreateOutputNode(cnode, graph);
  2088. }
  2089. graph->SetInputNodes();
  2090. auto manager = MakeManager({graph});
  2091. if (manager != nullptr) {
  2092. manager->AddFuncGraph(graph);
  2093. graph->set_manager(manager);
  2094. }
  2095. auto ms_context = MsContext::GetInstance();
  2096. MS_EXCEPTION_IF_NULL(ms_context);
  2097. if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
  2098. UnifyMindIR(graph);
  2099. }
  2100. graph->UpdateGraphDynamicAttr();
  2101. return graph;
  2102. }
  2103. KernelGraphPtr SessionBasic::NewKernelGraph() {
  2104. auto graph = std::make_shared<KernelGraph>();
  2105. graph->set_graph_id(graph_sum_);
  2106. graphs_[graph_sum_++] = graph;
  2107. return graph;
  2108. }
  2109. AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list) {
  2110. MS_EXCEPTION_IF_NULL(push_node);
  2111. for (auto &node : node_list) {
  2112. if (node != nullptr && node->isa<CNode>()) {
  2113. for (auto input : node->cast<CNodePtr>()->inputs()) {
  2114. if (push_node == AnfAlgo::VisitKernel(input, 0).first) {
  2115. if (AnfAlgo::GetCNodeName(node) != kPullOpName) {
  2116. MS_LOG(EXCEPTION) << "The edge between Push and Pull node is invalid.";
  2117. }
  2118. return node;
  2119. }
  2120. }
  2121. }
  2122. }
  2123. return nullptr;
  2124. }
  2125. GraphId SessionBasic::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs) {
  2126. MS_EXCEPTION_IF_NULL(executor_);
  2127. return executor_->CompileGraph(shared_from_this(), segment, outputs);
  2128. }
  2129. GraphId SessionBasic::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
  2130. MS_EXCEPTION_IF_NULL(executor_);
  2131. return executor_->CompileGraph(shared_from_this(), func_graph);
  2132. }
  2133. void SessionBasic::BuildGraph(GraphId graph_id) {
  2134. MS_EXCEPTION_IF_NULL(executor_);
  2135. executor_->BuildGraph(shared_from_this(), graph_id);
  2136. }
  2137. void SessionBasic::RunOp(OpRunInfo *op_run_info, VectorRef *outputs) {
  2138. MS_EXCEPTION_IF_NULL(executor_);
  2139. MS_EXCEPTION_IF_NULL(op_run_info);
  2140. executor_->RunOp(shared_from_this(), op_run_info, op_run_info->graph_info, &op_run_info->input_tensors, outputs,
  2141. op_run_info->tensor_mask);
  2142. }
  2143. void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  2144. VectorRef *outputs) {
  2145. MS_EXCEPTION_IF_NULL(executor_);
  2146. executor_->RunOpsInGraph(shared_from_this(), graph_id, inputs, outputs);
  2147. }
  2148. void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  2149. MS_EXCEPTION_IF_NULL(executor_);
  2150. executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs);
  2151. }
  2152. void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  2153. VectorRef *outputs) {
  2154. MS_EXCEPTION_IF_NULL(executor_);
  2155. executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
  2156. }
  2157. void SessionBasic::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  2158. VectorRef *const outputs) {
  2159. MS_LOG(INFO) << "Status record: start run graph. graph id: " << graph_id;
  2160. auto kernel_graph = GetGraph(graph_id);
  2161. MS_EXCEPTION_IF_NULL(kernel_graph);
  2162. // if none of child graph and no anf output exists
  2163. if (!kernel_graph->executable()) {
  2164. MS_LOG(INFO) << "No child graph has anf output";
  2165. return;
  2166. }
  2167. PreExecuteGraph(kernel_graph, inputs, outputs);
  2168. ExecuteGraph(kernel_graph);
  2169. PostExecuteGraph(kernel_graph, inputs, outputs);
  2170. MS_LOG(INFO) << "Status record: end run graph. graph id: " << graph_id;
  2171. }
  2172. device::DeviceAddressType DeviceTargetToDeviceType(const std::string &device_target) {
  2173. static const std::unordered_map<std::string, device::DeviceAddressType> target_type = {
  2174. {"Unknown", device::DeviceAddressType::kUnknown},
  2175. {"Ascend", device::DeviceAddressType::kAscend},
  2176. {"CPU", device::DeviceAddressType::kCPU},
  2177. {"GPU", device::DeviceAddressType::kGPU},
  2178. {"Davinci", device::DeviceAddressType::kAscend}};
  2179. auto iter = target_type.find(device_target);
  2180. if (iter == target_type.end()) {
  2181. MS_LOG(EXCEPTION) << "Not support device target: " << device_target;
  2182. }
  2183. return iter->second;
  2184. }
  2185. void SessionBasic::ProcessInputTensorsForHeterogeneous(const std::string &cur_target,
  2186. const std::vector<tensor::TensorPtr> &input_tensors) {
  2187. for (auto &tensor : input_tensors) {
  2188. auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
  2189. if (device_address != nullptr) {
  2190. if (device_address->DeviceType() != DeviceTargetToDeviceType(cur_target)) {
  2191. tensor->data_sync();
  2192. tensor->set_device_address(nullptr);
  2193. }
  2194. }
  2195. }
  2196. }
  2197. void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  2198. VectorRef *outputs) {
  2199. MS_LOG(INFO) << "Clean task in Queue";
  2200. session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks();
  2201. MS_LOG(INFO) << "Start!";
  2202. auto kernel_graph = GetGraph(graph_id);
  2203. MS_EXCEPTION_IF_NULL(kernel_graph);
  2204. std::map<AnfNodePtr, size_t> parameter_index;
  2205. GetParameterIndex(kernel_graph.get(), inputs, &parameter_index);
  2206. GraphOutputInfo graph_output_info;
  2207. graph_output_info.graph_outputs = outputs;
  2208. CreateOutputPlaceholder(kernel_graph, inputs, graph_output_info.graph_outputs, &graph_output_info.output_indexes);
  2209. std::map<KernelWithIndex, size_t> cnode_refcount;
  2210. GetRefCount(kernel_graph.get(), &cnode_refcount);
  2211. BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);
  2212. // Clear bucket resources every step
  2213. if (kernel_graph->is_bprop()) {
  2214. ClearAllBucket(graph_id);
  2215. }
  2216. std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
  2217. for (const auto &kernel : kernel_graph->execution_order()) {
  2218. // Generate input tensors, tensor masks and input kernel with index
  2219. InputTensorInfo input_tensor_info;
  2220. GetOpInputTensors(kernel, op_output_map, parameter_index, inputs, &input_tensor_info);
  2221. VectorRef op_outputs;
  2222. // Get OpRunInfo and GraphInfo
  2223. GraphInfo graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
  2224. OpRunInfo run_info = GetSingleOpRunInfo(kernel, graph_info, input_tensor_info);
  2225. // Build and run current single op
  2226. RunOpImplOrigin(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
  2227. input_tensor_info.input_tensors_mask);
  2228. graph_output_info.graph_output_tensors.clear();
  2229. // Handle inputs and outputs of current op
  2230. HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
  2231. HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info);
  2232. // Save grad node to Bucket
  2233. if (kernel_graph->is_bprop()) {
  2234. AddGradAddrToBucket(graph_id, graph_output_info.graph_output_tensors);
  2235. }
  2236. }
  2237. MS_LOG(INFO) << "Finish!";
  2238. }
  2239. void SessionBasic::EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask,
  2240. std::vector<tensor::TensorPtr> *input_tensors) const {
  2241. MS_EXCEPTION_IF_NULL(input_tensors);
  2242. if (input_tensors->size() != tensors_mask.size()) {
  2243. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
  2244. << tensors_mask.size();
  2245. }
  2246. std::vector<tensor::TensorPtr> new_input_tensors;
  2247. for (size_t index = 0; index < tensors_mask.size(); ++index) {
  2248. if (tensors_mask[index] != kValueNodeTensorMask) {
  2249. new_input_tensors.emplace_back(input_tensors->at(index));
  2250. }
  2251. }
  2252. *input_tensors = new_input_tensors;
  2253. }
  2254. bool SessionBasic::IsGetNextGraph(const std::shared_ptr<KernelGraph> &kernel_graph, std::string *channel_name) {
  2255. MS_EXCEPTION_IF_NULL(kernel_graph);
  2256. for (const auto &kernel_node : kernel_graph->execution_order()) {
  2257. auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
  2258. if (kernel_name == kGetNextOpName) {
  2259. auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
  2260. MS_EXCEPTION_IF_NULL(prim);
  2261. *channel_name = GetValue<std::string>(prim->GetAttr("shared_name"));
  2262. return true;
  2263. }
  2264. }
  2265. return false;
  2266. }
  2267. void SessionBasic::RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const {
  2268. auto ms_context = MsContext::GetInstance();
  2269. MS_EXCEPTION_IF_NULL(ms_context);
  2270. if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
  2271. opt::RemoveNopNode(kernel_graph.get());
  2272. }
  2273. }
  2274. void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) {
  2275. auto ms_context = MsContext::GetInstance();
  2276. MS_EXCEPTION_IF_NULL(ms_context);
  2277. if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
  2278. opt::HideNopNode(kernel_graph.get());
  2279. }
  2280. }
  2281. std::vector<uint32_t> SessionBasic::GetAllReduceSplitIndex() {
  2282. auto ms_context = MsContext::GetInstance();
  2283. MS_EXCEPTION_IF_NULL(ms_context);
  2284. std::string group = GetCommWorldGroup();
  2285. auto parallel_context = parallel::ParallelContext::GetInstance();
  2286. MS_EXCEPTION_IF_NULL(parallel_context);
  2287. // PyNative not support multi group allreduce
  2288. group += "sum1";
  2289. return parallel_context->GetAllReduceFusionSplitIndices(group);
  2290. }
  2291. uint32_t GetBpropGraphGradsCount(const KernelGraphPtr &graph) {
  2292. return AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}).size();
  2293. }
  2294. void SetGraphBpropAttr(const KernelGraphPtr &graph) {
  2295. auto &execution_orders = graph->execution_order();
  2296. if (std::any_of(execution_orders.begin(), execution_orders.end(),
  2297. [](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) {
  2298. graph->set_is_bprop(true);
  2299. MS_LOG(INFO) << "Match bprop graph";
  2300. } else {
  2301. graph->set_is_bprop(false);
  2302. }
  2303. }
  2304. std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const std::vector<uint32_t> &split_index) {
  2305. if (split_index.empty()) {
  2306. auto grads_count = GetBpropGraphGradsCount(graph);
  2307. if (grads_count == 0) {
  2308. MS_LOG(EXCEPTION) << "Bprop graph has no grad";
  2309. }
  2310. return {grads_count};
  2311. }
  2312. std::vector<uint32_t> bucket_size_list;
  2313. uint32_t old_index = 0;
  2314. for (const auto &index : split_index) {
  2315. if (old_index == 0) {
  2316. bucket_size_list.emplace_back(index - old_index + 1);
  2317. } else {
  2318. bucket_size_list.emplace_back(index - old_index);
  2319. }
  2320. old_index = index;
  2321. }
  2322. return bucket_size_list;
  2323. }
  2324. void CheckSplitIndexValid(const vector<uint32_t> &split_index) {
  2325. uint32_t last = 0;
  2326. for (size_t i = 0; i < split_index.size(); ++i) {
  2327. if (split_index[i] <= last && i != 0) {
  2328. MS_LOG(EXCEPTION) << "Invalid split index:" << split_index;
  2329. }
  2330. last = split_index[i];
  2331. }
  2332. }
  2333. void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) {
  2334. MS_EXCEPTION_IF_NULL(split_index);
  2335. if (split_index->empty()) {
  2336. return;
  2337. }
  2338. CheckSplitIndexValid(*split_index);
  2339. // calculate split index num
  2340. auto split_index_num = split_index->back();
  2341. // obtain graph output tensor num
  2342. auto grads_count = GetBpropGraphGradsCount(graph);
  2343. if (split_index_num >= grads_count) {
  2344. MS_LOG(WARNING) << "The context configuration all_reduce_fusion_config's upper boundary value should be smaller "
  2345. << "than total grads count: " << grads_count << ", but got: " << *split_index
  2346. << ". Now all AllReduce operations will be fused into one AllReduce operation.";
  2347. split_index->clear();
  2348. split_index->push_back(grads_count - 1);
  2349. } else if (split_index_num < grads_count - 1) {
  2350. split_index->push_back(grads_count - 1);
  2351. }
  2352. }
  2353. void SessionBasic::InitAllBucket(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
  2354. MS_EXCEPTION_IF_NULL(graph);
  2355. MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
  2356. auto ms_context = MsContext::GetInstance();
  2357. MS_EXCEPTION_IF_NULL(ms_context);
  2358. const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
  2359. auto parallel_context = parallel::ParallelContext::GetInstance();
  2360. MS_EXCEPTION_IF_NULL(parallel_context);
  2361. auto parallel_mode = parallel_context->parallel_mode();
  2362. if (!pynative_mode || parallel_mode != parallel::DATA_PARALLEL) {
  2363. return;
  2364. }
  2365. SetGraphBpropAttr(graph);
  2366. if (!graph->is_bprop()) {
  2367. return;
  2368. }
  2369. std::vector<std::shared_ptr<device::Bucket>> bucket_list;
  2370. // Create bucket for every split allreduce ops
  2371. auto split_index = GetAllReduceSplitIndex();
  2372. PreProcessOnSplitIndex(graph, &split_index);
  2373. auto bucket_size_list = GenerateBucketSizeList(graph, split_index);
  2374. uint32_t bucket_id = 0;
  2375. for (const auto &bucket_size : bucket_size_list) {
  2376. MS_LOG(INFO) << "Create new bucket:" << bucket_id << " size:" << bucket_size;
  2377. std::shared_ptr<device::Bucket> bucket = nullptr;
  2378. if (device_context != nullptr) {
  2379. bucket = device_context->CreateBucket(bucket_id++, bucket_size);
  2380. } else {
  2381. bucket = CreateBucket(bucket_id++, bucket_size);
  2382. }
  2383. bucket_list.emplace_back(bucket);
  2384. }
  2385. auto bucket_ret = bucket_map_.try_emplace(graph->graph_id(), bucket_list);
  2386. if (!bucket_ret.second) {
  2387. MS_LOG(EXCEPTION) << "Duplicate bucket_map_ graph key:" << graph->graph_id();
  2388. }
  2389. // set all free bucket index to 0
  2390. auto free_bucket_ret = free_bucket_id_map_.try_emplace(graph->graph_id(), 0);
  2391. if (!free_bucket_ret.second) {
  2392. MS_LOG(EXCEPTION) << "Duplicate free_bucket_id_map_ graph key:" << graph->graph_id();
  2393. }
  2394. MS_LOG(INFO) << "Init Bucket finish";
  2395. }
  2396. void SessionBasic::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
  2397. auto parallel_context = parallel::ParallelContext::GetInstance();
  2398. MS_EXCEPTION_IF_NULL(parallel_context);
  2399. auto parallel_mode = parallel_context->parallel_mode();
  2400. if (parallel_mode != parallel::DATA_PARALLEL) {
  2401. return;
  2402. }
  2403. auto iter = bucket_map_.find(graph_id);
  2404. if (iter == bucket_map_.end()) {
  2405. MS_LOG(EXCEPTION) << "unknown graph id:" << graph_id;
  2406. }
  2407. auto &bucket_list = iter->second;
  2408. auto free_bucket_iter = free_bucket_id_map_.find(graph_id);
  2409. if (free_bucket_iter == free_bucket_id_map_.end()) {
  2410. MS_LOG(EXCEPTION) << "unknown free graph id:" << graph_id;
  2411. }
  2412. auto free_bucket_index = free_bucket_iter->second;
  2413. for (auto &tensor : grad_tensor) {
  2414. if (free_bucket_index >= bucket_list.size()) {
  2415. MS_LOG(EXCEPTION) << "Invalid free bucket id:" << free_bucket_iter->second
  2416. << " total bucket num:" << bucket_list.size();
  2417. }
  2418. auto &free_bucket = bucket_list[free_bucket_index];
  2419. free_bucket->AddGradTensor(tensor);
  2420. if (free_bucket->full()) {
  2421. MS_LOG(INFO) << "bucket is full";
  2422. free_bucket->Launch();
  2423. free_bucket_index = ++free_bucket_iter->second;
  2424. MS_LOG(INFO) << "new free bucket:" << free_bucket_index;
  2425. }
  2426. }
  2427. }
  2428. void SessionBasic::ClearAllBucket(const GraphId &graph_id) {
  2429. auto iter = bucket_map_.find(graph_id);
  2430. if (iter != bucket_map_.end()) {
  2431. auto bucket_list = iter->second;
  2432. for (auto &bucket : bucket_list) {
  2433. MS_LOG(INFO) << "Clear bucket:" << bucket->id();
  2434. bucket->Release();
  2435. }
  2436. }
  2437. auto free_iter = free_bucket_id_map_.find(graph_id);
  2438. if (free_iter != free_bucket_id_map_.end()) {
  2439. free_iter->second = 0;
  2440. }
  2441. }
  2442. void SessionBasic::FinalOptimize(const KernelGraphPtr &graph) const {
  2443. MS_LOG(INFO) << "Start FinalOptimize for graph: " << graph->graph_id();
  2444. opt::CommonFinalOptimization(graph);
  2445. MS_LOG(INFO) << "End FinalOptimize for graph: " << graph->graph_id();
  2446. }
  2447. void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
  2448. #ifdef ENABLE_DUMP_IR
  2449. auto context_ptr = MsContext::GetInstance();
  2450. MS_EXCEPTION_IF_NULL(context_ptr);
  2451. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  2452. if (save_graphs) {
  2453. DumpIR("graph_build_" + std::to_string(kernel_graph->graph_id()) + ".ir", kernel_graph, true, kWholeStack);
  2454. DumpIRProto(kernel_graph, "vm_build_" + std::to_string(kernel_graph->graph_id()));
  2455. DumpIR("trace_code_graph", kernel_graph, true, kWholeStack);
  2456. }
  2457. #endif
  2458. }
  2459. void SessionBasic::UnifyMindIR(const KernelGraphPtr &graph) { opt::CommonUnifyMindIR(graph); }
  2460. #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
  2461. void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
  2462. if (!ps::PSContext::instance()->is_worker()) {
  2463. return;
  2464. }
  2465. CheckPSModeConsistence(kernel_graph);
  2466. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  2467. if (!ps::ps_cache_instance.initialized_ps_cache()) {
  2468. auto context_ptr = MsContext::GetInstance();
  2469. MS_EXCEPTION_IF_NULL(context_ptr);
  2470. auto device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  2471. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_target, device_id_);
  2472. MS_EXCEPTION_IF_NULL(runtime_instance);
  2473. auto context = runtime_instance->context();
  2474. const auto &kernels = kernel_graph->execution_order();
  2475. if (kernels.size() > 0 && AnfAlgo::GetCNodeName(kernels[0]) == "InitDataSetQueue") {
  2476. GetBatchElements(kernels[0]);
  2477. ps::ps_cache_instance.Initialize();
  2478. }
  2479. ps::ps_cache_instance.DoProcessData(device_id_, context);
  2480. }
  2481. } else {
  2482. // Assign parameter keys.
  2483. AssignParamKey(kernel_graph);
  2484. }
  2485. }
  2486. void SessionBasic::GetBatchElements(const AnfNodePtr &kernel_node) const {
  2487. auto shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
  2488. auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "types");
  2489. if (shapes.size() != types.size() || shapes.size() == 0 || types.size() == 0) {
  2490. MS_LOG(EXCEPTION) << "Invalid shapes of op[InitDataSetQueue]: shapes size " << shapes.size() << ", types size "
  2491. << types;
  2492. }
  2493. size_t batch_elements = 1;
  2494. const auto &shape = shapes[0];
  2495. for (size_t i = 0; i < shape.size(); ++i) {
  2496. batch_elements *= LongToSize(shape[i]);
  2497. }
  2498. ps::ps_cache_instance.set_batch_elements(batch_elements);
  2499. }
  2500. void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const {
  2501. auto input_nodes = kernel_graph->inputs();
  2502. for (const auto &input_node : input_nodes) {
  2503. if (!input_node->isa<Parameter>()) {
  2504. continue;
  2505. }
  2506. auto pk_node = input_node->cast<ParameterPtr>();
  2507. MS_EXCEPTION_IF_NULL(pk_node);
  2508. auto param_info_ptr = pk_node->param_info();
  2509. const std::string &param_name = pk_node->fullname_with_scope();
  2510. if (param_info_ptr != nullptr && param_info_ptr->init_in_server() &&
  2511. !ps::ps_cache_instance.IsHashTable(param_name)) {
  2512. MS_LOG(EXCEPTION) << "Can not initialize the parameter[" << param_name
  2513. << "] in server, this parameter is used by kernel which executes in device";
  2514. }
  2515. }
  2516. }
  2517. void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
  2518. MS_EXCEPTION_IF_NULL(kernel_graph);
  2519. // PS embeddingLookup cache check.
  2520. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  2521. MS_LOG(EXCEPTION) << "The other parameter can't set ps mode when the embeddingLookup cache is enabled in "
  2522. "parameter server training mode.";
  2523. }
  2524. std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
  2525. for (auto &node : node_list) {
  2526. if (node != nullptr && node->isa<CNode>()) {
  2527. // Assign key for forward kernel EmbeddingLookup.
  2528. // The key will be assigned to embedding table ande Push kernel as well.
  2529. if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
  2530. size_t embedding_table_idx = 0;
  2531. auto embedding_table = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), embedding_table_idx);
  2532. size_t key = ps::Worker::GetInstance().SetParamKey(embedding_table->fullname_with_scope());
  2533. AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
  2534. } else if (AnfAlgo::GetCNodeName(node) == kPushOpName) {
  2535. auto pull_node = FindPullNode(node, node_list);
  2536. if (!pull_node) {
  2537. MS_LOG(EXCEPTION) << "Assigning parameter key failed: can't find Pull node of the Push node.";
  2538. }
  2539. // Second input of Pull node is the trainable parameter.
  2540. size_t parameter_index = 1;
  2541. auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast<CNodePtr>(), parameter_index);
  2542. size_t key = ps::Worker::GetInstance().SetParamKey(parameter_node->fullname_with_scope());
  2543. AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
  2544. AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node);
  2545. std::string optimizer_name = AnfAlgo::GetNodeAttr<std::string>(node, kAttrOptimizerType);
  2546. ps::Worker::GetInstance().SetKeyOptimId(key, optimizer_name);
  2547. }
  2548. }
  2549. }
  2550. }
  2551. void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
  2552. const std::vector<tensor::TensorPtr> &inputs_const) {
  2553. if (!ps::PSContext::instance()->is_worker()) {
  2554. return;
  2555. }
  2556. std::vector<tensor::TensorPtr> inputs(inputs_const);
  2557. MS_EXCEPTION_IF_NULL(kernel_graph);
  2558. auto input_nodes = kernel_graph->inputs();
  2559. auto ms_context = MsContext::GetInstance();
  2560. MS_EXCEPTION_IF_NULL(ms_context);
  2561. for (size_t i = 0; i < inputs.size(); ++i) {
  2562. auto tensor = inputs[i];
  2563. MS_EXCEPTION_IF_NULL(tensor);
  2564. auto input_node = input_nodes[i];
  2565. MS_EXCEPTION_IF_NULL(input_node);
  2566. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  2567. ps::Worker::GetInstance().InitPSParamAndOptim(input_node, tensor);
  2568. }
  2569. }
  2570. }
  2571. #endif
  2572. } // namespace session
  2573. void DumpGraphExeOrder(const std::string &file_name, const std::string &target_dir,
  2574. const std::vector<CNodePtr> &execution_order) {
  2575. std::string file_path = target_dir + "/execution_order/" + file_name;
  2576. auto realpath = Common::CreatePrefixPath(file_path);
  2577. if (!realpath.has_value()) {
  2578. MS_LOG(ERROR) << "Failed to get real path: [" << file_path << "] in dump graph execution order.";
  2579. return;
  2580. }
  2581. file_path = realpath.value();
  2582. ChangeFileMode(file_path, S_IWUSR);
  2583. // write to csv file
  2584. std::ofstream ofs(file_path);
  2585. if (!ofs.is_open()) {
  2586. MS_LOG(ERROR) << "Failed to open file [" << file_path
  2587. << "] in dump graph execution order, please check the file access permission and whether disk space "
  2588. "is available.";
  2589. return;
  2590. }
  2591. ofs << "NodeExecutionOrder-FullNameWithScope\n";
  2592. for (const CNodePtr &node : execution_order) {
  2593. ofs << node->fullname_with_scope() << "\n";
  2594. }
  2595. ofs.close();
  2596. // set file mode to read only by user
  2597. ChangeFileMode(file_path, S_IRUSR);
  2598. }
  2599. uint32_t GetRankId() {
  2600. uint32_t rank_id = 0;
  2601. auto ms_context = MsContext::GetInstance();
  2602. MS_EXCEPTION_IF_NULL(ms_context);
  2603. std::string world_group;
  2604. std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  2605. if (backend == kAscendDevice) {
  2606. world_group = kHcclWorldGroup;
  2607. } else if (backend == kGPUDevice) {
  2608. world_group = kNcclWorldGroup;
  2609. } else {
  2610. MS_LOG(ERROR) << "Invalid backend: " << backend;
  2611. return rank_id;
  2612. }
  2613. if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
  2614. MS_LOG(INFO) << "Failed to get rank id.";
  2615. }
  2616. return rank_id;
  2617. }
  2618. } // namespace mindspore