You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_parser.cc 180 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parser/tensorflow/tensorflow_parser.h"
  17. #include <algorithm>
  18. #include <iostream>
  19. #include "ge/ge_api_types.h"
  20. #include "parser/common/convert/pb2json.h"
  21. #include "parser/common/acl_graph_parser_util.h"
  22. #include "common/util/error_manager/error_manager.h"
  23. #include "external/graph/operator_factory.h"
  24. #include "external/parser/tensorflow_parser.h"
  25. #include "external/register/scope/scope_fusion_pass_register.h"
  26. #include "framework/common/debug/ge_log.h"
  27. #include "framework/omg/parser/parser_api.h"
  28. #include "framework/omg/parser/parser_inner_ctx.h"
  29. #include "graph/debug/ge_attr_define.h"
  30. #include "graph/utils/graph_utils.h"
  31. #include "graph/utils/node_utils.h"
  32. #include "graph/utils/type_utils.h"
  33. #include "iterator_fusion_pass.h"
  34. #include "omg/parser/op_parser.h"
  35. #include "omg/parser/parser_factory.h"
  36. #include "parser/common/acl_graph_parser_util.h"
  37. #include "parser/common/model_saver.h"
  38. #include "parser/common/op_map.h"
  39. #include "parser/common/op_parser_factory.h"
  40. #include "parser/common/parser_fp16_t.h"
  41. #include "parser/common/pass_manager.h"
  42. #include "parser/common/prototype_pass_manager.h"
  43. #include "parser/common/thread_pool.h"
  44. #include "parser/common/parser_utils.h"
  45. #include "parser/common/util.h"
  46. #include "parser/tensorflow/tensorflow_custom_parser_adapter.h"
  47. #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h"
  48. #include "parser/tensorflow/tensorflow_fusion_op_parser.h"
  49. #include "parser/tensorflow/tensorflow_op_parser.h"
  50. #include "parser/tensorflow/tensorflow_util.h"
  51. #include "register/op_registry.h"
  52. #include "register/register_utils.h"
  53. #include "register/scope/scope_pass_registry_impl.h"
  54. #include "parser/common/auto_mapping_subgraph_io_index_func.h"
  55. #include "graph/def_types.h"
  56. using ge::OpParserFactory;
  57. using ge::Pb2Json;
  58. using ge::PreChecker;
  59. using ge::TENSORFLOW_ATTR_DATA_FORMAT;
  60. using ge::TENSORFLOW_ATTR_DTYPE;
  61. using ge::TENSORFLOW_ATTR_SHAPE;
  62. using ge::TENSORFLOW_ATTR_T;
  63. using ge::TENSORFLOW_ATTR_TYPE_STRING;
  64. using ge::TENSORFLOW_ATTR_TYPE_TENSOR;
  65. using ge::TENSORFLOW_ATTR_TYPE_TYPE;
  66. using ge::TENSORFLOW_ATTR_VALUE;
  67. using ge::TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG;
  68. using ge::TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG;
  69. using ge::tensorflow_op_map;
  70. using ge::tensorflow_train_op_map;
  71. using ge::TENSORFLOWF_NODE_OP_CONST;
  72. using ge::TENSORFLOWF_NODE_OP_IDENTITY;
  73. using ge::TENSORFLOWF_NODE_OP_MERGE;
  74. using ge::TENSORFLOWF_NODE_OP_PLACEHOLDER;
  75. using ge::TENSORFLOWF_NODE_OP_SWITCH;
  76. using ge::TENSORFLOWF_NODE_OP_TRANSPOSE;
  77. using ge::TENSORFLOWF_TENSOR_NCHW;
  78. using ge::TENSORFLOWF_TENSOR_NHWC;
  79. using ge::TensorFlowFusionCustomParserAdapter;
  80. using ge::TensorFlowFusionOpParser;
  81. using ge::TensorFlowOpParser;
  82. using ge::ThreadPool;
  83. using ge::parser::fp16_t;
  84. using ge::parser::ModelSaver;
  85. namespace ge {
  86. graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) {
  87. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  88. GE_CHECK_NOTNULL(model_file);
  89. GetParserContext().type = domi::TENSORFLOW;
  90. std::map<string, string> options;
  91. options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::TENSORFLOW)));
  92. // load custom plugin so and proto
  93. AclGrphParseUtil acl_graph_parse_util;
  94. if (acl_graph_parse_util.AclParserInitialize(options) != domi::SUCCESS) {
  95. GELOGE(GRAPH_FAILED, "Parser Initialize failed.");
  96. return GRAPH_FAILED;
  97. }
  98. // Create an empty computegraph
  99. ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
  100. if (compute_graph == nullptr) {
  101. REPORT_CALL_ERROR("E19999", "New ComputeGraph failed");
  102. GELOGE(FAILED, "Create ComputeGraph fail.");
  103. return FAILED;
  104. }
  105. graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  106. auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW);
  107. if (model_parser == nullptr) {
  108. REPORT_CALL_ERROR("E19999", "No Model Parser for tensorflow, check invalid");
  109. GELOGE(GRAPH_FAILED, "No Model Parser for tensorflow, check invalid");
  110. return FAILED;
  111. }
  112. // parse tensorflow model_file to GE graph
  113. ge::graphStatus ret = model_parser->Parse(model_file, graph);
  114. if (ret != ge::SUCCESS) {
  115. GELOGE(ret, "Parser graph %s failed.", ParserUtils::GetGraphName(graph).c_str());
  116. return ge::FAILED;
  117. }
  118. std::map<AscendString, AscendString> parser_params;
  119. if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
  120. GELOGE(ret, "Set graph %s default output node failed.", ParserUtils::GetGraphName(graph).c_str());
  121. return ge::FAILED;
  122. }
  123. GELOGI("Parser graph %s success.", ParserUtils::GetGraphName(graph).c_str());
  124. return ge::SUCCESS;
  125. }
  126. graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<AscendString, AscendString> &parser_params,
  127. ge::Graph &graph) {
  128. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  129. GE_CHECK_NOTNULL(model_file);
  130. GetParserContext().type = domi::TENSORFLOW;
  131. std::map<string, string> options;
  132. options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::TENSORFLOW)));
  133. // load custom plugin so and proto
  134. AclGrphParseUtil acl_graph_parse_util;
  135. domi::Status status = acl_graph_parse_util.AclParserInitialize(options);
  136. if (status != domi::SUCCESS) {
  137. GELOGE(GRAPH_FAILED, "Parser Initialize failed.");
  138. return GRAPH_FAILED;
  139. }
  140. string output_name;
  141. if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) {
  142. GELOGE(ge::FAILED, "Parser params before graph failed.");
  143. return ge::FAILED;
  144. }
  145. // Create an empty computegraph
  146. string graph_name = output_name.empty() ? "tmpGraph" : output_name;
  147. ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name);
  148. if (compute_graph == nullptr) {
  149. REPORT_CALL_ERROR("E19999", "New ComputeGraph failed");
  150. GELOGE(FAILED, "Create ComputeGraph fail.");
  151. return FAILED;
  152. }
  153. graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  154. auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW);
  155. if (model_parser == nullptr) {
  156. REPORT_CALL_ERROR("E19999", "No Model Parser for tensorflow, check invalid");
  157. GELOGE(GRAPH_FAILED, "No Model Parser for tensorflow, check invalid");
  158. return FAILED;
  159. }
  160. // parse tensorflow model_file to GE graph
  161. ge::graphStatus ret = model_parser->Parse(model_file, graph);
  162. if (ret != ge::SUCCESS) {
  163. GELOGE(ret, "Parser graph %s failed.", ParserUtils::GetGraphName(graph).c_str());
  164. return ge::FAILED;
  165. }
  166. if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) {
  167. GELOGE(ge::FAILED, "Parser params after graph failed.");
  168. return ge::FAILED;
  169. }
  170. if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
  171. GELOGE(ge::FAILED, "Set graph %s default output node failed.", ParserUtils::GetGraphName(graph).c_str());
  172. return ge::FAILED;
  173. }
  174. GELOGI("AclgrphParse graph %s success.", ParserUtils::GetGraphName(graph).c_str());
  175. return ge::SUCCESS;
  176. }
  177. } // namespace ge
  178. namespace ge {
  179. namespace {
  180. const int kTransposeInputIdx = 0;
  181. const uint32_t kThreadNum = 16;
  182. const size_t kInputNumUint = 2;
  183. const int kInputNumInt = 2;
  184. const int32_t kControlSlot = -1;
  185. const size_t kSoftmaxMultiple = 2;
  186. const set<string> kTfBlackFields = {"tensor_content"};
  187. const std::vector<std::string> kSkipCheckoutInputSizeNodes = {ge::parser::DATA, ge::parser::VARIABLE,
  188. ge::parser::FRAMEWORKOP, ge::parser::LAYERNORM};
  189. const std::vector<std::string> kMakeOperatorNotByIr = {ge::parser::ARG, ge::parser::VARIABLE, ge::parser::VARHANDLEOP,
  190. ge::parser::FRAMEWORKOP, ge::parser::DATA};
  191. const char *const kDpop = "DPOP";
  192. const char *const kFuncDefLibraryFilePath = "graph_def_library.pbtxt";
  193. const char *const kAttrNameIsScopeInnerNode = "_is_scope_inner_node";
  194. const char *const kExternalModel = "_external_model";
  195. struct ParseArg {
  196. const google::protobuf::Message *proto;
  197. std::string function_name;
  198. ge::NodePtr parent_node;
  199. std::string subgraph_name;
  200. ge::ComputeGraphPtr graph;
  201. };
  202. Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque<ParseArg> &args) {
  203. GELOGI("Gen subgraph parse tasks start");
  204. for (auto &node : parent_graph->GetDirectNode()) {
  205. auto op_desc = node->GetOpDesc();
  206. GE_CHECK_NOTNULL(op_desc);
  207. for (const auto &subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) {
  208. auto i = subgraph_name_to_index.second;
  209. auto subgraph_iname = op_desc->GetSubgraphInstanceName(i);
  210. if (subgraph_iname.empty()) {
  211. GELOGW("The subgraph index %u of node %s is empty", i, node->GetName().c_str());
  212. continue;
  213. }
  214. // A function may be referenced multiple times in TF, change the graph name to ensure it is unique in GE
  215. auto unique_name = node->GetName() + std::to_string(i) + subgraph_iname;
  216. auto subgraph = ge::parser::MakeShared<ge::ComputeGraph>(unique_name);
  217. if (subgraph == nullptr) {
  218. REPORT_CALL_ERROR("E19999", "New ComputeGraph failed when create subgraph:%s", subgraph_iname.c_str());
  219. GELOGE(OUT_OF_MEMORY, "Failed to alloc subgraph %s", subgraph_iname.c_str());
  220. return OUT_OF_MEMORY;
  221. }
  222. auto ret = ge::NodeUtils::SetSubgraph(*node, i, subgraph);
  223. if (ret != SUCCESS) {
  224. REPORT_CALL_ERROR("E19999", "Set subgraph:%s to node:%s(%s) failed, index:%u", subgraph_iname.c_str(),
  225. node->GetName().c_str(), node->GetType().c_str(), i);
  226. GELOGE(ret, "Set subgraph %s to node %s failed, index %u", subgraph_iname.c_str(), node->GetName().c_str(), i);
  227. return ret;
  228. }
  229. GELOGD("Add subgraph parse task to the queue, node %s, index %u, subgraph instance name %s",
  230. node->GetName().c_str(), i, subgraph_iname.c_str());
  231. args.push_back({nullptr, subgraph_iname, node, subgraph_name_to_index.first, subgraph});
  232. }
  233. }
  234. GELOGI("Gen subgraph parse tasks end");
  235. return SUCCESS;
  236. }
  237. Status PostOpProcessForSubgraph(const ParseArg &arg) {
  238. if (arg.parent_node == nullptr) {
  239. return SUCCESS;
  240. }
  241. std::string op_type = arg.parent_node->GetType();
  242. std::string op_name = arg.parent_node->GetName();
  243. domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr;
  244. auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type);
  245. if (post_func == nullptr) {
  246. GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str());
  247. if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS ||
  248. parse_func_v2 == nullptr) {
  249. GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str());
  250. return SUCCESS;
  251. }
  252. }
  253. GELOGD("Post process for subgraph %s node %s type %s subgraph name %s", arg.function_name.c_str(),
  254. arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str(), arg.subgraph_name.c_str());
  255. // refresh node_name in subgraph
  256. for (const ge::NodePtr &node : arg.graph->GetDirectNode()) {
  257. if ((node->GetOpDesc() == nullptr) || (node->GetType() == "Variable") || (node->GetType() == "VariableV2")) {
  258. continue;
  259. }
  260. node->GetOpDesc()->SetName(node->GetOwnerComputeGraph()->GetName() + "/" + node->GetName());
  261. }
  262. auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(arg.graph);
  263. Status ret = FAILED;
  264. if (post_func != nullptr) {
  265. ret = post_func(arg.subgraph_name, graph);
  266. } else if (parse_func_v2 != nullptr) {
  267. ret = parse_func_v2(arg.subgraph_name.c_str(), graph);
  268. }
  269. if (ret != SUCCESS) {
  270. REPORT_CALL_ERROR("E19999", "Call ParseSubgraphPostFunc:%s failed, subgraph:%s, node:%s(%s), ret:0x%X",
  271. arg.function_name.c_str(), arg.subgraph_name.c_str(), arg.parent_node->GetName().c_str(),
  272. arg.parent_node->GetType().c_str(), ret);
  273. GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s subgraph name %s", arg.function_name.c_str(),
  274. arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str(), arg.subgraph_name.c_str());
  275. return FAILED;
  276. }
  277. return SUCCESS;
  278. }
  279. Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, const ComputeGraphPtr &root_graph) {
  280. // Inner function, input params have been checked by caller
  281. Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(
  282. graph,
  283. [](int in, int &out) -> Status {
  284. out = in;
  285. return SUCCESS;
  286. },
  287. [](int in, int &out) -> Status {
  288. out = in;
  289. return SUCCESS;
  290. });
  291. if (status != SUCCESS) {
  292. GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]node:%s, sub graph name:%s.", node->GetName().c_str(),
  293. ParserUtils::GetGraphName(graph).c_str());
  294. REPORT_CALL_ERROR("E19999", "Failed to map sub graph input and output, node:%s, sub graph name:%s.",
  295. node->GetName().c_str(), ParserUtils::GetGraphName(graph).c_str());
  296. return INTERNAL_ERROR;
  297. }
  298. ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
  299. GE_CHECK_NOTNULL(compute_graph);
  300. // Inner function, GetOpDesc has been checked by caller
  301. (void)node->GetOpDesc()->AddSubgraphName("f");
  302. auto ret = NodeUtils::SetSubgraph(*node, 0, compute_graph);
  303. if (ret != GRAPH_SUCCESS) {
  304. GELOGE(INTERNAL_ERROR, "[Set][Subgraph]Node:%s, sub graph name:%s.", node->GetName().c_str(),
  305. compute_graph->GetName().c_str());
  306. REPORT_CALL_ERROR("E19999", "Failed to set sub graph, node: %s, sub graph name: %s.", node->GetName().c_str(),
  307. compute_graph->GetName().c_str());
  308. return INTERNAL_ERROR;
  309. }
  310. for (const auto &sub_graph : compute_graph->GetAllSubgraphs()) {
  311. ret = root_graph->AddSubgraph(sub_graph);
  312. if (ret != GRAPH_SUCCESS) {
  313. GELOGE(INTERNAL_ERROR, "[Add][Subgraph]Node:%s, sub graph name:%s, sub sub graph name:%s.",
  314. node->GetName().c_str(), compute_graph->GetName().c_str(), sub_graph->GetName().c_str());
  315. REPORT_CALL_ERROR("E19999", "Failed to add sub graph to root graph, node:%s, sub graph name:%s.",
  316. node->GetName().c_str(), sub_graph->GetName().c_str());
  317. return INTERNAL_ERROR;
  318. }
  319. compute_graph->RemoveSubgraph(sub_graph);
  320. GELOGD("Add subgraph[%s] to root graph[%s].", sub_graph->GetName().c_str(), root_graph->GetName().c_str());
  321. }
  322. return SUCCESS;
  323. }
  324. } // namespace
  325. /**
  326. * @ingroup domi_omg
  327. * @brief Trans common decorate function to PartitionedCall.
  328. * @param [in] node_def: Node of common function.
  329. * @param [out] op: result of PartitionedCall OpDesc.
  330. * @return 0: SUCCESS / Others: FAILED
  331. */
  332. Status TensorFlowModelParser::DefunToPartitionedCall(const domi::tensorflow::NodeDef *node_def,
  333. ge::OpDescPtr &op) const {
  334. const string op_name = node_def->name();
  335. domi::tensorflow::AttrValue attr_call_inference;
  336. if (!ge::TensorFlowUtil::FindAttrValue(node_def, "_disable_call_shape_inference", attr_call_inference)) {
  337. ErrorManager::GetInstance().ATCReportErrMessage(
  338. "E19014", {"opname", "value", "reason"},
  339. {node_def->name(), "attr [_disable_call_shape_inference]",
  340. "may has no ir definition, if it is not a common decorate function operator"});
  341. GELOGE(FAILED,
  342. "Op %s has no ir definition, or has no attr [_disable_call_shape_inference] "
  343. "if it is a common decorate function operator.",
  344. op_name.c_str());
  345. return FAILED;
  346. }
  347. op = ge::parser::MakeShared<ge::OpDesc>(op_name, ge::parser::PARTITIONEDCALL);
  348. GE_CHECK_NOTNULL(op);
  349. size_t input_tensor_num = 0;
  350. size_t output_tensor_num = 0;
  351. GetInputOutputTensorNum(op, input_tensor_num, output_tensor_num);
  352. for (size_t i = 0; i < input_tensor_num; ++i) {
  353. ge::GeTensorDesc input_tensor;
  354. if (op->AddInputDesc(input_tensor) != ge::GRAPH_SUCCESS) {
  355. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", op->GetName().c_str(), op->GetType().c_str());
  356. GELOGE(FAILED, "op [%s] type[%s] add input(%zu) tensor failed.", op_name.c_str(), op->GetType().c_str(), i);
  357. return FAILED;
  358. }
  359. }
  360. for (size_t i = 0; i < output_tensor_num; ++i) {
  361. ge::GeTensorDesc output_tensor;
  362. if (op->AddOutputDesc(output_tensor) != ge::GRAPH_SUCCESS) {
  363. REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", op->GetName().c_str(), op->GetType().c_str());
  364. GELOGE(FAILED, "op [%s] type[%s] add output(%zu) tensor failed.", op_name.c_str(), op->GetType().c_str(), i);
  365. return FAILED;
  366. }
  367. }
  368. GELOGI("After AddTensorDescToOpDesc op[%s]: type[%s] have input size: %zu, output size: %zu, disable inference: %d",
  369. op_name.c_str(), op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize(), attr_call_inference.b());
  370. (void)op->AddSubgraphName("f");
  371. (void)op->SetSubgraphInstanceName(0, op_name);
  372. return SUCCESS;
  373. }
  374. Status TensorFlowModelParser::TransNodeToOpDesc(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op,
  375. const string &op_type) {
  376. GE_CHECK_NOTNULL(node_def);
  377. string node_name = node_def->name();
  378. ge::Operator op_factory = ge::OperatorFactory::CreateOperator(node_name.c_str(), op_type.c_str());
  379. if (ParserUtils::GetOperatorName(op_factory) != node_name || op_type == ge::parser::DATA) {
  380. if (std::find(kMakeOperatorNotByIr.begin(), kMakeOperatorNotByIr.end(), op_type) != kMakeOperatorNotByIr.end()) {
  381. op = ge::parser::MakeShared<ge::OpDesc>(node_name, op_type);
  382. GE_CHECK_NOTNULL(op);
  383. } else if (node_name == op_type) {
  384. // Trans @tensorflow.python.framework.Defun(...) to PartitionedCall.
  385. GE_RETURN_IF_ERROR(DefunToPartitionedCall(node_def, op));
  386. GE_CHECK_NOTNULL(op);
  387. } else {
  388. ErrorManager::GetInstance().ATCReportErrMessage("E10501", {"opname", "optype"}, {node_name, op_type});
  389. GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str());
  390. return FAILED;
  391. }
  392. } else {
  393. op = ge::OpDescUtils::GetOpDescFromOperator(op_factory);
  394. GE_CHECK_NOTNULL(op);
  395. GELOGI("After GetOpDescFromOperator op[%s]: type[%s] has input size: %zu, output size: %zu", op->GetName().c_str(),
  396. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  397. GE_RETURN_IF_ERROR(AddTensorDescToOpDesc(op, node_def));
  398. GELOGI("After AddTensorDescToOpDesc op[%s]: type[%s] has input size: %zu, output size: %zu", op->GetName().c_str(),
  399. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  400. }
  401. op_factory.BreakConnect();
  402. return SUCCESS;
  403. }
  404. Status TensorFlowModelParser::ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op,
  405. const shared_ptr<OpParser> &op_parser) {
  406. GE_CHECK_NOTNULL(node_def);
  407. GE_CHECK_NOTNULL(op);
  408. GE_CHECK_NOTNULL(op_parser);
  409. string node_name = node_def->name();
  410. string node_op = node_def->op();
  411. Status status = FAILED;
  412. domi::ParseParamByOpFunc parse_param_by_op_fn = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(node_op);
  413. if (parse_param_by_op_fn == nullptr) {
  414. shared_ptr<TensorFlowOpParser> tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
  415. GE_CHECK_NOTNULL(tensorflow_op_parser);
  416. status = tensorflow_op_parser->ParseParams(node_def, op);
  417. if (status != SUCCESS) {
  418. GELOGE(status, "Parse params for node[%s] failed", node_name.c_str());
  419. return status;
  420. }
  421. } else {
  422. ge::Operator op_src(node_def->name().c_str(), node_def->op().c_str());
  423. status = domi::OperatorAutoMapping(node_def, op_src);
  424. if (status != SUCCESS) {
  425. REPORT_CALL_ERROR("E19999", "Auto mapping node_def:%s(%s) to operator failed", node_def->name().c_str(),
  426. node_def->op().c_str());
  427. GELOGE(status, "Node[%s] auto mapping failed.", node_name.c_str());
  428. return status;
  429. }
  430. std::shared_ptr<ge::TensorFlowCustomParserAdapter> tf_custom_op_parser =
  431. std::dynamic_pointer_cast<ge::TensorFlowCustomParserAdapter>(op_parser);
  432. GE_CHECK_NOTNULL(tf_custom_op_parser);
  433. status = tf_custom_op_parser->ParseParams(op_src, op);
  434. if (status != SUCCESS) {
  435. GELOGE(status, "Parse params for node[%s] failed", ParserUtils::GetOperatorName(op_src).c_str());
  436. return status;
  437. }
  438. }
  439. domi::tensorflow::AttrValue attr;
  440. if (ge::TensorFlowUtil::FindAttrValue(node_def, ATTR_NAME_QOS_SERVICE_LABEL, attr)) {
  441. (void)ge::AttrUtils::SetInt(*op, ATTR_NAME_QOS_SERVICE_LABEL, static_cast<int64_t>(attr.i()));
  442. }
  443. return SUCCESS;
  444. }
  445. Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def, ge::ComputeGraphPtr &graph,
  446. shared_ptr<ge::ScopeGraph> &scope_graph) {
  447. GE_CHECK_NOTNULL(node_def);
  448. GE_CHECK_NOTNULL(graph);
  449. GE_CHECK_NOTNULL(scope_graph);
  450. domi::tensorflow::AttrValue attr_value;
  451. if (ge::TensorFlowUtil::FindAttrValue(node_def, kAttrNameIsScopeInnerNode, attr_value) && attr_value.b()) {
  452. std::mutex graph_mutex;
  453. return AddScopeInnerNode(this, graph, &graph_mutex, node_def);
  454. }
  455. // node is released in destructor
  456. string node_name = node_def->name();
  457. string node_op = node_def->op();
  458. std::map<std::string, std::string>::const_iterator type_it = tensorflow_op_map.find(node_op);
  459. if (type_it == tensorflow_op_map.end()) {
  460. GELOGI("Can not find,maybe this node has no plugin node_name is %s, node_op is %s ", node_name.c_str(),
  461. node_op.c_str());
  462. ge::OpDescPtr op_desc;
  463. GE_RETURN_IF_ERROR(TransNodeToOpDesc(node_def, op_desc, node_op));
  464. ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc);
  465. GE_CHK_STATUS(domi::OperatorAutoMapping(node_def, op));
  466. op.BreakConnect();
  467. ge::NodePtr node = nullptr;
  468. node = graph->AddNode(op_desc);
  469. if (node == nullptr) {
  470. DeleteFuisonNodeDef();
  471. GELOGE(FAILED, "add node failed.");
  472. return INTERNAL_ERROR;
  473. }
  474. node_map_[node_name] = node;
  475. return SUCCESS;
  476. }
  477. string op_type = type_it->second;
  478. // The type value is obtained from the definition map set of DaVinci.
  479. ge::OpDescPtr op;
  480. GE_RETURN_IF_ERROR(TransNodeToOpDesc(node_def, op, op_type));
  481. bool needFusion = IsFusionOp(scope_graph, node_def);
  482. // The number of inputs and outputs of each operator can be determined after the new IR design model is resolved.
  483. // Add tensordesc to the opdesc object of the operator
  484. // Process change of tensordesc initialization of opdesc,
  485. // Previous process: Tensordesc is constructed according to graph structure in builder stage
  486. // Current process: Tensordesc is determined before the opdesc of the operator is added to the graph
  487. Status status = FAILED;
  488. // create OpParser
  489. shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
  490. GE_CHECK_NOTNULL(factory);
  491. if (!needFusion) {
  492. shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type);
  493. // parse op param
  494. status = ParseOpParams(node_def, op, op_parser);
  495. if (status != SUCCESS) {
  496. GELOGE(status, "Parse params for node[%s] failed", node_name.c_str());
  497. return status;
  498. }
  499. }
  500. GELOGI("After op parser op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(),
  501. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  502. // checkout op input number with IR
  503. GE_RETURN_IF_ERROR(CheckoutInputNum(op, node_def));
  504. ge::NodePtr node = graph->AddNode(op);
  505. if (node == nullptr) {
  506. DeleteFuisonNodeDef();
  507. GELOGE(FAILED, "add node failed.");
  508. return INTERNAL_ERROR;
  509. }
  510. node_map_[node_name] = node;
  511. if (needFusion) {
  512. shared_ptr<OpParser> fusion_op_parser = factory->CreateFusionOpParser(op_type);
  513. GE_CHECK_NOTNULL(fusion_op_parser);
  514. // Find all children of the fusion operator
  515. std::map<string, vector<const NodeDef *>>::const_iterator iter = fusion_op_nodedef_map_.find(node_def->name());
  516. if (iter == fusion_op_nodedef_map_.end()) {
  517. REPORT_INNER_ERROR("E19999", "FusionOp node %s has no children node, check invalid", node_name.c_str());
  518. GELOGE(FAILED, "FusionOp node %s has no children node!", node_name.c_str());
  519. return INTERNAL_ERROR;
  520. }
  521. vector<const domi::tensorflow::NodeDef *> node_def_v = iter->second;
  522. // parse fusion node param
  523. status = FusionNodeParseParams(fusion_op_parser, node_def, node);
  524. if (status != SUCCESS) {
  525. GELOGE(status, "Parse params for fusion node[%s] failed", node_name.c_str());
  526. return status;
  527. }
  528. // record original op names
  529. std::vector<std::string> namesTmp;
  530. for (auto &node_def_iter : node_def_v) {
  531. GE_CHECK_NOTNULL(node_def_iter);
  532. std::string nodeName = node_def_iter->name();
  533. namesTmp.push_back(nodeName);
  534. }
  535. ge::GraphUtils::RecordOriginalNames(namesTmp, node);
  536. status = RecordFusionResult(scope_graph, node_def, op);
  537. if (status != SUCCESS) {
  538. GELOGE(INTERNAL_ERROR, "Record fusion result for fusion op: %s failed", op->GetName().c_str());
  539. DeleteFuisonNodeDef();
  540. return status;
  541. }
  542. }
  543. return SUCCESS;
  544. }
  545. void TensorFlowModelParser::GetInputOutputTensorNum(const ge::OpDescPtr &op_desc, size_t &input_tensor_num,
  546. size_t &output_tensor_num) const {
  547. // The caller guarantees that the pointer is not null
  548. auto iter = op_node_context_map_.find(op_desc->GetName());
  549. if (iter == op_node_context_map_.end()) {
  550. return;
  551. }
  552. const OpNodeContext &op_context = iter->second;
  553. const std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &dest_input_map = op_context.input_map;
  554. // input number
  555. input_tensor_num = 0;
  556. for (auto &input_vec : dest_input_map) {
  557. for (auto &input_v : input_vec.second) {
  558. if (input_v.second != kControlSlot) {
  559. input_tensor_num++;
  560. }
  561. }
  562. }
  563. // output number
  564. const std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &src_output_map = op_context.output_map;
  565. int32_t max_anchor_index = 0;
  566. for (auto &src_output_iter : src_output_map) {
  567. for (auto &index_output_iter : src_output_iter.second) {
  568. if (index_output_iter.first > max_anchor_index) {
  569. max_anchor_index = index_output_iter.first;
  570. }
  571. }
  572. }
  573. output_tensor_num = max_anchor_index + 1;
  574. }
  575. Status TensorFlowModelParser::CheckoutInputNum(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node) const {
  576. GE_CHECK_NOTNULL(node);
  577. GE_CHECK_NOTNULL(op_desc);
  578. if (std::find(kSkipCheckoutInputSizeNodes.begin(), kSkipCheckoutInputSizeNodes.end(), op_desc->GetType()) !=
  579. kSkipCheckoutInputSizeNodes.end()) {
  580. return SUCCESS;
  581. }
  582. // get input and output tensor number
  583. size_t input_tensor_num = 0;
  584. size_t output_tensor_num = 0;
  585. GetInputOutputTensorNum(op_desc, input_tensor_num, output_tensor_num);
  586. // get input and output tensor number from op desc
  587. size_t factory_input_size = op_desc->GetInputsSize();
  588. if (input_tensor_num != factory_input_size) {
  589. ErrorManager::GetInstance().ATCReportErrMessage(
  590. "E19014", {"opname", "value", "reason"},
  591. {op_desc->GetName(), "input number of tensorflow[" + std::to_string(input_tensor_num) + "]",
  592. "should be equal to factory size[" + std::to_string(factory_input_size) + "]"});
  593. GELOGE(FAILED, "op [%s], type[%s], The input number of tensorflow[%zu] should be equal to factory size[%zu]",
  594. op_desc->GetName().c_str(), op_desc->GetType().c_str(), input_tensor_num, factory_input_size);
  595. return FAILED;
  596. }
  597. return SUCCESS;
  598. }
  599. void TensorFlowModelParser::UpdateInputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &input_desc,
  600. const size_t input_tensor_num) {
  601. // The caller guarantees that the pointer is not null
  602. for (size_t i = 0; i < input_tensor_num; ++i) {
  603. if (i < input_desc.size()) {
  604. // i is guaranteed to be valid, no check required.
  605. ge::graphStatus ret = op_desc->UpdateInputDesc(op_desc->GetInputNameByIndex(i), input_desc[i]);
  606. if (ret != ge::GRAPH_SUCCESS) {
  607. // UpdateInputDesc for dynamic intput will be failed, but it will be added in later op parser.
  608. GELOGI("op [%s], type[%s], input(%zu) with name %s is not updated", op_desc->GetName().c_str(),
  609. op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str());
  610. }
  611. } else {
  612. ge::GeTensorDesc input_tensor;
  613. // i is guaranteed to be valid, no check required.
  614. ge::graphStatus ret = op_desc->UpdateInputDesc(op_desc->GetInputNameByIndex(i), input_tensor);
  615. if (ret != ge::GRAPH_SUCCESS) {
  616. // UpdateInputDesc for dynamic intput will be failed, but it will be added in later op parser.
  617. GELOGI("op [%s], type[%s], input(%zu) with name %s is not updated", op_desc->GetName().c_str(),
  618. op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str());
  619. }
  620. }
  621. }
  622. }
  623. void TensorFlowModelParser::UpdateOutputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &output_desc,
  624. size_t output_tensor_num) {
  625. // The caller guarantees that the pointer is not null
  626. for (size_t i = 0; i < output_tensor_num; ++i) {
  627. if (i < output_desc.size()) {
  628. // i is guaranteed to be valid, no check required.
  629. ge::graphStatus ret = op_desc->UpdateOutputDesc(op_desc->GetOutputNameByIndex(i), output_desc[i]);
  630. if (ret != ge::GRAPH_SUCCESS) {
  631. // UpdateOutputDesc for dynamic output will be failed, but it will be added in later op parser.
  632. GELOGI("op [%s], type[%s], output(%zu) with name %s is not updated", op_desc->GetName().c_str(),
  633. op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str());
  634. }
  635. } else {
  636. ge::GeTensorDesc output_tensor;
  637. // i is guaranteed to be valid, no check required.
  638. ge::graphStatus ret = op_desc->UpdateOutputDesc(op_desc->GetOutputNameByIndex(i), output_tensor);
  639. if (ret != ge::GRAPH_SUCCESS) {
  640. // UpdateOutputDesc for dynamic output will be failed, but it will be added in later op parser.
  641. GELOGI("op [%s], type[%s], output(%zu) with name %s is not updated", op_desc->GetName().c_str(),
  642. op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str());
  643. }
  644. }
  645. }
  646. }
  647. Status TensorFlowModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc,
  648. const domi::tensorflow::NodeDef *node) const {
  649. GE_CHECK_NOTNULL(node);
  650. GE_CHECK_NOTNULL(op_desc);
  651. // get input and output attr from tensorflow
  652. const string type = node->op();
  653. domi::tensorflow::AttrValue input_attr_value;
  654. domi::tensorflow::AttrValue output_attr_value;
  655. ParserOperator temp_op;
  656. if (ge::TensorFlowUtil::FindAttrValue(node, ge::parser::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) {
  657. GE_CHK_STATUS_RET(ge::TensorFlowUtil::TransTensorDescriptor(input_attr_value, &temp_op,
  658. TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG, type),
  659. "trans input_attr_value failed, op: %s", node->name().c_str());
  660. } else {
  661. GELOGD("Frameworkop has no input tensor desc, name:%s, type:%s.", node->name().c_str(), type.c_str());
  662. }
  663. if (ge::TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) {
  664. GE_CHK_STATUS_RET(ge::TensorFlowUtil::TransTensorDescriptor(output_attr_value, &temp_op,
  665. TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG, type),
  666. "trans output_attr_value failed, op: %s", node->name().c_str());
  667. } else {
  668. GELOGD("Frameworkop has no output tensor desc, name:%s, type:%s.", node->name().c_str(), type.c_str());
  669. }
  670. auto iter = op_node_context_map_.find(op_desc->GetName());
  671. if (iter == op_node_context_map_.end()) {
  672. return SUCCESS;
  673. }
  674. const std::vector<ge::GeTensorDesc> &input_desc = temp_op.GetInputTensorDesc();
  675. const std::vector<ge::GeTensorDesc> &output_desc = temp_op.GetOutputTensorDesc();
  676. // get input and output tensor number
  677. size_t input_tensor_num = 0;
  678. size_t output_tensor_num = 0;
  679. GetInputOutputTensorNum(op_desc, input_tensor_num, output_tensor_num);
  680. // update input
  681. UpdateInputTensor(op_desc, input_desc, input_tensor_num);
  682. // update output
  683. UpdateOutputTensor(op_desc, output_desc, output_tensor_num);
  684. return SUCCESS;
  685. }
  686. Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) {
  687. GE_CHECK_NOTNULL(graph);
  688. for (auto &src_iter : op_node_context_map_) {
  689. string src_op_name = src_iter.first;
  690. OpNodeContext src_op_node_context = src_iter.second;
  691. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &src_output_map = src_op_node_context.output_map;
  692. // Traverse all output of the op_node
  693. for (auto &src_output_iter : src_output_map) {
  694. string dest_op_name = src_output_iter.first;
  695. auto dest_iter = op_node_context_map_.find(dest_op_name);
  696. if (dest_iter == op_node_context_map_.end()) {
  697. continue;
  698. }
  699. // Find that the output of the source node is equal to the destination node
  700. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &dest_input_map = dest_iter->second.input_map;
  701. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>>::const_iterator
  702. input_iter = dest_input_map.find(src_op_name);
  703. // Find output and input
  704. if (input_iter == dest_input_map.end()) {
  705. continue;
  706. }
  707. auto iter = node_map_.find(src_op_name);
  708. if (iter == node_map_.end()) {
  709. continue;
  710. }
  711. ge::NodePtr src = iter->second;
  712. GE_CHECK_NOTNULL(src);
  713. auto iter1 = node_map_.find(dest_op_name);
  714. if (iter1 == node_map_.end()) {
  715. continue;
  716. }
  717. // Each pair builds an edge
  718. ge::NodePtr dest = iter1->second;
  719. GE_CHECK_NOTNULL(dest);
  720. if (src_output_iter.second.size() != input_iter->second.size()) {
  721. REPORT_INNER_ERROR("E19999", "Input size of op[%s]:%zu is not equal to Output size of op[%s]:%zu.",
  722. src_op_name.c_str(), input_iter->second.size(), dest_op_name.c_str(),
  723. src_output_iter.second.size());
  724. GELOGE(INTERNAL_ERROR, "Input size of op[%s]:%zu is not equal to Output size of op[%s]:%zu.",
  725. src_op_name.c_str(), input_iter->second.size(), dest_op_name.c_str(), src_output_iter.second.size());
  726. return INTERNAL_ERROR;
  727. }
  728. for (auto &outputpair : src_output_iter.second) {
  729. // Get control edge properties
  730. bool control = GetEdgesControlInfo(dest_op_name, outputpair.second);
  731. // Graph create new edge
  732. if (!control) {
  733. GELOGD("Start add edge: from %s:%d to %s:%d.", src->GetName().c_str(), outputpair.first,
  734. dest->GetName().c_str(), outputpair.second);
  735. ge::OutDataAnchorPtr out_archor_ptr = src->GetOutDataAnchor(outputpair.first);
  736. GE_CHECK_NOTNULL(out_archor_ptr);
  737. ge::InDataAnchorPtr in_archor_ptr = dest->GetInDataAnchor(outputpair.second);
  738. GE_CHECK_NOTNULL(in_archor_ptr);
  739. if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) {
  740. REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed",
  741. src->GetName().c_str(), dest->GetName().c_str());
  742. GELOGE(FAILED, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), dest->GetName().c_str());
  743. return INTERNAL_ERROR;
  744. }
  745. } else {
  746. GELOGD("Start add contorl edge: from %s to %s.", src->GetName().c_str(), dest->GetName().c_str());
  747. ge::InControlAnchorPtr in_archor_ptr = dest->GetInControlAnchor();
  748. GE_CHECK_NOTNULL(in_archor_ptr);
  749. ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor();
  750. GE_CHECK_NOTNULL(out_archor_ptr);
  751. if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) {
  752. REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed",
  753. src->GetName().c_str(), dest->GetName().c_str());
  754. GELOGE(FAILED, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), dest->GetName().c_str());
  755. return INTERNAL_ERROR;
  756. }
  757. }
  758. }
  759. dest_input_map.erase(input_iter);
  760. }
  761. }
  762. return SUCCESS;
  763. }
  764. Status TensorFlowModelParser::AddFmkNodeDefToMap(const domi::tensorflow::NodeDef *node_def,
  765. vector<string> &op_node_name_list) {
  766. GE_CHECK_NOTNULL(node_def);
  767. const string &node_name = node_def->name();
  768. nodedef_map_[node_name] = node_def;
  769. OpNodeContext op_node_context;
  770. op_node_context_map_[node_name] = op_node_context;
  771. op_node_name_list.push_back(node_name);
  772. return SUCCESS;
  773. }
  774. Status TensorFlowModelParser::CheckOpShapeDim(const domi::tensorflow::NodeDef *node_def, const std::set<int> &dims,
  775. bool &valid) {
  776. GE_CHECK_NOTNULL(node_def);
  777. domi::tensorflow::AttrValue input_attr_value;
  778. bool is_attr_exist =
  779. ge::TensorFlowUtil::FindAttrValue(node_def, ge::parser::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value);
  780. GE_IF_BOOL_EXEC(!is_attr_exist, return SUCCESS);
  781. GE_CHK_BOOL_EXEC(input_attr_value.has_list(),
  782. REPORT_INNER_ERROR("E19999", "Attr:%s of node_def:%s(%s) is empty, check invalid",
  783. ge::parser::ATTR_NAME_INPUT_TENSOR_DESC.c_str(), node_def->name().c_str(),
  784. node_def->op().c_str());
  785. return PARAM_INVALID, "output attr value vector is empty");
  786. // list contain many TensorDescriptors
  787. domi::tensorflow::AttrValue_ListValue a_list = input_attr_value.list();
  788. for (int32_t i = 0; i < a_list.func_size(); i++) {
  789. ge::GeTensorDesc ge_desc;
  790. int32_t tf_datatype = 0;
  791. GE_CHK_BOOL_RET_STATUS(ge::TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, i, tf_datatype), PARAM_INVALID,
  792. "parse ge_desc failed.");
  793. for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) {
  794. int64_t temp_dim = ge_desc.GetShape().GetDim(j);
  795. GE_IF_BOOL_EXEC(dims.count(temp_dim) > 0, valid = false);
  796. }
  797. }
  798. return SUCCESS;
  799. }
  800. Status TensorFlowModelParser::CheckOpType(const domi::tensorflow::NodeDef *node_def, string &op_type) {
  801. GE_CHECK_NOTNULL(node_def);
  802. bool valid = true;
  803. string node_name = node_def->name();
  804. std::map<std::string, set<int>> check_dims = {
  805. {ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, {10}},
  806. };
  807. GE_IF_BOOL_EXEC(
  808. op_type == ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS,
  809. GE_CHK_STATUS_RET(CheckOpShapeDim(node_def, check_dims[op_type], valid), "failed to check op shape");
  810. GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP; GELOGI("Set op %s to frameworkop", node_name.c_str());
  811. framework_ops_[node_name] = node_def;););
  812. GE_IF_BOOL_EXEC(
  813. op_type == ge::parser::ADD || op_type == ge::parser::MULTIPLY || op_type == ge::parser::MEAN,
  814. for (const string &input_name
  815. : node_def->input()) {
  816. string tmp_input_name;
  817. GE_RETURN_IF_ERROR(CheckInputNodeName(input_name, &tmp_input_name, nullptr, nullptr));
  818. GELOGD("Add or Mul op %s input name is %s", node_name.c_str(), input_name.c_str());
  819. GE_IF_BOOL_EXEC(framework_ops_.find(tmp_input_name) != framework_ops_.end(),
  820. GELOGI("Set op %s to frameworkop", node_name.c_str());
  821. op_type = ge::parser::FRAMEWORKOP;);
  822. });
  823. return SUCCESS;
  824. }
  825. /*
  826. * @ingroup domi_omg
  827. * @brief Mapping TF's datatype to GE's datatype
  828. * @param [in] type, datatype types of operators in TF networks
  829. * @return ge::DataType
  830. */
  831. ge::DataType TensorFlowModelParser::ConvertToGeDataType(const uint32_t type) {
  832. ErrorManager::GetInstance().GenWorkStreamIdDefault();
  833. ge::DataType data_type = domi::TensorAssign::ConvertTensorflowDataType(type);
  834. return data_type;
  835. }
  836. Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph,
  837. std::mutex *graphMutex, shared_ptr<ge::ScopeGraph> &scope_graph,
  838. const domi::tensorflow::NodeDef *node_def,
  839. error_message::Context error_context) {
  840. ErrorManager::GetInstance().SetErrorContext(error_context);
  841. // The caller guarantees that the pointer is not null
  842. string node_name = node_def->name();
  843. string node_op = node_def->op();
  844. GELOGD("TF op node name = %s, op type= %s", node_name.c_str(), node_op.c_str());
  845. domi::tensorflow::AttrValue attr_value;
  846. if (ge::TensorFlowUtil::FindAttrValue(node_def, kAttrNameIsScopeInnerNode, attr_value) && attr_value.b()) {
  847. return AddScopeInnerNode(parser, graph, graphMutex, node_def);
  848. }
  849. std::map<std::string, std::string>::const_iterator iterator = parser->adaptedOpTypeMap_.find(node_name);
  850. if (iterator == parser->adaptedOpTypeMap_.cend()) {
  851. REPORT_INNER_ERROR("E19999", "get adapted op type failed, node name = %s", node_name.c_str());
  852. GELOGE(FAILED, "get adapted op type failed, node name = %s", node_name.c_str());
  853. return FAILED;
  854. }
  855. string op_type = iterator->second;
  856. // Log printing for determining operator type
  857. domi::ImplyType implyType = domi::OpRegistry::Instance()->GetImplyType(op_type);
  858. GE_IF_BOOL_EXEC((implyType == domi::ImplyType::TVM) && (op_type != ge::parser::FRAMEWORKOP),
  859. GELOGD("TBE %s parsering", node_op.c_str()););
  860. GE_IF_BOOL_EXEC((implyType == domi::ImplyType::CCE) && (op_type != ge::parser::FRAMEWORKOP),
  861. GELOGD("CCE %s parsering", node_op.c_str()););
  862. GE_IF_BOOL_EXEC((implyType == domi::ImplyType::HCCL) && (op_type != ge::parser::FRAMEWORKOP),
  863. GELOGD("HCCL %s parsering", node_op.c_str()););
  864. GE_IF_BOOL_EXEC(op_type == ge::parser::FRAMEWORKOP, GELOGD("FRAMEWORKOP %s parsering", node_op.c_str()););
  865. GELOGD("TF op node name = %s, op type= %s, trans to op type %s", node_name.c_str(), node_op.c_str(), op_type.c_str());
  866. // Construct operator by IR
  867. ge::OpDescPtr op;
  868. ge::Operator op_factory = ge::OperatorFactory::CreateOperator(node_name.c_str(), op_type.c_str());
  869. if (ParserUtils::GetOperatorName(op_factory) != node_name) {
  870. if (std::find(kMakeOperatorNotByIr.begin(), kMakeOperatorNotByIr.end(), op_type) != kMakeOperatorNotByIr.end()) {
  871. op = ge::parser::MakeShared<ge::OpDesc>(node_name, op_type);
  872. GE_CHECK_NOTNULL(op);
  873. } else if (node_name == op_type) {
  874. GE_RETURN_IF_ERROR(parser->DefunToPartitionedCall(node_def, op));
  875. GE_CHECK_NOTNULL(op);
  876. ge::Operator op_tmp = ge::OpDescUtils::CreateOperatorFromOpDesc(op);
  877. GE_CHK_STATUS(domi::OperatorAutoMapping(node_def, op_tmp));
  878. op_tmp.BreakConnect();
  879. ge::NodePtr node;
  880. {
  881. std::lock_guard<std::mutex> lock(*graphMutex);
  882. node = graph->AddNode(op);
  883. }
  884. GE_CHECK_NOTNULL(node);
  885. {
  886. std::lock_guard<std::mutex> lock(parser->nodeMapMutex_);
  887. parser->node_map_[node_name] = node;
  888. }
  889. return SUCCESS;
  890. } else {
  891. REPORT_INPUT_ERROR("E10501", std::vector<std::string>({"opname", "optype"}),
  892. std::vector<std::string>({node_name, op_type}));
  893. GELOGE(INTERNAL_ERROR, "op[%s] type[%s] have no ir factory.]", node_name.c_str(), op_type.c_str());
  894. return FAILED;
  895. }
  896. } else {
  897. op = ge::OpDescUtils::GetOpDescFromOperator(op_factory);
  898. GE_CHECK_NOTNULL(op);
  899. GELOGD("After GetOpDescFromOperator op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(),
  900. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  901. GE_RETURN_IF_ERROR(parser->AddTensorDescToOpDesc(op, node_def));
  902. GELOGD("After AddTensorDescToOpDesc op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(),
  903. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  904. }
  905. GELOGD("TF op node name = %s, outpusize= %zu", node_name.c_str(), op->GetAllOutputsDesc().size());
  906. op_factory.BreakConnect();
  907. // create OpParser
  908. shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
  909. GE_CHECK_NOTNULL(factory);
  910. bool needFusion = parser->IsFusionOp(scope_graph, node_def);
  911. GELOGD("TF op node name = %s, op type= %s is fusion op(NO: 0; YES: 1)= %d", node_name.c_str(), node_op.c_str(),
  912. needFusion);
  913. Status status = FAILED;
  914. if (!needFusion) {
  915. shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type);
  916. status = parser->ParseOpParams(node_def, op, op_parser);
  917. if (status != SUCCESS) {
  918. GELOGE(status, "Parse params for node[%s] failed", node_name.c_str());
  919. return status;
  920. }
  921. }
  922. GELOGD("After op parser op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(),
  923. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  924. // checkout op input number with IR
  925. GE_RETURN_IF_ERROR(parser->CheckoutInputNum(op, node_def));
  926. if (needFusion) {
  927. status = RecordFusionResult(scope_graph, node_def, op);
  928. if (status != SUCCESS) {
  929. GELOGE(INTERNAL_ERROR, "Record fusion result for fusion op: %s failed", op->GetName().c_str());
  930. return status;
  931. }
  932. }
  933. ge::NodePtr node;
  934. {
  935. std::lock_guard<std::mutex> lock(*graphMutex);
  936. node = graph->AddNode(op);
  937. }
  938. if (node == nullptr) {
  939. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op->GetName().c_str(),
  940. op->GetType().c_str(), graph->GetName().c_str());
  941. GELOGE(FAILED, "add node failed.");
  942. return INTERNAL_ERROR;
  943. }
  944. if (needFusion) {
  945. shared_ptr<OpParser> fusion_op_parser = factory->CreateFusionOpParser(op_type);
  946. status = parser->FusionNodeParseParams(fusion_op_parser, node_def, node);
  947. GE_CHK_STATUS_EXEC(status, return status, "Parse Params for node %s failed", node_name.c_str());
  948. }
  949. {
  950. std::lock_guard<std::mutex> lock(parser->nodeMapMutex_);
  951. parser->node_map_[node_name] = node;
  952. }
  953. return SUCCESS;
  954. }
  955. Status TensorFlowModelParser::AdaptOpType(const domi::tensorflow::NodeDef *node_def, bool isDatasetInit) {
  956. // The caller guarantees that the pointer is not null
  957. string node_name = node_def->name();
  958. string node_op = node_def->op();
  959. string op_type;
  960. if (tensorflow_train_op_map.find(node_op) != tensorflow_train_op_map.end()) {
  961. op_type = tensorflow_train_op_map.at(node_op);
  962. GE_CHK_STATUS_RET(CheckOpType(node_def, op_type), "Failed to check op type");
  963. } else {
  964. op_type = ge::parser::FRAMEWORKOP;
  965. domi::tensorflow::AttrValue attr_call_inference;
  966. if ((node_name == node_op) &&
  967. ge::TensorFlowUtil::FindAttrValue(node_def, "_disable_call_shape_inference", attr_call_inference)) {
  968. op_type = node_op;
  969. }
  970. }
  971. GE_IF_BOOL_EXEC(isDatasetInit, op_type = ge::parser::FRAMEWORKOP);
  972. adaptedOpTypeMap_[node_name] = op_type;
  973. return SUCCESS;
  974. }
  975. Status TensorFlowModelParser::AddFmkNode(ge::ComputeGraphPtr &graph, shared_ptr<ge::ScopeGraph> &scope_graph,
  976. vector<string> &op_node_name_list, bool is_dataset_init) {
  977. GE_CHECK_NOTNULL(graph);
  978. GE_CHECK_NOTNULL(scope_graph);
  979. GE_RETURN_IF_ERROR(AddFusionNodeDef(scope_graph, op_node_name_list));
  980. size_t op_node_list_size = op_node_name_list.size();
  981. for (size_t i = 0; i < op_node_list_size; ++i) {
  982. const string op_node_name = op_node_name_list[i];
  983. const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name];
  984. GE_CHECK_NOTNULL(node_def);
  985. GE_RETURN_IF_ERROR(AdaptOpType(node_def, is_dataset_init));
  986. }
  987. GELOGD("Add fusion nodedef and Adapt op type success");
  988. // Multithreading parallel parsing nodedef
  989. ThreadPool executor(kThreadNum);
  990. std::mutex graphMutex;
  991. std::vector<std::future<Status>> vectorFuture(op_node_list_size);
  992. ge::ComputeGraphPtr graph_tmp = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
  993. GE_CHECK_NOTNULL(graph_tmp);
  994. for (size_t j = 0; j < op_node_list_size; j++) {
  995. const string op_node_name = op_node_name_list[j];
  996. const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name];
  997. GE_CHECK_NOTNULL(node_def);
  998. std::future<Status> f =
  999. executor.commit(TensorFlowModelParser::ParseNodeDef, this, graph_tmp, &graphMutex, scope_graph, node_def,
  1000. ErrorManager::GetInstance().GetErrorManagerContext());
  1001. if (!f.valid()) {
  1002. GELOGE(FAILED, "Future is invalid");
  1003. return FAILED;
  1004. }
  1005. vectorFuture[j] = std::move(f);
  1006. }
  1007. GELOGD("Parse nodedef success");
  1008. // Wait for the return value of each thread. If the thread does not finish processing, it will block here
  1009. bool ret_flag = true;
  1010. size_t futureSize = vectorFuture.size();
  1011. for (size_t i = 0; i < futureSize; ++i) {
  1012. Status retStatus = vectorFuture[i].get();
  1013. if (retStatus != SUCCESS) {
  1014. ret_flag = false;
  1015. }
  1016. }
  1017. if (!ret_flag) {
  1018. return FAILED;
  1019. }
  1020. return AddNodeToGraphAndMarkFormat(graph, op_node_name_list);
  1021. }
  1022. Status TensorFlowModelParser::AddNodeToGraphAndMarkFormat(ge::ComputeGraphPtr &graph,
  1023. const vector<string> &op_node_name_list) {
  1024. // Add ge:: nodeptr to graph in order
  1025. size_t op_node_list_size = op_node_name_list.size();
  1026. for (size_t j = 0; j < op_node_list_size; j++) {
  1027. const string op_node_name = op_node_name_list[j];
  1028. auto iterator = node_map_.find(op_node_name);
  1029. if (iterator == node_map_.end()) {
  1030. REPORT_INNER_ERROR("E19999", "node:%s can't find in node_map_, check invalid", op_node_name.c_str());
  1031. GELOGE(FAILED, "add node failed.");
  1032. return INTERNAL_ERROR;
  1033. }
  1034. GE_CHECK_NOTNULL(iterator->second);
  1035. GE_CHK_STATUS_RET(iterator->second->SetOwnerComputeGraph(graph), "set owner compute graph failed");
  1036. graph->AddNode(iterator->second);
  1037. }
  1038. return SUCCESS;
  1039. }
  1040. Status TensorFlowModelParser::ExcuteScopeFusionPasses(domi::tensorflow::GraphDef *const graph_def,
  1041. shared_ptr<ge::ScopeGraph> &scope_graph) {
  1042. // Identifying scope fusion operators based on scope rules
  1043. GE_CHECK_NOTNULL(graph_def);
  1044. ScopePassManager passmanager;
  1045. PARSER_TIMESTAMP_START(BuildScopeGraph);
  1046. scope_graph = passmanager.BuildScopeGraph(graph_def);
  1047. GE_CHECK_NOTNULL(scope_graph);
  1048. PARSER_TIMESTAMP_END(BuildScopeGraph, "TensorFlowModelParser::BuildScopeGraph");
  1049. PARSER_TIMESTAMP_START(ScopeGraphPass);
  1050. // Validate the non-general scope fusion pass.
  1051. // The parameter is set to the name of the fusion rule.
  1052. // Multiple names can be set and separated by ",".
  1053. std::vector<std::string> enable_pass_names =
  1054. ge::StringUtils::Split(ge::GetParserContext().enable_scope_fusion_passes, ',');
  1055. auto &impl = ge::ScopeFusionPassRegistry::GetInstance().impl_;
  1056. if (impl == nullptr) {
  1057. REPORT_INNER_ERROR("E19999", "ScopeFusionPassRegistry is not properly initialized.");
  1058. GELOGE(ge::MEMALLOC_FAILED, "ScopeFusionPassRegistry is not properly initialized.");
  1059. return ge::MEMALLOC_FAILED;
  1060. }
  1061. for (size_t i = 0; i < enable_pass_names.size(); ++i) {
  1062. if (enable_pass_names[i].empty()) {
  1063. continue;
  1064. }
  1065. if (!impl->SetPassEnableFlag(enable_pass_names[i], true)) {
  1066. GELOGW("Failed to set enable flag of scope fusion pass:%s", enable_pass_names[i].c_str());
  1067. }
  1068. }
  1069. std::vector<std::string> scope_passes_list = impl->GetAllRegisteredPasses();
  1070. Status ret = RunScopeFusionPass(scope_passes_list, passmanager, scope_graph);
  1071. if (ret != SUCCESS) {
  1072. GELOGE(ret, "Run scope fusion failed, ret:%u.", ret);
  1073. return ret;
  1074. }
  1075. PARSER_TIMESTAMP_END(ScopeGraphPass, "TensorFlowModelParser::ScopeGraphPass");
  1076. return SUCCESS;
  1077. }
  1078. Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) {
  1079. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  1080. GE_CHECK_NOTNULL(data);
  1081. GE_CHECK_NOTNULL(graph);
  1082. // Store objects parsed from pb files
  1083. domi::tensorflow::GraphDef OriDef;
  1084. bool read = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &OriDef);
  1085. if (!read) {
  1086. REPORT_INNER_ERROR("E19999", "read graph proto from binary failed");
  1087. GELOGE(FAILED, "read_proto_from_binary failed.");
  1088. return INTERNAL_ERROR;
  1089. }
  1090. domi::tensorflow::GraphDef graph_def;
  1091. const bool is_empty_input = GetParserContext().input_dims.empty() && GetParserContext().out_nodes_map.empty();
  1092. if (is_empty_input) {
  1093. graph_def = OriDef;
  1094. } else {
  1095. GELOGI("Before Trim, the Graph Node size is:%d", OriDef.node_size());
  1096. if (static_cast<bool>(TrimGraph(OriDef, &graph_def))) {
  1097. GELOGE(FAILED, "Trim Graph fail.");
  1098. return INTERNAL_ERROR;
  1099. }
  1100. GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size());
  1101. }
  1102. GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW),
  1103. "Run ProtoType Pass Failed");
  1104. shared_ptr<ge::ScopeGraph> scope_graph = nullptr;
  1105. Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph);
  1106. if (ret != SUCCESS) {
  1107. GELOGE(ret, "[TF ParseFromMemory] scope fusion failed.");
  1108. return ret;
  1109. }
  1110. GELOGD("[TF ParseFromMemory] scope fusion success");
  1111. // Add nodedef in the model to prechecker and check the general parameters
  1112. for (int i = 0; i < graph_def.node_size(); i++) {
  1113. const domi::tensorflow::NodeDef &node = graph_def.node(i);
  1114. GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()),
  1115. "Add node_def to PreChecker failed, node name: %s.", node.name().c_str());
  1116. GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().CheckName(&node), "Check node_def name failed, node name: %s.",
  1117. node.name().c_str());
  1118. if (node.op() != TENSORFLOWF_NODE_OP_IDENTITY) {
  1119. GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().CheckType(&node, true),
  1120. "Check node_def type failed, node name: %s.", node.name().c_str());
  1121. }
  1122. }
  1123. bool has_error = false;
  1124. // save node name
  1125. vector<string> op_node_name_list;
  1126. for (int i = 0; i < graph_def.node_size(); i++) {
  1127. const domi::tensorflow::NodeDef *node_def = graph_def.mutable_node(i);
  1128. // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_
  1129. GE_IF_BOOL_EXEC(MaybeFusionOp(scope_graph, node_def),
  1130. GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str()););
  1131. // Do not exit immediately when there is an error, wait until all errors are collected before exiting
  1132. GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(node_def, op_node_name_list), has_error = true,
  1133. "add node failed.");
  1134. }
  1135. // The fusion operator has passed the verification.
  1136. // The errors of internal non key operators (which will be ignored later)
  1137. // do not affect the transformation of the whole model,
  1138. // So clear the error information of non key operators
  1139. // This function call affects the return value of prechecker::instance().Haserror()
  1140. GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list));
  1141. // Building input and input relationships for all OP nodes
  1142. GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def));
  1143. GELOGD("[TF ParseFromMemory] get op nodes context from graph success");
  1144. // Infer input formats
  1145. ge::GetParserContext().format = InferInputFormats();
  1146. GELOGD("[TF ParseFromMemory] infer input formats success");
  1147. // Building input-output relationship between fusionop and common op
  1148. GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, op_node_name_list));
  1149. ret = AddFusionNodeDef(scope_graph, op_node_name_list);
  1150. if (ret != SUCCESS) {
  1151. GELOGE(ret, "Add fusion NodeDef failed.");
  1152. DeleteFuisonNodeDef();
  1153. return ret;
  1154. }
  1155. GELOGI("TF op node size = %zu.", op_node_name_list.size());
  1156. // Loop analysis of op_nodes and map them to nodes in graph
  1157. for (size_t i = 0; i < op_node_name_list.size(); i++) {
  1158. GELOGI("TF op node name = %s.", op_node_name_list[i].c_str());
  1159. const string op_node_name = op_node_name_list[i];
  1160. const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name_list[i]];
  1161. if (node_def == nullptr) {
  1162. REPORT_INNER_ERROR("E19999", "Node:%s can't find in nodedef_map_, check invalid", op_node_name.c_str());
  1163. GELOGE(INTERNAL_ERROR, "Node def is nullptr, name:%s.", op_node_name.c_str());
  1164. DeleteFuisonNodeDef();
  1165. return INTERNAL_ERROR;
  1166. }
  1167. const string &node_op = node_def->op();
  1168. if (tensorflow_op_map.find(node_op) == tensorflow_op_map.cend()) {
  1169. DeleteFuisonNodeDef();
  1170. REPORT_INNER_ERROR("E19999", "Op type %s unsupport", node_op.c_str());
  1171. GELOGE(FAILED, "Unsupport op type %s", node_op.c_str());
  1172. return INTERNAL_ERROR;
  1173. }
  1174. ret = AddNode(node_def, graph, scope_graph);
  1175. if (ret != SUCCESS) {
  1176. GELOGE(ret, "Add node failed, name:%s.", op_node_name.c_str());
  1177. DeleteFuisonNodeDef();
  1178. return ret;
  1179. }
  1180. }
  1181. DeleteFuisonNodeDef();
  1182. GE_RETURN_IF_ERROR(AddEdges(graph));
  1183. GE_RETURN_IF_ERROR(graph->TopologicalSorting());
  1184. has_error = has_error || PreChecker::Instance().HasError();
  1185. if (has_error) {
  1186. GELOGE(PARAM_INVALID, "Precheck has errors.");
  1187. return PARAM_INVALID;
  1188. }
  1189. GELOGI("[TF ParseFromMemory] Parse from memory success.");
  1190. return SUCCESS;
  1191. }
  1192. Status TensorFlowModelParser::GetFunctionProto(const string &file,
  1193. domi::tensorflow::GraphDefLibrary &graph_def_library) {
  1194. int pos = file.rfind('/');
  1195. string graph_def_path = (pos == -1) ? kFuncDefLibraryFilePath : file.substr(0, pos) + "/" + kFuncDefLibraryFilePath;
  1196. GELOGI("Function def libraray path is %s.", graph_def_path.c_str());
  1197. bool read = ge::parser::ReadProtoFromText(graph_def_path.c_str(), &graph_def_library);
  1198. if (!read) {
  1199. GELOGE(INTERNAL_ERROR,
  1200. "Get subgraph library failed. "
  1201. "The model contains function operators. "
  1202. "Need to use the script func2graph.py in the atc package to save the subgraphs to graph_def_library.pbtxt");
  1203. ErrorManager::GetInstance().ATCReportErrMessage("E12029");
  1204. return FAILED;
  1205. }
  1206. GELOGI("Get subgraph library success.");
  1207. return SUCCESS;
  1208. }
  1209. Status TensorFlowModelParser::Parse(const char *file, ge::Graph &graph) {
  1210. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  1211. GE_CHECK_NOTNULL(file);
  1212. ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
  1213. GE_CHECK_NOTNULL(root_graph);
  1214. Status ret = Parse(file, root_graph);
  1215. if (ret != SUCCESS) {
  1216. GELOGE(ret, "Parser graph %s failed.", ParserUtils::GetGraphName(graph).c_str());
  1217. return ret;
  1218. }
  1219. GELOGI("Parser graph %s success.", ParserUtils::GetGraphName(graph).c_str());
  1220. return SUCCESS;
  1221. }
  1222. Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &root_graph) {
  1223. GE_CHECK_NOTNULL(model_path);
  1224. GE_CHECK_NOTNULL(root_graph);
  1225. GELOGI("Parse file %s", model_path);
  1226. // Store objects parsed from pb files
  1227. domi::tensorflow::GraphDef ori_def;
  1228. bool read = ge::parser::ReadProtoFromBinaryFile(model_path, &ori_def);
  1229. if (!read) {
  1230. GELOGE(FAILED, "read tensorflow file failed when the inupt param value of --framework is 3.");
  1231. return INTERNAL_ERROR;
  1232. }
  1233. // Trim graph by user input and output.
  1234. domi::tensorflow::GraphDef graph_def;
  1235. if (ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) {
  1236. graph_def = ori_def;
  1237. } else {
  1238. GELOGI("Before Trim, the Graph Node size is:%d", ori_def.node_size());
  1239. if (static_cast<bool>(TrimGraph(ori_def, &graph_def))) {
  1240. GELOGE(FAILED, "Trim Graph fail.");
  1241. return INTERNAL_ERROR;
  1242. }
  1243. GELOGI("After Trim, The graph_def.node size is:%d", graph_def.node_size());
  1244. }
  1245. // Construct ParseArg for root graph.
  1246. google::protobuf::Message *root_proto = &graph_def;
  1247. std::deque<ParseArg> tasks;
  1248. tasks.push_back({root_proto, "root", nullptr, "", root_graph});
  1249. // Get sub graph from graph_def_library.pbtxt which prepared before and stored in model_path.
  1250. std::map<std::string, domi::tensorflow::GraphDef> function_name_to_graphdef;
  1251. // Parse all root graph and sub graph.
  1252. while (!tasks.empty()) {
  1253. auto arg = tasks.front();
  1254. tasks.pop_front();
  1255. if (arg.proto == nullptr) {
  1256. if (function_name_to_graphdef.empty() && (ori_def.library().function_size() > 0)) {
  1257. GELOGI("Graph has function size: %d ", ori_def.library().function_size());
  1258. domi::tensorflow::GraphDefLibrary graph_def_library;
  1259. GE_CHK_STATUS_RET(GetFunctionProto(model_path, graph_def_library));
  1260. for (auto &ge_graph_def : graph_def_library.graph_def()) {
  1261. function_name_to_graphdef[ge_graph_def.name()] = ge_graph_def.graph();
  1262. GELOGD("Graph_def name: %s, node size: %d", ge_graph_def.name().c_str(), ge_graph_def.graph().node_size());
  1263. }
  1264. }
  1265. const std::map<std::string, domi::tensorflow::GraphDef>::const_iterator
  1266. iter = function_name_to_graphdef.find(arg.function_name);
  1267. if (iter == function_name_to_graphdef.end()) {
  1268. ErrorManager::GetInstance().ATCReportErrMessage("E12013", {"functionname"}, {arg.function_name});
  1269. GELOGE(FAILED, "Failed to get subgraph by function name %s", arg.function_name.c_str());
  1270. return FAILED;
  1271. }
  1272. arg.proto = &(iter->second);
  1273. }
  1274. GELOGI("Begin to parse graph %s", arg.function_name.c_str());
  1275. auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW);
  1276. auto ret = model_parser->ParseAllGraph(arg.proto, arg.graph);
  1277. if (ret != SUCCESS) {
  1278. GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(),
  1279. arg.graph->GetName().c_str());
  1280. return ret;
  1281. }
  1282. ret = PostOpProcessForSubgraph(arg);
  1283. if (ret != SUCCESS) {
  1284. // the error log has been printed inner the function
  1285. return ret;
  1286. }
  1287. ret = GenSubgraphParseTasks(arg.graph, tasks);
  1288. if (ret != SUCCESS) {
  1289. REPORT_CALL_ERROR("E19999", "Failed to gen tasks on graph:%s for next iteration", arg.graph->GetName().c_str());
  1290. GELOGE(ret, "Failed to gen tasks on graph %s for next iteration", arg.graph->GetName().c_str());
  1291. return ret;
  1292. }
  1293. }
  1294. return SUCCESS;
  1295. }
  1296. Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) {
  1297. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  1298. GE_CHECK_NOTNULL(proto);
  1299. GE_CHECK_NOTNULL(graph);
  1300. const domi::tensorflow::GraphDef *ori_graph =
  1301. ge::PtrToPtr<google::protobuf::Message, domi::tensorflow::GraphDef>(proto);
  1302. // Make a copy for operation without modifying the original graph def.
  1303. domi::tensorflow::GraphDef graph_def = *ori_graph;
  1304. GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW),
  1305. "Run ProtoType Pass Failed");
  1306. shared_ptr<ge::ScopeGraph> scope_graph = nullptr;
  1307. Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph);
  1308. if (ret != SUCCESS) {
  1309. GELOGE(ret, "[TF Parse] scope fusion failed.");
  1310. return ret;
  1311. }
  1312. GELOGD("[TF Parse] scope fusion success");
  1313. GE_RETURN_IF_ERROR(OptimizeConstNodes4CustomOp(&graph_def));
  1314. GELOGD("[TF Parse] optimize const nodes for custom op base success");
  1315. // Add nodedef in the model to prechecker and check the general parameters
  1316. // Prevent data residue in multiple calls
  1317. PreChecker::Instance().Clear();
  1318. for (int i = 0; i < graph_def.node_size(); i++) {
  1319. const domi::tensorflow::NodeDef &node = graph_def.node(i);
  1320. GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()),
  1321. "Add node_def to PreChecker failed, node name: %s.", node.name().c_str());
  1322. if (PreChecker::Instance().CheckName(&node) != SUCCESS) {
  1323. GELOGE(FAILED, "Check op[%s] failed, name repeat in tensorflow pb file.", node.name().c_str());
  1324. return FAILED;
  1325. }
  1326. if (node.op() != TENSORFLOWF_NODE_OP_IDENTITY) {
  1327. if (PreChecker::Instance().CheckType(&node, true) != SUCCESS) {
  1328. GELOGE(FAILED, "Check op[%s]'s optype failed, type is not supported.", node.name().c_str());
  1329. return FAILED;
  1330. }
  1331. }
  1332. }
  1333. bool has_error = false;
  1334. // save node name
  1335. vector<string> op_node_name_list;
  1336. for (int i = 0; i < graph_def.node_size(); i++) {
  1337. const domi::tensorflow::NodeDef *node_def = graph_def.mutable_node(i);
  1338. // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_
  1339. if (MaybeFusionOp(scope_graph, node_def)) {
  1340. GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str());
  1341. }
  1342. // Do not exit immediately when there is an error, wait until all errors are collected before exiting
  1343. GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(node_def, op_node_name_list), has_error = true);
  1344. }
  1345. // The fusion operator has passed the verification.
  1346. // The errors of internal non key operators (which will be ignored later)
  1347. // do not affect the transformation of the whole model,
  1348. // So clear the error information of non key operators
  1349. // This function call affects the return value of prechecker::instance().Haserror()
  1350. GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list));
  1351. // Building input and input relationships for all OP nodes
  1352. GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def));
  1353. GELOGD("[TF Parse] get op nodes context from graph success");
  1354. // Infer input formats
  1355. ge::GetParserContext().format = InferInputFormats();
  1356. GELOGD("[TF Parse] infer input formats success");
  1357. // Building input-output relationship between fusionop and common op
  1358. GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, op_node_name_list));
  1359. GELOGD("[TF Parse] update all node op context success");
  1360. // set user-designate-inputs-order
  1361. std::vector<std::string> user_inputs_order;
  1362. for (auto &input : ge::GetParserContext().user_input_dims) {
  1363. user_inputs_order.push_back(input.first);
  1364. }
  1365. graph->SetInputsOrder(user_inputs_order);
  1366. ret = AddFusionNodeDef(scope_graph, op_node_name_list);
  1367. if (ret != SUCCESS) {
  1368. GELOGE(ret, "Add fusion NodeDef failed.");
  1369. DeleteFuisonNodeDef();
  1370. return ret;
  1371. }
  1372. GELOGI("TF op node size = %zu.", op_node_name_list.size());
  1373. // Loop analysis of op_nodes and map them to nodes in graph
  1374. for (size_t i = 0; i < op_node_name_list.size(); i++) {
  1375. GELOGI("TF op node name = %s.", op_node_name_list[i].c_str());
  1376. const string op_node_name = op_node_name_list[i];
  1377. const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name_list[i]];
  1378. if (node_def == nullptr) {
  1379. REPORT_INNER_ERROR("E19999", "Node:%s can't find in nodedef_map_, check invalid", op_node_name.c_str());
  1380. GELOGE(INTERNAL_ERROR, "Cannot find [%s] in nodedef map.", op_node_name_list[i].c_str());
  1381. DeleteFuisonNodeDef();
  1382. return INTERNAL_ERROR;
  1383. }
  1384. const string &node_op = node_def->op();
  1385. if (tensorflow_op_map.find(node_op) == tensorflow_op_map.end()) {
  1386. GELOGW("%s not found in tensorflow_op_map.", node_op.c_str());
  1387. }
  1388. ret = AddNode(node_def, graph, scope_graph);
  1389. if (ret != SUCCESS) {
  1390. GELOGE(ret, "Add op[%s] failed", node_def->name().c_str());
  1391. DeleteFuisonNodeDef();
  1392. return ret;
  1393. }
  1394. }
  1395. GELOGD("[TF Parse] parse tf node to geop success");
  1396. DeleteFuisonNodeDef();
  1397. GE_RETURN_IF_ERROR(AddEdges(graph));
  1398. Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
  1399. ParserUtils::OutputMapping final_output_nodes;
  1400. GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph, final_output_nodes));
  1401. GE_RETURN_IF_ERROR(UpdateOutputsInfo(final_output_nodes));
  1402. GE_RETURN_IF_ERROR(RemoveIsolateNode(graph));
  1403. GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph));
  1404. GE_RETURN_IF_ERROR(graph->TopologicalSorting());
  1405. if (has_error) {
  1406. GELOGE(PARAM_INVALID, "Precheck has errors.");
  1407. return PARAM_INVALID;
  1408. }
  1409. GELOGI("[TF Parser] Parse proto success.");
  1410. return SUCCESS;
  1411. }
  1412. Status TensorFlowModelParser::GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def) {
  1413. // Build the input relationship first
  1414. for (auto &iter : op_node_context_map_) {
  1415. map<string, std::vector<std::pair<int32_t, int32_t>>> input_map;
  1416. const string &op_node_name = iter.first;
  1417. GE_RETURN_IF_ERROR(GetOpNodeInputMap(op_node_name, input_map));
  1418. OpNodeContext &op_node_context = iter.second;
  1419. op_node_context.input_map = input_map;
  1420. }
  1421. // Then build the output relationship
  1422. GE_RETURN_IF_ERROR(GetOpNodeOutputMap(graph_def));
  1423. return SUCCESS;
  1424. }
  1425. // Get the input relation of opnode includeing input_op and input_const
  1426. Status TensorFlowModelParser::GetOpNodeInputMap(const string &op_node_name,
  1427. map<string, std::vector<std::pair<int32_t, int32_t>>> &input_map) {
  1428. // Get the current nodedef according to the node_name
  1429. const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name];
  1430. GE_CHECK_NOTNULL(node_def);
  1431. int32_t input_index = 0;
  1432. int32_t output_index = 0;
  1433. for (const string &input_node_name : node_def->input()) {
  1434. GELOGD("Get Op InputMap, node_name : %s, input node:%s", node_def->name().c_str(), input_node_name.c_str());
  1435. string tmp_node_name;
  1436. bool control = false;
  1437. GE_RETURN_IF_ERROR(CheckInputNodeName(input_node_name, &tmp_node_name, &output_index, &control));
  1438. input_map[tmp_node_name].push_back({output_index, control ? kControlSlot : input_index});
  1439. SaveEdgesControlInfo(node_def->name(), control);
  1440. input_index = control ? input_index : input_index + 1;
  1441. }
  1442. return SUCCESS;
  1443. }
  1444. Status TensorFlowModelParser::GetOpNodeOutputMap(const domi::tensorflow::GraphDef &graph_def) {
  1445. // Loop through all nodes in graphdef
  1446. for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) {
  1447. auto currentIter = op_node_context_map_.find(node_def.name());
  1448. if (currentIter != op_node_context_map_.end()) {
  1449. OpNodeContext &op_node_context = currentIter->second;
  1450. // Find all input nodes of the current node
  1451. for (auto &inputIter : op_node_context.input_map) {
  1452. auto iter = op_node_context_map_.find(inputIter.first);
  1453. if (iter != op_node_context_map_.end()) {
  1454. std::vector<std::pair<int32_t, int32_t>> inputpairs = inputIter.second;
  1455. OpNodeContext &op_node_context1 = iter->second;
  1456. op_node_context1.output_map[node_def.name()].assign(inputpairs.begin(), inputpairs.end());
  1457. }
  1458. }
  1459. }
  1460. }
  1461. return SUCCESS;
  1462. }
  1463. Status TensorFlowModelParser::GeStoi(const string &input_node_name, const string &index_str, int32_t *index) {
  1464. try {
  1465. int32_t tmp_index = static_cast<int32_t>(std::stoi(index_str.c_str(), nullptr, 10));
  1466. *index = tmp_index;
  1467. } catch (std::invalid_argument &) {
  1468. ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"},
  1469. {"input_node_name(" + input_node_name + ")", index_str});
  1470. GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is invalid argument!", input_node_name.c_str(),
  1471. index_str.c_str());
  1472. return INTERNAL_ERROR;
  1473. } catch (std::out_of_range &) {
  1474. ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"},
  1475. {"input_node_name(" + input_node_name + ")", index_str});
  1476. GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is out of range!", input_node_name.c_str(),
  1477. index_str.c_str());
  1478. return INTERNAL_ERROR;
  1479. } catch (...) {
  1480. ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"},
  1481. {"input_node_name(" + input_node_name + ")", index_str});
  1482. GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is bad argument!", input_node_name.c_str(),
  1483. index_str.c_str());
  1484. return INTERNAL_ERROR;
  1485. }
  1486. return SUCCESS;
  1487. }
  1488. Status TensorFlowModelParser::CheckInputNodeName(const string &input_node_name, string *node_name, int32_t *index,
  1489. bool *control) {
  1490. // Processing scene: input: "^fastrcnn_predictions/map/while/Identity""
  1491. string tmp_input_node_name = input_node_name;
  1492. if (tmp_input_node_name.find("^") == 0) {
  1493. tmp_input_node_name = tmp_input_node_name.substr(1, tmp_input_node_name.length() - 1);
  1494. if (control != nullptr) {
  1495. *control = true;
  1496. }
  1497. } else {
  1498. if (control != nullptr) {
  1499. *control = false;
  1500. }
  1501. }
  1502. auto find = tmp_input_node_name.find(":");
  1503. if (find == string::npos) {
  1504. *node_name = tmp_input_node_name;
  1505. if (index == nullptr) {
  1506. return SUCCESS;
  1507. }
  1508. *index = 0;
  1509. return SUCCESS;
  1510. }
  1511. string indexstr = tmp_input_node_name.substr(find + 1, tmp_input_node_name.length() - find - 1);
  1512. *node_name = tmp_input_node_name.substr(0, find);
  1513. if (index == nullptr) {
  1514. return SUCCESS;
  1515. }
  1516. if (GeStoi(input_node_name, indexstr, index) != SUCCESS) {
  1517. return INTERNAL_ERROR;
  1518. }
  1519. return SUCCESS;
  1520. }
  1521. Status TensorFlowModelParser::RunScopeFusionPass(const vector<string> &scope_passes_list,
  1522. ScopePassManager &pass_manager,
  1523. shared_ptr<ge::ScopeGraph> &scope_graph) {
  1524. if (scope_passes_list.empty()) {
  1525. return SUCCESS;
  1526. }
  1527. GE_CHECK_NOTNULL(scope_graph);
  1528. auto &impl = ge::ScopeFusionPassRegistry::GetInstance().impl_;
  1529. if (impl == nullptr) {
  1530. REPORT_INNER_ERROR("E19999", "ScopeFusionPassRegistry is not properly initialized.");
  1531. GELOGE(ge::MEMALLOC_FAILED, "ScopeFusionPassRegistry is not properly initialized.");
  1532. return ge::MEMALLOC_FAILED;
  1533. }
  1534. for (auto &pass_name : scope_passes_list) {
  1535. auto pass = impl->CreateScopeFusionPass(pass_name);
  1536. if (pass == nullptr) {
  1537. REPORT_INNER_ERROR("E19999", "Scope fusion pass[%s] is not registered.", pass_name.c_str());
  1538. GELOGE(INTERNAL_ERROR, "Scope fusion pass[%s] is not registered.", pass_name.c_str());
  1539. return INTERNAL_ERROR;
  1540. }
  1541. Status ret = pass_manager.AddPass(pass);
  1542. if (ret != SUCCESS) {
  1543. REPORT_CALL_ERROR("E19999", "Add scope fusion pass[%s] failed.", pass_name.c_str());
  1544. GELOGE(INTERNAL_ERROR, "Add scope fusion pass[%s] failed.", pass_name.c_str());
  1545. return INTERNAL_ERROR;
  1546. }
  1547. }
  1548. Status ret = pass_manager.Run(scope_graph);
  1549. if (ret != SUCCESS && ret != domi::SCOPE_NOT_CHANGED) {
  1550. GELOGE(FAILED, "Run scope fusion pass failed, ret:%u.", ret);
  1551. return FAILED;
  1552. }
  1553. return SUCCESS;
  1554. }
  1555. bool TensorFlowModelParser::MaybeFusionOp(shared_ptr<ge::ScopeGraph> &scope_graph,
  1556. const domi::tensorflow::NodeDef *node_def) {
  1557. GE_CHECK_NOTNULL(scope_graph);
  1558. GE_CHECK_NOTNULL(node_def);
  1559. // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_
  1560. ge::ScopeFusionOpInfo info;
  1561. std::vector<ge::ScopeFusionOpInfo> info_list;
  1562. auto &impl = scope_graph->impl_;
  1563. if (impl->IsFusionOpChild(node_def->name(), info_list)) {
  1564. GE_IF_BOOL_EXEC(
  1565. info_list.size() > 0, for (size_t i = 0; i < info_list.size(); ++i) {
  1566. fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].fusion_op_type);
  1567. fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].description);
  1568. fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(node_def);
  1569. if (info_list[i].fusion_op_type == "Dropout" &&
  1570. (node_def->op() == "Add" || node_def->op() == "RandomUniform")) {
  1571. fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(nodedef_map_[node_def->input(0)]);
  1572. }
  1573. if (info_list[i].fusion_op_type == "LayerNorm" && node_def->op() == "Mean") {
  1574. fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(nodedef_map_[node_def->input(1)]);
  1575. }
  1576. fusion_op_policy_[info_list[i].fusion_node_name] = info_list[i].scope_pass;
  1577. fusion_op_children_[node_def->name()] = info_list[i];
  1578. });
  1579. GE_IF_BOOL_EXEC(info_list.size() == 0, fusion_op_type_map_[info.fusion_node_name].push_back(info.fusion_op_type);
  1580. fusion_op_type_map_[info.fusion_node_name].push_back(info.description);
  1581. fusion_op_nodedef_map_[info.fusion_node_name].push_back(node_def);
  1582. fusion_op_policy_[info.fusion_node_name] = info.scope_pass;
  1583. fusion_op_children_[node_def->name()] = info);
  1584. return true;
  1585. }
  1586. return false;
  1587. }
  1588. bool TensorFlowModelParser::IsFusionOpChild(const string &node_name, ge::ScopeFusionOpInfo *info) {
  1589. GE_CHK_BOOL_EXEC(info != nullptr, REPORT_CALL_ERROR("E19999", "Param info is nullptr, check invalid");
  1590. return false, "fusion info is null.");
  1591. // 1.View in full match fusion strategy first
  1592. // 2.View in scope fusion policy then
  1593. auto iter = fusion_op_children_.find(node_name);
  1594. if (iter != fusion_op_children_.end()) {
  1595. info->node_name = fusion_op_children_[node_name].node_name;
  1596. info->fusion_node_name = fusion_op_children_[node_name].fusion_node_name;
  1597. info->fusion_op_type = fusion_op_children_[node_name].fusion_op_type;
  1598. info->description = fusion_op_children_[node_name].description;
  1599. info->scope_pass = fusion_op_children_[node_name].scope_pass;
  1600. return true;
  1601. }
  1602. return false;
  1603. }
  1604. bool TensorFlowModelParser::FusionOpChildIgnore(const shared_ptr<ge::ScopeGraph> &scope_graph,
  1605. const ge::ScopeFusionOpInfo &info) {
  1606. GE_CHECK_NOTNULL(scope_graph);
  1607. bool ignore = false;
  1608. if (info.scope_pass) {
  1609. // Scope fusion strategy
  1610. auto &impl = scope_graph->impl_;
  1611. ignore = impl->FusionOpChildIgnore(info);
  1612. }
  1613. return ignore;
  1614. }
  1615. bool TensorFlowModelParser::IsFusionOp(const shared_ptr<ge::ScopeGraph> &scope_graph,
  1616. const domi::tensorflow::NodeDef *node_def) {
  1617. // The caller guarantees that the pointer is not null
  1618. auto &impl = scope_graph->impl_;
  1619. return (impl->IsFusionOp(node_def));
  1620. }
  1621. Status TensorFlowModelParser::GetInPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info,
  1622. const int32_t old_index, int32_t &new_index) {
  1623. GE_CHECK_NOTNULL(scope_graph);
  1624. if (info.scope_pass) {
  1625. auto &impl = scope_graph->impl_;
  1626. return impl->GetInputOrOutputIndex(info, old_index, true, new_index);
  1627. }
  1628. GELOGE(INTERNAL_ERROR, "Fusion op should come from scope fusion pass, node name:%s, fusion node name:%s",
  1629. info.node_name.c_str(), info.fusion_node_name.c_str());
  1630. return INTERNAL_ERROR;
  1631. }
  1632. Status TensorFlowModelParser::GetOutPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info,
  1633. const int32_t old_index, int32_t &new_index) {
  1634. GE_CHECK_NOTNULL(scope_graph);
  1635. if (info.scope_pass) {
  1636. auto &impl = scope_graph->impl_;
  1637. return impl->GetInputOrOutputIndex(info, old_index, false, new_index);
  1638. }
  1639. GELOGE(INTERNAL_ERROR, "Fusion op should come from scope fusion pass, node name:%s, fusion node name:%s",
  1640. info.node_name.c_str(), info.fusion_node_name.c_str());
  1641. return INTERNAL_ERROR;
  1642. }
  1643. bool TensorFlowModelParser::ConstOpNeedUpdate(const string &op_name) {
  1644. if (nodedef_map_[op_name]->op() != TENSORFLOWF_NODE_OP_CONST) {
  1645. // Normal op need to update
  1646. return true;
  1647. } else {
  1648. auto iter = op_node_context_map_.find(op_name);
  1649. if (iter != op_node_context_map_.end()) {
  1650. ge::ScopeFusionOpInfo info;
  1651. auto outmap = iter->second.output_map;
  1652. for (auto &out_node : outmap) {
  1653. // if the const op output connected to are all fusion ops and the cosnt op is not in the update vector
  1654. if (!IsFusionOpChild(out_node.first, &info)) {
  1655. return true;
  1656. }
  1657. }
  1658. }
  1659. return true;
  1660. }
  1661. }
  1662. Status TensorFlowModelParser::UpdateAllNodeOpContext(shared_ptr<ge::ScopeGraph> &scope_graph,
  1663. vector<string> &op_node_name_list) {
  1664. GE_CHECK_NOTNULL(scope_graph);
  1665. vector<string> tmp_op_node_name_list;
  1666. map<string, OpNodeContext> tmp_fusion_op_node_context_map;
  1667. for (auto &op_node_name : op_node_name_list) {
  1668. auto iter = op_node_context_map_.find(op_node_name);
  1669. if (iter != op_node_context_map_.end()) {
  1670. ge::ScopeFusionOpInfo info;
  1671. if (IsFusionOpChild(op_node_name, &info) && nodedef_map_[op_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) {
  1672. // This node is a fusion operator
  1673. const std::map<std::string, OpNodeContext>::const_iterator
  1674. fusion_iter = tmp_fusion_op_node_context_map.find(info.fusion_node_name);
  1675. if (fusion_iter == tmp_fusion_op_node_context_map.end()) {
  1676. OpNodeContext op_node_context;
  1677. tmp_fusion_op_node_context_map[info.fusion_node_name] = op_node_context;
  1678. tmp_op_node_name_list.push_back(info.fusion_node_name);
  1679. }
  1680. OpNodeContext &fusion_op_node_context = tmp_fusion_op_node_context_map[info.fusion_node_name];
  1681. OpNodeContext &normal_op_node_context = op_node_context_map_[op_node_name];
  1682. GE_RETURN_IF_ERROR(UpdateFusionOpContext(scope_graph, info, fusion_op_node_context, normal_op_node_context));
  1683. // Delete fusion operator context
  1684. op_node_context_map_.erase(iter);
  1685. } else {
  1686. // This node is a common operator
  1687. OpNodeContext &normal_op_node_context = op_node_context_map_[op_node_name];
  1688. GE_RETURN_IF_ERROR(UpdateNormalOpContext(scope_graph, op_node_name, normal_op_node_context));
  1689. tmp_op_node_name_list.push_back(op_node_name);
  1690. }
  1691. }
  1692. }
  1693. // update op_node_name_list
  1694. op_node_name_list.clear();
  1695. op_node_name_list.assign(tmp_op_node_name_list.begin(), tmp_op_node_name_list.end());
  1696. // update op_node_context_map_
  1697. for (const auto &iter : tmp_fusion_op_node_context_map) {
  1698. op_node_context_map_[iter.first] = iter.second;
  1699. }
  1700. // Normalized context
  1701. GE_RETURN_IF_ERROR(NormalizeAllNodeOpContext());
  1702. return SUCCESS;
  1703. }
  1704. Status TensorFlowModelParser::UpdateFusionOpContext(shared_ptr<ge::ScopeGraph> &scope_graph,
  1705. const ge::ScopeFusionOpInfo &info,
  1706. OpNodeContext &fusion_op_node_context,
  1707. OpNodeContext &normal_op_node_context) {
  1708. GE_CHECK_NOTNULL(scope_graph);
  1709. if (FusionOpChildIgnore(scope_graph, info)) {
  1710. // The inner children operators of the fusion operator can be ignored directly
  1711. // if they do not establish the edge relationship with other outer ordinary / fusion operators
  1712. return SUCCESS;
  1713. }
  1714. GE_CHK_STATUS_RET(UppdateInputMap(scope_graph, info, fusion_op_node_context, normal_op_node_context),
  1715. "UppdateInputMap ret fail");
  1716. GE_CHK_STATUS_RET(UppdateOutputMap(scope_graph, info, fusion_op_node_context, normal_op_node_context),
  1717. "UppdateOutputMap ret fail");
  1718. return SUCCESS;
  1719. }
  1720. Status TensorFlowModelParser::UppdateInputMap(shared_ptr<ge::ScopeGraph> &scope_graph,
  1721. const ge::ScopeFusionOpInfo &info, OpNodeContext &fusion_op_node_context,
  1722. OpNodeContext &normal_op_node_context) {
  1723. GE_CHECK_NOTNULL(scope_graph);
  1724. for (auto &iter : normal_op_node_context.input_map) {
  1725. string input_node_name = iter.first;
  1726. std::vector<std::pair<int32_t, int32_t>> &pairs = iter.second;
  1727. ge::ScopeFusionOpInfo from_info;
  1728. int32_t from_index = 0;
  1729. int32_t to_index = 0;
  1730. if (!ConstOpNeedUpdate(input_node_name)) {
  1731. GELOGI("%s is const node connected to a fusion child, ignore", input_node_name.c_str());
  1732. continue;
  1733. }
  1734. if (IsFusionOpChild(input_node_name, &from_info)) {
  1735. if (info.fusion_node_name == from_info.fusion_node_name) {
  1736. // Ignore two sub operators in the same fusion operator
  1737. continue;
  1738. }
  1739. for (auto &pair : pairs) {
  1740. GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, from_info, pair.first, from_index),
  1741. "GetOutPutIndex failed ,input_node_name %s.", input_node_name.c_str());
  1742. GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, info, pair.second, to_index),
  1743. "GetInPutIndex failed ,input_node_name %s.", input_node_name.c_str());
  1744. fusion_op_node_context.input_map[from_info.fusion_node_name].push_back({from_index, to_index});
  1745. UpdateEdgesControlInfo(info);
  1746. GELOGD("[Update op context] update fusion input map for fusion input, %s:%d TO %s:%d",
  1747. from_info.fusion_node_name.c_str(), from_index, info.fusion_node_name.c_str(), to_index);
  1748. }
  1749. } else {
  1750. for (auto &pair : pairs) {
  1751. from_index = pair.first;
  1752. GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, info, pair.second, to_index),
  1753. "GetInPutIndex input_node_name %s.", input_node_name.c_str());
  1754. fusion_op_node_context.input_map[input_node_name].push_back({from_index, to_index});
  1755. UpdateEdgesControlInfo(info);
  1756. GELOGD("[Update op context] update fusion input map for normal input, %s:%d TO %s:%d",
  1757. input_node_name.c_str(), from_index, info.fusion_node_name.c_str(), to_index);
  1758. }
  1759. }
  1760. }
  1761. return SUCCESS;
  1762. }
  1763. Status TensorFlowModelParser::UppdateOutputMap(shared_ptr<ge::ScopeGraph> &scope_graph,
  1764. const ge::ScopeFusionOpInfo &info, OpNodeContext &fusion_op_node_context,
  1765. OpNodeContext &normal_op_node_context) {
  1766. GE_CHECK_NOTNULL(scope_graph);
  1767. for (auto &iter : normal_op_node_context.output_map) {
  1768. string output_node_name = iter.first;
  1769. std::vector<std::pair<int32_t, int32_t>> &pairs = iter.second;
  1770. ge::ScopeFusionOpInfo to_info;
  1771. int32_t from_index = 0;
  1772. int32_t to_index = 0;
  1773. if (IsFusionOpChild(output_node_name, &to_info)) {
  1774. if (info.fusion_node_name == to_info.fusion_node_name) {
  1775. // Ignore two sub operators in the same fusion operator
  1776. continue;
  1777. }
  1778. for (auto &pair : pairs) {
  1779. GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, info, pair.first, from_index),
  1780. "fusion GetOutPutIndex failed ,output_node_name %s.", output_node_name.c_str());
  1781. GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, to_info, pair.second, to_index),
  1782. "fusion GetInPutIndex failed ,output_node_name %s.", output_node_name.c_str());
  1783. fusion_op_node_context.output_map[to_info.fusion_node_name].push_back({from_index, to_index});
  1784. GELOGD("[Update op context] update fusion output map for fusion output, %s:%d TO %s:%d",
  1785. info.fusion_node_name.c_str(), from_index, to_info.fusion_node_name.c_str(), to_index);
  1786. }
  1787. } else {
  1788. for (auto &pair : pairs) {
  1789. to_index = pair.second;
  1790. GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, info, pair.first, from_index),
  1791. "not fusion,GetOutPutIndex failed ,output_node_name %s.", output_node_name.c_str());
  1792. fusion_op_node_context.output_map[output_node_name].push_back({from_index, to_index});
  1793. GELOGD("[Update op context] update fusion output map for normal output, %s:%d TO %s:%d",
  1794. info.fusion_node_name.c_str(), from_index, output_node_name.c_str(), to_index);
  1795. }
  1796. }
  1797. }
  1798. return SUCCESS;
  1799. }
  1800. Status TensorFlowModelParser::EraseNormalOpOutputIfChild(shared_ptr<ge::ScopeGraph> &scope_graph,
  1801. const string &op_node_name,
  1802. OpNodeContext &normal_op_node_context) {
  1803. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> tmp_output_map;
  1804. for (auto iter = normal_op_node_context.output_map.begin(); iter != normal_op_node_context.output_map.end();) {
  1805. string output_node_name = iter->first;
  1806. ge::ScopeFusionOpInfo to_info;
  1807. if (IsFusionOpChild(output_node_name, &to_info) &&
  1808. nodedef_map_[output_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) {
  1809. // Fuse operator, update index
  1810. std::vector<std::pair<int32_t, int32_t>> &pairs = iter->second;
  1811. int32_t to_index = 0;
  1812. for (auto &pair : pairs) {
  1813. int32_t from_index = pair.first;
  1814. GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, to_info, pair.second, to_index),
  1815. "GetInPutIndex failed ,output_node_name %s.", output_node_name.c_str());
  1816. tmp_output_map[to_info.fusion_node_name].push_back({from_index, to_index});
  1817. GELOGD("[Update op context] update normal output map for fusion output, %s:%d TO %s:%d", op_node_name.c_str(),
  1818. from_index, to_info.fusion_node_name.c_str(), to_index);
  1819. }
  1820. iter = normal_op_node_context.output_map.erase(iter);
  1821. } else {
  1822. iter++;
  1823. }
  1824. }
  1825. for (auto &iter : tmp_output_map) {
  1826. normal_op_node_context.output_map[iter.first] = iter.second;
  1827. }
  1828. return SUCCESS;
  1829. }
  1830. Status TensorFlowModelParser::UpdateNormalOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, const string &op_node_name,
  1831. OpNodeContext &normal_op_node_context) {
  1832. GE_CHECK_NOTNULL(scope_graph);
  1833. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> tmp_input_map;
  1834. for (auto iter = normal_op_node_context.input_map.begin(); iter != normal_op_node_context.input_map.end();) {
  1835. string input_node_name = iter->first;
  1836. ge::ScopeFusionOpInfo from_info;
  1837. if (IsFusionOpChild(input_node_name, &from_info) &&
  1838. nodedef_map_[input_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) {
  1839. // Fuse operator, update index
  1840. std::vector<std::pair<int32_t, int32_t>> &pairs = iter->second;
  1841. int32_t from_index = 0;
  1842. for (auto &pair : pairs) {
  1843. int32_t to_index = pair.second;
  1844. GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, from_info, pair.first, from_index),
  1845. "GetOutPutIndex failed ,input_node_name %s.", input_node_name.c_str());
  1846. tmp_input_map[from_info.fusion_node_name].push_back({from_index, to_index});
  1847. GELOGD("[Update op context] update normal input map for fusion input, %s:%d TO %s:%d",
  1848. from_info.fusion_node_name.c_str(), from_index, op_node_name.c_str(), to_index);
  1849. }
  1850. iter = normal_op_node_context.input_map.erase(iter);
  1851. } else {
  1852. iter++;
  1853. }
  1854. }
  1855. Status ret = EraseNormalOpOutputIfChild(scope_graph, op_node_name, normal_op_node_context);
  1856. if (ret != SUCCESS) {
  1857. return ret;
  1858. }
  1859. for (auto &iter : tmp_input_map) {
  1860. normal_op_node_context.input_map[iter.first] = iter.second;
  1861. }
  1862. return SUCCESS;
  1863. }
  1864. Status TensorFlowModelParser::NormalizeAllNodeOpContext() {
  1865. for (auto iter = op_node_context_map_.begin(); iter != op_node_context_map_.end();) {
  1866. OpNodeContext &context = iter->second;
  1867. NormalizeInputOrOutputMap(iter->first, context.input_map);
  1868. NormalizeInputOrOutputMap(iter->first, context.output_map);
  1869. if ((context.input_map.size() == 0) && (context.output_map.size() == 0)) {
  1870. GELOGD("[Update op context] node: %s will be removed at the back.", iter->first.c_str());
  1871. iter = op_node_context_map_.erase(iter);
  1872. } else {
  1873. iter++;
  1874. }
  1875. }
  1876. return SUCCESS;
  1877. }
  1878. Status TensorFlowModelParser::NormalizeInputOrOutputMap(
  1879. const string &node_name, std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &context_map) {
  1880. if (context_map.empty()) {
  1881. return SUCCESS;
  1882. }
  1883. for (auto iter = context_map.begin(); iter != context_map.end();) {
  1884. std::vector<std::pair<int32_t, int32_t>> &pairs = iter->second;
  1885. std::vector<std::pair<int32_t, int32_t>> temp_pairs;
  1886. std::set<std::string> compare_set;
  1887. for (auto &pair : pairs) {
  1888. bool is_fusion_child = (fusion_op_children_.find(node_name) != fusion_op_children_.cend()) ||
  1889. (fusion_op_children_.find(iter->first) != fusion_op_children_.cend());
  1890. bool is_fusion_op = (fusion_op_type_map_.find(node_name) != fusion_op_type_map_.cend()) ||
  1891. (fusion_op_type_map_.find(iter->first) != fusion_op_type_map_.cend());
  1892. if (((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) &&
  1893. (is_fusion_child || is_fusion_op)) {
  1894. // The edge will be cut off at the back, ignoring
  1895. continue;
  1896. }
  1897. string name = to_string(pair.first) + ":" + to_string(pair.second);
  1898. const std::set<std::string>::const_iterator compare_iter = compare_set.find(name);
  1899. if (compare_iter != compare_set.end()) {
  1900. // pair<from,to> repeat, ignore
  1901. continue;
  1902. }
  1903. temp_pairs.push_back(pair);
  1904. compare_set.insert(name);
  1905. }
  1906. if (temp_pairs.empty()) {
  1907. // If there is no pair, the context can be deleted
  1908. iter = context_map.erase(iter);
  1909. continue;
  1910. } else {
  1911. iter++;
  1912. }
  1913. pairs.clear();
  1914. pairs.assign(temp_pairs.begin(), temp_pairs.end());
  1915. }
  1916. return SUCCESS;
  1917. }
  1918. void TensorFlowModelParser::DeleteFuisonNodeDef() {
  1919. for (auto &fusion_nodedef : fusion_nodedef_list) {
  1920. GE_DELETE_NEW_SINGLE(fusion_nodedef);
  1921. }
  1922. }
  1923. void TensorFlowModelParser::SaveEdgesControlInfo(const string &node_name, const bool control) {
  1924. if (control) {
  1925. // If the control attribute is true, save the control attribute to edges_control_map
  1926. edges_control_map[node_name].push_back(kControlSlot);
  1927. }
  1928. }
  1929. void TensorFlowModelParser::UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo &info) {
  1930. const std::map<std::string, std::vector<int32_t>>::const_iterator iter = edges_control_map.find(info.node_name);
  1931. if (iter != edges_control_map.end()) {
  1932. // Delete the original fusion operator node information and add the fusion operator control edge information
  1933. edges_control_map.erase(iter);
  1934. edges_control_map[info.fusion_node_name].push_back(kControlSlot);
  1935. }
  1936. }
  1937. bool TensorFlowModelParser::GetEdgesControlInfo(const string &node_name, const int32_t index) const {
  1938. // If the node name is included, then confirm whether the index is the same
  1939. auto iter = edges_control_map.find(node_name);
  1940. if (iter != edges_control_map.end()) {
  1941. for (auto &i : iter->second) {
  1942. if (i == index) {
  1943. return true;
  1944. }
  1945. }
  1946. }
  1947. return false;
  1948. }
  1949. Status TensorFlowModelParser::ClearFusionOpError(const vector<string> &op_node_name_list) {
  1950. for (const auto &name : op_node_name_list) {
  1951. ge::ScopeFusionOpInfo info;
  1952. if (IsFusionOpChild(name, &info)) {
  1953. const NodeDef *node = nodedef_map_[name];
  1954. GE_CHECK_NOTNULL(node);
  1955. GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().Clear(node, "fused and removed."),
  1956. "Clear pre-checking for node %s failed.", node->name().c_str());
  1957. }
  1958. }
  1959. return SUCCESS;
  1960. }
  1961. Status TensorFlowModelParser::ToJson(const char *model_file, const char *json_file) {
  1962. GE_CHK_BOOL_RET_STATUS(model_file != nullptr, FAILED, "model_file is nullptr.");
  1963. GE_CHK_BOOL_RET_STATUS(json_file != nullptr, FAILED, "json_file is nullptr.");
  1964. domi::tensorflow::GraphDef graph_def;
  1965. nlohmann::json j;
  1966. GE_RETURN_WITH_LOG_IF_FALSE(ge::parser::ReadProtoFromBinaryFile(model_file, &graph_def),
  1967. "ReadProtoFromBinaryFile failed, file:%s.", model_file);
  1968. Pb2Json::Message2Json(graph_def, kTfBlackFields, j, true);
  1969. return ModelSaver::SaveJsonToFile(json_file, j);
  1970. }
  1971. Status TensorFlowWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) {
  1972. (void)data;
  1973. (void)size;
  1974. (void)graph;
  1975. return SUCCESS;
  1976. }
  1977. Status TensorFlowWeightsParser::Parse(const char *file, ge::Graph &graph) {
  1978. (void)file;
  1979. (void)graph;
  1980. return SUCCESS;
  1981. }
  1982. Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) {
  1983. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  1984. ErrorManager::GetInstance().GenWorkStreamIdDefault();
  1985. PARSER_TIMESTAMP_START(ParseProto);
  1986. GE_CHECK_NOTNULL(proto);
  1987. GE_CHECK_NOTNULL(graph);
  1988. ge::GetParserContext().train_flag = true;
  1989. const domi::tensorflow::GraphDef *graph_def_in =
  1990. ge::PtrToPtr<google::protobuf::Message, domi::tensorflow::GraphDef>(proto);
  1991. // Make a copy for operation without modifying the original graph def.
  1992. domi::tensorflow::GraphDef graph_def_operation = *graph_def_in;
  1993. domi::tensorflow::GraphDef *graph_def = &graph_def_operation;
  1994. GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(graph_def, domi::TENSORFLOW),
  1995. "Run ProtoType Pass Failed");
  1996. shared_ptr<ge::ScopeGraph> scope_graph = nullptr;
  1997. Status ret = ExcuteScopeFusionPasses(graph_def, scope_graph);
  1998. if (ret != SUCCESS) {
  1999. GELOGE(ret, "[TF Parser] scope fusion failed.");
  2000. return ret;
  2001. }
  2002. GELOGD("[TF Parser] scope fusion success");
  2003. bool has_error = false;
  2004. // Graphdef optimizes identity
  2005. PARSER_TIMESTAMP_START(GraphDefOptimize);
  2006. GE_RETURN_IF_ERROR(GraphDefOptimize(graph_def));
  2007. PARSER_TIMESTAMP_END(GraphDefOptimize, "TensorFlowModelParser::GraphDefOptimize");
  2008. GELOGD("[TF Parser] graph def optimize success");
  2009. // Optimization for TVM operator
  2010. PARSER_TIMESTAMP_START(OptimizeConstNodes4CustomOp);
  2011. GE_RETURN_IF_ERROR(OptimizeConstNodes4CustomOp(graph_def));
  2012. PARSER_TIMESTAMP_END(OptimizeConstNodes4CustomOp, "TensorFlowModelParser::OptimizeConstNodes4CustomOp");
  2013. GELOGD("[TF Parser] optimize const nodes for custom op success");
  2014. GE_RETURN_IF_ERROR(GetTensorflowGraphInOutMap(graph_def));
  2015. GE_RETURN_IF_ERROR(RemoveIsolateNode(graph_def));
  2016. vector<string> op_node_name_list;
  2017. bool isDatasetInit = false;
  2018. PARSER_TIMESTAMP_START(AddFmkNodeDefToMap);
  2019. for (int i = 0; i < graph_def->node_size(); i++) {
  2020. const domi::tensorflow::NodeDef *node_def = graph_def->mutable_node(i);
  2021. if (node_def->op() == ge::parser::IDENTITY && node_def->input_size() == 0) {
  2022. continue;
  2023. }
  2024. if (node_def->op() == ge::parser::SNAPSHOT && node_def->input_size() == 0) {
  2025. continue;
  2026. }
  2027. GE_IF_BOOL_EXEC(node_def->op() == "MakeIterator", isDatasetInit = true);
  2028. // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_
  2029. if (MaybeFusionOp(scope_graph, node_def)) {
  2030. GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str());
  2031. }
  2032. // Do not exit immediately when there is an error, wait until all errors are collected before exiting
  2033. ret = AddFmkNodeDefToMap(node_def, op_node_name_list);
  2034. GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed");
  2035. }
  2036. PARSER_TIMESTAMP_END(AddFmkNodeDefToMap, "TensorFlowModelParser::AddFmkNodeDefToMap");
  2037. GELOGI("[TF Parser] TF subgraph isDatasetInit: %d.", isDatasetInit);
  2038. // Build input and output relationships for all OP nodes
  2039. PARSER_TIMESTAMP_START(GetOpNodesContextFromGraph);
  2040. GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def));
  2041. PARSER_TIMESTAMP_END(GetOpNodesContextFromGraph, "TensorFlowModelParser::GetOpNodesContextFromGraph");
  2042. GELOGD("[TF Parser] Get op nodes context from graph success");
  2043. // Building input-output relationship between fusionop and common op
  2044. GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, op_node_name_list));
  2045. GELOGI("[TF Parser] TF op node size = %zu.", op_node_name_list.size());
  2046. PARSER_TIMESTAMP_START(AddFmkNode);
  2047. // Loop analysis of op_nodes and map them to nodes in graph
  2048. ret = AddFmkNode(graph, scope_graph, op_node_name_list, isDatasetInit);
  2049. PARSER_TIMESTAMP_END(AddFmkNode, "TensorFlowModelParser::AddFmkNode");
  2050. GE_CHK_STATUS_EXEC(ret, DeleteFuisonNodeDef(); return ret, "AddFmkNode failed");
  2051. GELOGD("[TF Parser] Add framework node success");
  2052. ret = AddEdges(graph);
  2053. Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
  2054. ParserUtils::OutputMapping final_output_nodes;
  2055. GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph, final_output_nodes));
  2056. GE_RETURN_IF_ERROR(UpdateOutputsInfo(final_output_nodes));
  2057. DeleteFuisonNodeDef();
  2058. GE_CHK_STATUS_EXEC(ret, return ret, "AddEdges failed");
  2059. GELOGD("[TF Parser] Add edges success");
  2060. PARSER_TIMESTAMP_START(RemoveIsolateNode);
  2061. // Delete isolated nodes
  2062. GE_RETURN_IF_ERROR(RemoveIsolateNode(graph));
  2063. GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph));
  2064. PARSER_TIMESTAMP_END(RemoveIsolateNode, "TensorFlowModelParser::RemoveIsolateNode");
  2065. PARSER_TIMESTAMP_START(TopologicalSorting);
  2066. GE_RETURN_IF_ERROR(graph->TopologicalSorting());
  2067. PARSER_TIMESTAMP_END(TopologicalSorting, "TensorFlowModelParser::TopologicalSorting");
  2068. ge::parser::PassManager iterator_fusion_pass;
  2069. try {
  2070. (void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass", new ge::IteratorFusionPass(domi::TENSORFLOW));
  2071. } catch (std::bad_alloc &e) {
  2072. GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs.");
  2073. return INTERNAL_ERROR;
  2074. }
  2075. ret = iterator_fusion_pass.Run(graph);
  2076. if (ret != SUCCESS && ret != ge::NOT_CHANGED) {
  2077. GELOGE(ret, "Run graph passes optimize for preprocess failed, ret:%u.", ret);
  2078. return ret;
  2079. }
  2080. has_error = has_error || PreChecker::Instance().HasError();
  2081. if (has_error) {
  2082. GELOGE(PARAM_INVALID, "Precheck has errors.");
  2083. return PARAM_INVALID;
  2084. }
  2085. GELOGI("[TF Parser] Parse proto success.");
  2086. PARSER_TIMESTAMP_END(ParseProto, "TensorFlowModelParser::ParseProto");
  2087. return SUCCESS;
  2088. }
  2089. Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Message *root_proto,
  2090. domi::GetGraphCallback callback, ge::ComputeGraphPtr &root_graph) {
  2091. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  2092. ErrorManager::GetInstance().GenWorkStreamIdDefault();
  2093. GE_CHECK_NOTNULL(root_proto);
  2094. GE_CHECK_NOTNULL(callback);
  2095. GE_CHECK_NOTNULL(root_graph);
  2096. PARSER_TIMESTAMP_START(ParseProtoWithSubgraph);
  2097. std::vector<std::unique_ptr<google::protobuf::Message>> proto_holder;
  2098. std::deque<ParseArg> tasks;
  2099. tasks.push_back({root_proto, "root", nullptr, "", root_graph});
  2100. while (!tasks.empty()) {
  2101. auto arg = tasks.front();
  2102. tasks.pop_front();
  2103. if (arg.proto == nullptr) {
  2104. auto proto = callback(root_proto, arg.function_name);
  2105. if (proto == nullptr) {
  2106. REPORT_CALL_ERROR("E19999", "callback execute failed, func_name:%s", arg.function_name.c_str());
  2107. GELOGE(FAILED, "Failed to get function by name %s", arg.function_name.c_str());
  2108. return FAILED;
  2109. }
  2110. arg.proto = proto.get();
  2111. proto_holder.emplace_back(std::move(proto));
  2112. }
  2113. GELOGI("Begin to parse graph %s", arg.function_name.c_str());
  2114. auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW);
  2115. auto ret = model_parser->ParseProto(arg.proto, arg.graph);
  2116. if (ret != SUCCESS) {
  2117. GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(),
  2118. arg.graph->GetName().c_str());
  2119. return ret;
  2120. }
  2121. ret = PostOpProcessForSubgraph(arg);
  2122. if (ret != SUCCESS) {
  2123. // the error log has been printed inner the function
  2124. return ret;
  2125. }
  2126. ret = GenSubgraphParseTasks(arg.graph, tasks);
  2127. if (ret != SUCCESS) {
  2128. GELOGE(ret, "Failed to gen tasks on graph %s for next iteration", arg.graph->GetName().c_str());
  2129. return ret;
  2130. }
  2131. }
  2132. auto add_ret = AddExternalGraph(root_graph);
  2133. if (add_ret != SUCCESS) {
  2134. GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str());
  2135. return add_ret;
  2136. }
  2137. PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph");
  2138. return SUCCESS;
  2139. }
  2140. Status TensorFlowModelParser::ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) {
  2141. if (serialized_proto.empty()) {
  2142. GELOGE(FAILED, "Deserialize proto failed as serialized proto is empty");
  2143. return FAILED;
  2144. }
  2145. domi::tensorflow::GraphDef graph_def;
  2146. if (!graph_def.ParseFromString(serialized_proto)) {
  2147. GELOGE(FAILED, "Proto object GraphDef parse serialized proto failed");
  2148. return FAILED;
  2149. }
  2150. return ParseProto(ge::PtrToPtr<domi::tensorflow::GraphDef, const google::protobuf::Message>(&graph_def), graph);
  2151. }
  2152. Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback,
  2153. ge::ComputeGraphPtr &root_graph) {
  2154. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  2155. ErrorManager::GetInstance().GenWorkStreamIdDefault();
  2156. GE_CHECK_NOTNULL(callback);
  2157. GE_CHECK_NOTNULL(root_graph);
  2158. PARSER_TIMESTAMP_START(ParseProtoWithSubgraph);
  2159. std::deque<ParseArg> tasks;
  2160. tasks.push_back({nullptr, "root", nullptr, "", root_graph});
  2161. bool root_parsed = false;
  2162. while (!tasks.empty()) {
  2163. auto arg = tasks.front();
  2164. tasks.pop_front();
  2165. auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW);
  2166. Status ret = SUCCESS;
  2167. if (root_parsed) {
  2168. GELOGI("Begin to parse serialized proto of sub graph %s", arg.function_name.c_str());
  2169. ret = model_parser->ParseProto(callback(arg.function_name), arg.graph);
  2170. } else {
  2171. GELOGI("Begin to parse serialized proto of root graph");
  2172. ret = model_parser->ParseProto(root_proto, arg.graph);
  2173. root_parsed = true;
  2174. }
  2175. if (ret != SUCCESS) {
  2176. GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(),
  2177. arg.graph->GetName().c_str());
  2178. return ret;
  2179. }
  2180. ret = PostOpProcessForSubgraph(arg);
  2181. if (ret != SUCCESS) {
  2182. return ret; // the error log has been printed inner the function
  2183. }
  2184. ret = GenSubgraphParseTasks(arg.graph, tasks);
  2185. if (ret != SUCCESS) {
  2186. GELOGE(ret, "Failed to gen tasks for sub graph of graph %s", arg.graph->GetName().c_str());
  2187. return ret;
  2188. }
  2189. }
  2190. auto add_ret = AddExternalGraph(root_graph);
  2191. if (add_ret != SUCCESS) {
  2192. GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str());
  2193. return add_ret;
  2194. }
  2195. PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph");
  2196. return SUCCESS;
  2197. }
  2198. Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def,
  2199. map<string, NodeDef *> &nodedef_map,
  2200. const std::pair<string, int> &input_data,
  2201. const std::vector<string> &control_list) {
  2202. GE_CHECK_NOTNULL(curr_mode_def);
  2203. string curr_node_name = curr_mode_def->name();
  2204. auto context_iter = op_node_context_map_.find(curr_node_name);
  2205. if (context_iter == op_node_context_map_.end()) {
  2206. REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str());
  2207. GELOGE(FAILED, "Can't find op node context.");
  2208. return INTERNAL_ERROR;
  2209. }
  2210. OpNodeContext op_node_context = context_iter->second;
  2211. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map;
  2212. for (auto &output_iter : output_map) {
  2213. const string &output_node_name = output_iter.first;
  2214. domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name];
  2215. GE_CHECK_NOTNULL(output_node_def);
  2216. auto inputs = output_node_def->mutable_input();
  2217. std::vector<std::string> added_inputs;
  2218. for (auto &input : *inputs) {
  2219. string node_name;
  2220. bool is_control = false;
  2221. if (CheckInputNodeName(input, &node_name, nullptr, &is_control) != SUCCESS) {
  2222. GELOGE(FAILED, "parse node input info failed, node %s, input %s.", output_node_def->name().c_str(),
  2223. input.c_str());
  2224. return FAILED;
  2225. }
  2226. if (node_name == curr_node_name) {
  2227. if (is_control) {
  2228. input = "^" + input_data.first;
  2229. } else if (input_data.second == 0) {
  2230. input = input_data.first;
  2231. } else {
  2232. input = input_data.first + ":" + std::to_string(input_data.second);
  2233. }
  2234. GELOGD("Optimize Snapshot node, dest:%s, set input:%s.", output_node_name.c_str(), input.c_str());
  2235. for (auto &item : control_list) {
  2236. bool is_exist_input = false;
  2237. for (auto &tmp_input : output_node_def->input()) {
  2238. string tmp_node_name;
  2239. if (CheckInputNodeName(tmp_input, &tmp_node_name, nullptr, nullptr) != SUCCESS) {
  2240. GELOGE(INTERNAL_ERROR, "parse node input info failed, node %s, input %s.",
  2241. output_node_def->name().c_str(), tmp_input.c_str());
  2242. return FAILED;
  2243. }
  2244. if (tmp_node_name == item) {
  2245. is_exist_input = true;
  2246. break;
  2247. }
  2248. }
  2249. if (!is_exist_input) {
  2250. added_inputs.push_back("^" + item);
  2251. }
  2252. }
  2253. }
  2254. }
  2255. for (std::string added_input : added_inputs) {
  2256. GELOGD("Optimize Snapshot node, dest:%s, set control input:%s.", output_node_name.c_str(), added_input.c_str());
  2257. output_node_def->add_input(added_input);
  2258. }
  2259. }
  2260. // Clear the input of snapshot and become an isolated node
  2261. curr_mode_def->clear_input();
  2262. return SUCCESS;
  2263. }
  2264. Status TensorFlowModelParser::GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def,
  2265. map<string, NodeDef *> &nodedef_map,
  2266. const vector<NodeDef *> &nodedef_to_optimize) {
  2267. GE_CHECK_NOTNULL(graph_def);
  2268. if (!nodedef_to_optimize.empty()) {
  2269. // Building input and input relationships for all OP nodes
  2270. GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def));
  2271. GELOGD("Optimize snapshot num:%zu.", nodedef_to_optimize.size());
  2272. } else {
  2273. return SUCCESS;
  2274. }
  2275. for (auto &curr_node_def : nodedef_to_optimize) {
  2276. GE_CHECK_NOTNULL(curr_node_def);
  2277. std::pair<string, int> input_data; // src node name, src index
  2278. vector<string> control_list;
  2279. uint32_t data_input_cnt = 0;
  2280. for (auto &input : curr_node_def->input()) {
  2281. string node_name;
  2282. int input_index = 0;
  2283. bool is_control = false;
  2284. if (CheckInputNodeName(input, &node_name, &input_index, &is_control) != SUCCESS) {
  2285. GELOGE(FAILED, "parse SnapShot input info failed, node %s, input %s.", curr_node_def->name().c_str(),
  2286. input.c_str());
  2287. return FAILED;
  2288. }
  2289. if (is_control) {
  2290. control_list.push_back(node_name);
  2291. } else {
  2292. data_input_cnt++;
  2293. input_data = std::make_pair(node_name, input_index);
  2294. }
  2295. }
  2296. if (data_input_cnt != 1) {
  2297. REPORT_INNER_ERROR("E19999", "Node:%s's input data size:%u not equal to 1, check invalid",
  2298. curr_node_def->name().c_str(), data_input_cnt);
  2299. GELOGE(FAILED, "%s op data input size %u invalid", curr_node_def->name().c_str(), data_input_cnt);
  2300. return FAILED;
  2301. }
  2302. // Optimize Snapshot Node
  2303. GE_CHK_STATUS_RET(OptimizeSnapShot(curr_node_def, nodedef_map, input_data, control_list));
  2304. }
  2305. GELOGI("GraphDefOptimizeSnapShot success.");
  2306. return SUCCESS;
  2307. }
  2308. Status TensorFlowModelParser::SetDestNodeName(const domi::tensorflow::NodeDef *const node_current,
  2309. domi::tensorflow::NodeDef *const node_dest, const int32_t input_idx,
  2310. const bool is_control, bool &clear_input_flag) {
  2311. GELOGI("current node name is %s ", node_current->name().c_str());
  2312. clear_input_flag = true;
  2313. if (is_control) {
  2314. string node_current_name = node_current->input(0);
  2315. string current_name;
  2316. if (CheckInputNodeName(node_current_name, &current_name, nullptr, nullptr) != SUCCESS) {
  2317. GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", node_current_name.c_str());
  2318. return FAILED;
  2319. }
  2320. current_name = "^" + current_name;
  2321. GELOGI("set nodeCurrentNameTmp: %s", current_name.c_str());
  2322. node_dest->set_input(input_idx, current_name);
  2323. } else {
  2324. node_dest->set_input(input_idx, node_current->input(0).c_str());
  2325. GELOGD("%s op set input:%s.", node_dest->name().c_str(), node_current->input(0).c_str());
  2326. }
  2327. // DestroyTemporaryVariable node have only one input and one output.
  2328. // If the number of inputs is greater than 1, all subsequent inputs are
  2329. // control edge inputs. Therefore, after deleting DestroyTemporaryVariable,
  2330. // these control edge inputs can be directly connected to nodeDst.
  2331. for (int i = 1; i < node_current->input_size(); ++i) {
  2332. node_dest->add_input(node_current->input(i));
  2333. }
  2334. return SUCCESS;
  2335. }
  2336. void TensorFlowModelParser::OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *const graph_def,
  2337. domi::tensorflow::NodeDef *const nodeCurrent,
  2338. bool &clearInputFlag) const {
  2339. // Internal call to ensure that the parameter is not empty.
  2340. GELOGI("DestroyTemporaryVariable optimizing.");
  2341. for (int w = 0; w < graph_def->node_size(); w++) {
  2342. domi::tensorflow::NodeDef *nodeDst = graph_def->mutable_node(w);
  2343. GE_IF_BOOL_EXEC(nodeDst->name() == nodeCurrent->name(), continue);
  2344. for (int k = 0; k < nodeDst->input_size(); k++) {
  2345. string nodeDstInputName = nodeDst->input(k);
  2346. string nodeDstInputNameTmp;
  2347. bool isControl = false;
  2348. if (CheckInputNodeName(nodeDstInputName, &nodeDstInputNameTmp, nullptr, &isControl) != SUCCESS) {
  2349. GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", nodeDstInputName.c_str());
  2350. return;
  2351. }
  2352. if (nodeDstInputNameTmp != nodeCurrent->name()) {
  2353. continue;
  2354. }
  2355. if (SetDestNodeName(nodeCurrent, nodeDst, k, isControl, clearInputFlag) != SUCCESS) {
  2356. GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", nodeCurrent->name().c_str());
  2357. return;
  2358. }
  2359. GELOGI("Optimize DestroyTemporaryVariable successful.");
  2360. }
  2361. }
  2362. }
  2363. Status TensorFlowModelParser::GraphDefOptimizeDestroyTemporaryVariable(
  2364. domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *const nodeCurrent) const {
  2365. if (graph_def == nullptr || nodeCurrent == nullptr) {
  2366. REPORT_INNER_ERROR("E19999", "Param graph_def or nodeCurrent is nullptr, check invalid");
  2367. GELOGE(FAILED, "input param is nullptr.");
  2368. return FAILED;
  2369. }
  2370. if (nodeCurrent->op() != ge::parser::DESTROYTEMPORARYVARIABLE) {
  2371. return SUCCESS;
  2372. }
  2373. GELOGI("Optimize DestroyTemporaryVariable, node name is :%s.", nodeCurrent->name().c_str());
  2374. bool clearInputFlag = false;
  2375. google::protobuf::Map<std::string, domi::tensorflow::AttrValue> *attr_map_destroy = nodeCurrent->mutable_attr();
  2376. domi::tensorflow::AttrValue var_name_attr_destroy = (*attr_map_destroy)[ge::VAR_ATTR_NAME];
  2377. for (int j = 0; j < graph_def->node_size(); j++) {
  2378. domi::tensorflow::NodeDef *nodeTmpVar = graph_def->mutable_node(j);
  2379. GE_IF_BOOL_EXEC(nodeTmpVar->op() != ge::parser::TEMPORARYVARIABLE, continue);
  2380. google::protobuf::Map<std::string, domi::tensorflow::AttrValue> *attr_map_tmp = nodeTmpVar->mutable_attr();
  2381. domi::tensorflow::AttrValue var_name_attr_tmp = (*attr_map_tmp)[ge::VAR_ATTR_NAME];
  2382. if (var_name_attr_destroy.s() != var_name_attr_tmp.s()) {
  2383. continue;
  2384. }
  2385. // Optimize destroytemporaryvariable operator
  2386. OptimizeDestroyTemporaryVariable(graph_def, nodeCurrent, clearInputFlag);
  2387. if (clearInputFlag) {
  2388. nodeCurrent->clear_input(); // Clear the destroytemporaryvariable input to become an isolated node
  2389. break;
  2390. }
  2391. }
  2392. if (!clearInputFlag) {
  2393. REPORT_INNER_ERROR("E19999", "Optimize DestroyTemporaryVariable failed, node name is :%s.",
  2394. nodeCurrent->name().c_str());
  2395. GELOGE(INTERNAL_ERROR, "Optimize DestroyTemporaryVariable failed, node name is :%s.", nodeCurrent->name().c_str());
  2396. return FAILED;
  2397. }
  2398. return SUCCESS;
  2399. }
  2400. struct DelTransposeInfo {
  2401. domi::tensorflow::NodeDef *node_def; // transpose
  2402. domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next]
  2403. int inputIdx;
  2404. };
  2405. Status GetTransposeInfo(domi::tensorflow::GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo,
  2406. std::map<std::string, DelTransposeInfo> &transposeInfo) {
  2407. GE_CHECK_NOTNULL(graph_def);
  2408. for (int i = 0; i < graph_def->node_size(); ++i) {
  2409. auto node_def = graph_def->mutable_node(i);
  2410. if (node_def->op() == ge::parser::TRANSPOSE) {
  2411. DelTransposeInfo transpose;
  2412. transpose.node_def = node_def;
  2413. transposeInfo.insert(std::make_pair(node_def->name(), transpose));
  2414. } else if (node_def->op() == ge::parser::SOFTMAX) {
  2415. softmaxInfo.insert(std::make_pair(node_def->name(), node_def->input(0)));
  2416. GELOGI("softmax name:%s, input name:%s", node_def->name().c_str(), node_def->input(0).c_str());
  2417. }
  2418. }
  2419. for (auto &itTranspose : transposeInfo) {
  2420. for (int j = 0; j < graph_def->node_size(); ++j) {
  2421. auto nextNodeDef = graph_def->mutable_node(j);
  2422. bool bFind = false;
  2423. for (int k = 0; k < nextNodeDef->input_size(); ++k) {
  2424. if (nextNodeDef->input(k) == itTranspose.first) {
  2425. itTranspose.second.nextNodeDef = nextNodeDef;
  2426. itTranspose.second.inputIdx = k;
  2427. GELOGI("transpose info name:%s, next name:%s, idx:%d", itTranspose.second.node_def->name().c_str(),
  2428. nextNodeDef->name().c_str(), k);
  2429. bFind = true;
  2430. break;
  2431. }
  2432. }
  2433. if (bFind) {
  2434. break;
  2435. }
  2436. }
  2437. }
  2438. return SUCCESS;
  2439. }
  2440. Status EraseTransposeNode(std::map<std::string, std::string> &softmaxInfo,
  2441. std::map<std::string, DelTransposeInfo> &transposeInfo) {
  2442. std::map<std::string, DelTransposeInfo>::const_iterator itTranspose = transposeInfo.begin();
  2443. for (; itTranspose != transposeInfo.end();) {
  2444. // transpose --> softmax
  2445. bool bErase = true;
  2446. if (softmaxInfo.find(itTranspose->second.node_def->input(0)) != softmaxInfo.end() ||
  2447. softmaxInfo.find(itTranspose->second.nextNodeDef->name()) != softmaxInfo.end()) {
  2448. bErase = false;
  2449. }
  2450. if (bErase) {
  2451. GELOGI("erase node name:%s, input(0):%s", itTranspose->first.c_str(),
  2452. itTranspose->second.node_def->input(0).c_str());
  2453. itTranspose = transposeInfo.erase(itTranspose);
  2454. } else {
  2455. ++itTranspose;
  2456. }
  2457. }
  2458. if ((softmaxInfo.size() <= SIZE_MAX / kSoftmaxMultiple) &&
  2459. (softmaxInfo.size() * kSoftmaxMultiple != transposeInfo.size())) {
  2460. GELOGW("softmax size[%zu], transpose size[%zu]", softmaxInfo.size(), transposeInfo.size());
  2461. return FAILED;
  2462. }
  2463. return SUCCESS;
  2464. }
  2465. void TensorFlowModelParser::OptimizeTranspose(std::map<std::string, DelTransposeInfo> &transposeInfo) {
  2466. for (auto &it : transposeInfo) {
  2467. auto transpose = it.second;
  2468. transpose.nextNodeDef->set_input(transpose.inputIdx, transpose.node_def->input(kTransposeInputIdx));
  2469. transpose.node_def->clear_input();
  2470. }
  2471. }
  2472. void TensorFlowModelParser::SoftmaxAddAttr(domi::tensorflow::GraphDef *const graph_def) {
  2473. // The caller guarantees that the pointer is not null
  2474. for (int i = 0; i < graph_def->node_size(); ++i) {
  2475. auto node_def = graph_def->mutable_node(i);
  2476. if (node_def->op() == ge::parser::SOFTMAX) {
  2477. domi::tensorflow::AttrValue attr_value;
  2478. attr_value.set_i(1);
  2479. ge::TensorFlowUtil::AddNodeAttr("axis", attr_value, node_def);
  2480. GELOGI("SoftmaxAddAttr, name: %s, input name:%s", node_def->name().c_str(), node_def->input(0).c_str());
  2481. }
  2482. }
  2483. }
  2484. Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph_def) {
  2485. GE_CHECK_NOTNULL(graph_def);
  2486. map<string, NodeDef *> nodedef_map;
  2487. vector<string> op_node_name_list;
  2488. // Save Snapshot
  2489. vector<NodeDef *> snapshot_to_optimize;
  2490. for (int i = 0; i < graph_def->node_size(); i++) {
  2491. // mutable_node return vale is not empty
  2492. domi::tensorflow::NodeDef *node_def = graph_def->mutable_node(i);
  2493. const string &node_name = node_def->name();
  2494. Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list);
  2495. GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed");
  2496. if (node_def->op() == ge::parser::SNAPSHOT) {
  2497. snapshot_to_optimize.push_back(node_def);
  2498. }
  2499. nodedef_map[node_name] = node_def;
  2500. }
  2501. // Optimize for Snapshot
  2502. GE_RETURN_IF_ERROR(GraphDefOptimizeSnapShot(graph_def, nodedef_map, snapshot_to_optimize));
  2503. for (int i = 0; i < graph_def->node_size(); i++) {
  2504. domi::tensorflow::NodeDef *nodeCurrent = graph_def->mutable_node(i);
  2505. GE_CHK_STATUS_RET(GraphDefOptimizeDestroyTemporaryVariable(graph_def, nodeCurrent));
  2506. }
  2507. // These member variables will be rebuilt later and need to be cleared here.
  2508. nodedef_map_.clear();
  2509. op_node_context_map_.clear();
  2510. return SUCCESS;
  2511. }
  2512. Status TensorFlowModelParser::RemoveIsolateNode(const ge::ComputeGraphPtr &graph) {
  2513. GE_CHECK_NOTNULL(graph);
  2514. auto nodes = graph->GetDirectNode();
  2515. for (auto &n : nodes) {
  2516. // get front 4 char
  2517. if (n->GetName().substr(0, 4) == "dpop") {
  2518. continue;
  2519. }
  2520. if ((n->GetType() == ge::parser::DATA) ||
  2521. (ge::GetParserContext().out_nodes_map.find(n->GetName()) != ge::GetParserContext().out_nodes_map.end())) {
  2522. GELOGI("Can not remove op [%s] because it is data or out node.", n->GetName().c_str());
  2523. continue;
  2524. }
  2525. GE_IF_BOOL_EXEC((((n->GetInAllNodes().size() == 0) && (n->GetOutDataNodes().size() == 0)) ||
  2526. ((n->GetType() == ge::parser::CONSTANTOP || n->GetType() == ge::parser::CONSTANT) &&
  2527. (n->GetOutDataNodes().size() == 0))),
  2528. GE_CHK_STATUS_RET(ge::GraphUtils::IsolateNode(n, {}), "Isolate removed node: %s, type: %s failed",
  2529. n->GetName().c_str(), n->GetType().c_str());
  2530. GE_CHK_STATUS_RET(ge::GraphUtils::RemoveNodeWithoutRelink(graph, n),
  2531. "Remove node: %s, type: %s without relink failed", n->GetName().c_str(),
  2532. n->GetType().c_str()););
  2533. }
  2534. return SUCCESS;
  2535. }
  2536. // The format specified by the command line argument is preferred,
  2537. // if not specified, use InferInputFormats to infer,
  2538. // and if the inference fails, the default NHWC format is used.
  2539. domiTensorFormat_t TensorFlowModelParser::InferInputFormats() {
  2540. GE_IF_BOOL_EXEC(ge::GetParserContext().format != DOMI_TENSOR_RESERVED, return ge::GetParserContext().format);
  2541. domiTensorFormat_t global_input_format = DOMI_TENSOR_RESERVED;
  2542. set<const NodeDef *> visited_node;
  2543. for (auto &node_item : nodedef_map_) {
  2544. // Infer format for data node and save it to ge::GetParserContext().format.
  2545. domiTensorFormat_t format = DOMI_TENSOR_RESERVED;
  2546. const NodeDef *node = node_item.second;
  2547. if (node == nullptr) {
  2548. return format;
  2549. }
  2550. auto it = tensorflow_op_map.find(node->op());
  2551. if (it != tensorflow_op_map.end() && it->second == ge::parser::DATA) {
  2552. GE_IF_BOOL_EXEC(GetNodeFormat(node, NO_TRANSPOSE, format, visited_node) != SUCCESS,
  2553. GELOGW("Cannot infer input format, the NHWC format is used by default, and you can also "
  2554. "specify format by command line arguments.");
  2555. return domi::DOMI_TENSOR_NHWC);
  2556. GE_IF_BOOL_EXEC(global_input_format == DOMI_TENSOR_RESERVED, global_input_format = format);
  2557. GE_IF_BOOL_EXEC(
  2558. format != DOMI_TENSOR_RESERVED && format != global_input_format,
  2559. GELOGW("Multiple data ops with different formats are not supported, "
  2560. "the NHWC format is used by default, and you can also specify format by command line arguments.");
  2561. return domi::DOMI_TENSOR_NHWC);
  2562. }
  2563. }
  2564. return global_input_format == DOMI_TENSOR_RESERVED ? domi::DOMI_TENSOR_NHWC : global_input_format;
  2565. }
  2566. Status TensorFlowModelParser::GetNodeFormat(const NodeDef *node, TfTranspose pred_transpose, domiTensorFormat_t &format,
  2567. set<const NodeDef *> &visited_node) {
  2568. GE_CHECK_NOTNULL(node);
  2569. // Avoid repeated visits.
  2570. GE_IF_BOOL_EXEC(visited_node.find(node) != visited_node.end(), return SUCCESS);
  2571. visited_node.emplace(node);
  2572. GE_IF_BOOL_EXEC(node->op() == TENSORFLOWF_NODE_OP_SWITCH || node->op() == TENSORFLOWF_NODE_OP_MERGE, return SUCCESS);
  2573. // If node has a data_format attribute, format is set according to data_format.
  2574. domi::tensorflow::AttrValue attr;
  2575. if (ge::TensorFlowUtil::FindAttrValue(node, TENSORFLOW_ATTR_DATA_FORMAT, attr) && node->op() != ge::parser::BIASADD) {
  2576. GE_RETURN_IF_ERROR(ge::TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_STRING));
  2577. format = (attr.s() == TENSORFLOWF_TENSOR_NCHW) ? domi::DOMI_TENSOR_NCHW : domi::DOMI_TENSOR_NHWC;
  2578. GE_IF_BOOL_EXEC(format == domi::DOMI_TENSOR_NCHW && pred_transpose == TO_NCHW, format = domi::DOMI_TENSOR_NHWC);
  2579. GE_IF_BOOL_EXEC(format == domi::DOMI_TENSOR_NHWC && pred_transpose == TO_NHWC, format = domi::DOMI_TENSOR_NCHW);
  2580. GE_IF_BOOL_EXEC((format == domi::DOMI_TENSOR_NCHW && pred_transpose == TO_NHWC) ||
  2581. (format == domi::DOMI_TENSOR_NHWC && pred_transpose == TO_NCHW),
  2582. GELOGI("Format conflicts with transpose.");
  2583. return FAILED);
  2584. return SUCCESS;
  2585. }
  2586. TfTranspose transpose;
  2587. GE_RETURN_IF_ERROR(GetFormatTranspose(node, transpose));
  2588. GE_IF_BOOL_EXEC(pred_transpose == transpose && pred_transpose != NO_TRANSPOSE,
  2589. GELOGI("Multiple transpose conflicts.");
  2590. return FAILED);
  2591. // If node does not have the data_format attribute, format is set according to the output node.
  2592. string node_name = node->name();
  2593. GE_IF_BOOL_EXEC(op_node_context_map_.find(node_name) == op_node_context_map_.end(),
  2594. GELOGI("node %s not found in op_node_context_map_", node_name.c_str());
  2595. return FAILED);
  2596. domiTensorFormat_t inferred_format = DOMI_TENSOR_RESERVED;
  2597. const OpNodeContext &node_ctx = op_node_context_map_.at(node_name);
  2598. for (const auto &output_item : node_ctx.output_map) {
  2599. auto node_iter = nodedef_map_.find(output_item.first);
  2600. GE_IF_BOOL_EXEC(node_iter == nodedef_map_.end(),
  2601. GELOGI("node %s not found in nodedef_map_", output_item.first.c_str());
  2602. return FAILED);
  2603. const NodeDef *output_node = node_iter->second;
  2604. GE_CHECK_NOTNULL(output_node);
  2605. domiTensorFormat_t output_format = DOMI_TENSOR_RESERVED;
  2606. GE_RETURN_IF_ERROR(GetNodeFormat(output_node, transpose, output_format, visited_node));
  2607. GE_IF_BOOL_EXEC(output_format != DOMI_TENSOR_RESERVED && inferred_format != DOMI_TENSOR_RESERVED &&
  2608. output_format != inferred_format,
  2609. GELOGI("Multiple output formats conflict.");
  2610. return FAILED);
  2611. inferred_format = output_format;
  2612. }
  2613. format = inferred_format;
  2614. return SUCCESS;
  2615. }
  2616. Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node, TfTranspose &transpose_direc) const {
  2617. GE_CHECK_NOTNULL(transpose_node);
  2618. transpose_direc = NO_TRANSPOSE;
  2619. GE_IF_BOOL_EXEC(transpose_node->op() != TENSORFLOWF_NODE_OP_TRANSPOSE, return SUCCESS);
  2620. GE_IF_BOOL_EXEC(transpose_node->input_size() != kInputNumInt, GELOGI("Input size of transpose is not 2.");
  2621. return FAILED);
  2622. string perm_node_name = transpose_node->input(1);
  2623. auto it = nodedef_map_.find(perm_node_name);
  2624. GE_IF_BOOL_EXEC(it == nodedef_map_.end(), GELOGI("Node %s not found in nodedef_map_.", perm_node_name.c_str());
  2625. return FAILED);
  2626. const NodeDef *perm_node = it->second;
  2627. GE_CHECK_NOTNULL(perm_node);
  2628. domi::tensorflow::AttrValue attr_value;
  2629. GE_IF_BOOL_EXEC(perm_node->op() != TENSORFLOWF_NODE_OP_CONST, GELOGI("Input node of transpose is not const.");
  2630. return FAILED);
  2631. GE_IF_BOOL_EXEC(!ge::TensorFlowUtil::FindAttrValue(perm_node, TENSORFLOW_ATTR_DTYPE, attr_value), return FAILED);
  2632. GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TYPE) != SUCCESS,
  2633. return FAILED);
  2634. domi::tensorflow::DataType type = attr_value.type();
  2635. GE_IF_BOOL_EXEC(type != domi::tensorflow::DT_INT32 && type != domi::tensorflow::DT_INT64, return FAILED);
  2636. GE_IF_BOOL_EXEC(!ge::TensorFlowUtil::FindAttrValue(perm_node, TENSORFLOW_ATTR_VALUE, attr_value), return FAILED);
  2637. GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TENSOR) != SUCCESS,
  2638. return FAILED);
  2639. const TensorProto &tensor = attr_value.tensor();
  2640. const domi::tensorflow::TensorShapeProto &tensor_shape = tensor.tensor_shape();
  2641. GE_IF_BOOL_EXEC(tensor_shape.dim_size() != 1 || tensor_shape.dim(0).size() != parser::DIM_DEFAULT_SIZE,
  2642. return SUCCESS);
  2643. GE_IF_BOOL_EXEC(tensor.tensor_content().empty(), return SUCCESS);
  2644. vector<int64_t> perm_value;
  2645. GE_IF_BOOL_EXEC(
  2646. type == domi::tensorflow::DT_INT32,
  2647. const int32_t *data = reinterpret_cast<const int32_t *>(tensor.tensor_content().data());
  2648. for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); });
  2649. GE_IF_BOOL_EXEC(
  2650. type == domi::tensorflow::DT_INT64,
  2651. const int64_t *data = reinterpret_cast<const int64_t *>(tensor.tensor_content().data());
  2652. for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); });
  2653. // 0, 1, 2, 3 present dim num.
  2654. vector<int64_t> perm_to_nchw = {0, 3, 1, 2};
  2655. vector<int64_t> perm_to_nhwc = {0, 2, 3, 1};
  2656. GE_IF_BOOL_EXEC(perm_value == perm_to_nchw, transpose_direc = TO_NCHW);
  2657. GE_IF_BOOL_EXEC(perm_value == perm_to_nhwc, transpose_direc = TO_NHWC);
  2658. return SUCCESS;
  2659. }
  2660. Status TensorFlowModelParser::TrimGraph(const domi::tensorflow::GraphDef &input_graph_def,
  2661. domi::tensorflow::GraphDef *output_graph_def) {
  2662. GE_CHECK_NOTNULL(output_graph_def);
  2663. if (!ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) {
  2664. return TrimGraphByInput(input_graph_def, output_graph_def);
  2665. } else {
  2666. return TrimGraphByOutput(input_graph_def, output_graph_def);
  2667. }
  2668. }
  2669. Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef &input_graph_def,
  2670. domi::tensorflow::GraphDef *const output_graph_def) {
  2671. // The caller guarantees that the pointer is not null
  2672. std::set<string> delete_nodes;
  2673. std::set<string> input_nodes;
  2674. for (auto &iter : ge::GetParserContext().input_dims) {
  2675. input_nodes.insert(iter.first);
  2676. }
  2677. std::map<string, const NodeDef *> node_lookup;
  2678. for (const NodeDef &node : input_graph_def.node()) {
  2679. node_lookup[node.name()] = &node;
  2680. }
  2681. std::vector<string> current_inputs;
  2682. for (auto &iter : ge::GetParserContext().input_dims) {
  2683. current_inputs.push_back(iter.first);
  2684. }
  2685. while (!current_inputs.empty()) {
  2686. std::set<string> next_inputs;
  2687. for (const string &current_input : current_inputs) {
  2688. delete_nodes.insert(current_input);
  2689. GE_CHK_BOOL_EXEC(node_lookup.count(current_input) > 0U,
  2690. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  2691. {"input_shape", current_input});
  2692. return FAILED, "Input op[%s] not found in graph.", current_input.c_str());
  2693. const NodeDef *current_node = node_lookup[current_input];
  2694. GE_CHECK_NOTNULL(current_node);
  2695. for (const string &input_name : current_node->input()) {
  2696. string input_node_name = NodeNameFromInput(input_name);
  2697. if (!delete_nodes.count(input_node_name)) {
  2698. next_inputs.insert(input_node_name);
  2699. }
  2700. }
  2701. }
  2702. current_inputs = std::vector<string>(next_inputs.begin(), next_inputs.end());
  2703. }
  2704. domi::tensorflow::GraphDef filtered_graph_def;
  2705. filtered_graph_def.mutable_node()->Clear();
  2706. for (const NodeDef &node : input_graph_def.node()) {
  2707. if (static_cast<bool>(input_nodes.count(node.name()))) {
  2708. *(filtered_graph_def.mutable_node()->Add()) = node;
  2709. }
  2710. if (!delete_nodes.count(node.name())) {
  2711. *(filtered_graph_def.mutable_node()->Add()) = node;
  2712. }
  2713. }
  2714. output_graph_def->Clear();
  2715. for (const NodeDef &node : filtered_graph_def.node()) {
  2716. if (static_cast<bool>(input_nodes.count(node.name()))) {
  2717. NodeDef placeholder_node = node;
  2718. placeholder_node.clear_input();
  2719. GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder"));
  2720. domi::tensorflow::AttrValue attr_value;
  2721. domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape();
  2722. GE_CHECK_NOTNULL(data_shape);
  2723. const ge::ParserContext &ctx = ge::GetParserContext();
  2724. std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
  2725. std::vector<int64_t> designated_dims = input_dims.at(node.name());
  2726. for (int32_t i = 0; i < static_cast<int32_t>(designated_dims.size()); i++) {
  2727. data_shape->add_dim()->set_size(designated_dims[i]);
  2728. }
  2729. google::protobuf::Map<std::string, domi::tensorflow::AttrValue> *attr = placeholder_node.mutable_attr();
  2730. (*attr)[TENSORFLOW_ATTR_SHAPE] = attr_value;
  2731. GE_CHECK_NOTNULL(output_graph_def->mutable_node());
  2732. *(output_graph_def->mutable_node()->Add()) = placeholder_node;
  2733. } else {
  2734. GE_CHECK_NOTNULL(output_graph_def->mutable_node());
  2735. *(output_graph_def->mutable_node()->Add()) = node;
  2736. }
  2737. }
  2738. return SUCCESS;
  2739. }
  2740. Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef &input_graph_def,
  2741. domi::tensorflow::GraphDef *const output_graph_def) {
  2742. // The caller guarantees that the pointer is not null
  2743. std::set<string> required_nodes;
  2744. std::set<string> input_nodes;
  2745. for (auto &iter : ge::GetParserContext().input_dims) {
  2746. required_nodes.insert(iter.first);
  2747. input_nodes.insert(iter.first);
  2748. }
  2749. for (auto &iter : ge::GetParserContext().out_nodes_map) {
  2750. required_nodes.insert(iter.first);
  2751. }
  2752. std::map<string, const NodeDef *> node_lookup;
  2753. for (const NodeDef &node : input_graph_def.node()) {
  2754. node_lookup[node.name()] = &node;
  2755. }
  2756. std::vector<string> current_inputs;
  2757. for (auto &iter : ge::GetParserContext().out_nodes_map) {
  2758. current_inputs.push_back(iter.first);
  2759. }
  2760. while (!current_inputs.empty()) {
  2761. std::set<string> next_inputs;
  2762. for (const string &current_input : current_inputs) {
  2763. required_nodes.insert(current_input);
  2764. GE_IF_BOOL_EXEC(static_cast<bool>(input_nodes.count(current_input)), continue);
  2765. GE_CHK_BOOL_EXEC(node_lookup.count(current_input) > 0U,
  2766. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  2767. {"out_nodes", current_input});
  2768. return FAILED, "op[%s] not found in graph.", current_input.c_str());
  2769. const NodeDef *current_node = node_lookup[current_input];
  2770. GE_CHECK_NOTNULL(current_node);
  2771. for (const string &input_name : current_node->input()) {
  2772. string input_node_name = NodeNameFromInput(input_name);
  2773. if (!required_nodes.count(input_node_name)) {
  2774. next_inputs.insert(input_node_name);
  2775. }
  2776. }
  2777. }
  2778. current_inputs = std::vector<string>(next_inputs.begin(), next_inputs.end());
  2779. }
  2780. domi::tensorflow::GraphDef filtered_graph_def;
  2781. filtered_graph_def.mutable_node()->Clear();
  2782. for (const NodeDef &node : input_graph_def.node()) {
  2783. if (static_cast<bool>(required_nodes.count(node.name()))) {
  2784. *(filtered_graph_def.mutable_node()->Add()) = node;
  2785. }
  2786. }
  2787. output_graph_def->Clear();
  2788. for (const NodeDef &node : filtered_graph_def.node()) {
  2789. if (static_cast<bool>(input_nodes.count(node.name()))) {
  2790. NodeDef placeholder_node = node;
  2791. placeholder_node.clear_input();
  2792. GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder"));
  2793. domi::tensorflow::AttrValue attr_value;
  2794. domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape();
  2795. GE_CHECK_NOTNULL(data_shape);
  2796. const ge::ParserContext &ctx = ge::GetParserContext();
  2797. std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
  2798. std::vector<int64_t> designated_dims = input_dims.at(node.name());
  2799. for (int32_t i = 0; i < static_cast<int32_t>(designated_dims.size()); i++) {
  2800. data_shape->add_dim()->set_size(designated_dims[i]);
  2801. }
  2802. google::protobuf::Map<std::string, domi::tensorflow::AttrValue> *attr = placeholder_node.mutable_attr();
  2803. (*attr)[TENSORFLOW_ATTR_SHAPE] = attr_value;
  2804. GE_CHECK_NOTNULL(output_graph_def->mutable_node());
  2805. *(output_graph_def->mutable_node()->Add()) = placeholder_node;
  2806. } else {
  2807. GE_CHECK_NOTNULL(output_graph_def->mutable_node());
  2808. *(output_graph_def->mutable_node()->Add()) = node;
  2809. }
  2810. }
  2811. return SUCCESS;
  2812. }
  2813. string TensorFlowModelParser::NodeNameFromInput(const string &input_name) {
  2814. string prefix;
  2815. string node_name;
  2816. string suffix;
  2817. std::vector<string> input_parts = ge::StringUtils::Split(input_name, ':');
  2818. suffix = (input_parts.size() < kInputNumUint) ? "" : (":" + input_parts[1]);
  2819. string tmp_name = input_parts[0];
  2820. GE_IF_BOOL_EXEC(input_parts[0].find("^") == 0, tmp_name = tmp_name.substr(1, tmp_name.length() - 1));
  2821. node_name = tmp_name;
  2822. return node_name;
  2823. }
  2824. Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_parser,
  2825. const domi::tensorflow::NodeDef *node_def,
  2826. ge::NodePtr &node) const {
  2827. GE_CHECK_NOTNULL(node_def);
  2828. GE_CHECK_NOTNULL(node);
  2829. GE_CHECK_NOTNULL(op_parser);
  2830. GELOGI("FusionNodeParseParams:node name:%s.", node_def->name().c_str());
  2831. // The fusion operator deals with parseparams separately
  2832. shared_ptr<TensorFlowFusionOpParser> tensorflow_fusion_op_parser =
  2833. std::dynamic_pointer_cast<TensorFlowFusionOpParser>(op_parser);
  2834. GE_IF_BOOL_EXEC(tensorflow_fusion_op_parser == nullptr,
  2835. REPORT_INNER_ERROR("E19999", "Param op_parser is not TensorFlowFusionOpParser Type, check invalid");
  2836. GELOGE(FAILED, "node :%s can not get fusion parser, please check!", node_def->name().c_str());
  2837. return INTERNAL_ERROR);
  2838. // Find all children of the fusion operator
  2839. auto iter = fusion_op_nodedef_map_.find(node_def->name());
  2840. if (iter == fusion_op_nodedef_map_.end()) {
  2841. REPORT_INNER_ERROR("E19999", "Node:%s can't find in fusion_op_nodedef_map_, check invalid",
  2842. node_def->name().c_str());
  2843. GELOGE(FAILED, "FusionOp node %s has no children node!", node_def->name().c_str());
  2844. return INTERNAL_ERROR;
  2845. }
  2846. (void)ge::AttrUtils::SetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, node_def->op());
  2847. vector<const domi::tensorflow::NodeDef *> node_def_v = iter->second;
  2848. domi::FusionParseParamByOpFunc parse_param_func =
  2849. domi::OpRegistry::Instance()->GetFusionParseParamByOpFunc(node->GetType(), node_def->op());
  2850. Status status = FAILED;
  2851. if (parse_param_func == nullptr) {
  2852. status = tensorflow_fusion_op_parser->ParseParams(node_def_v, node);
  2853. GE_CHK_STATUS_EXEC(status, return status, "Parse Params for fusionop node %s failed", node_def->name().c_str());
  2854. } else {
  2855. vector<ge::Operator> op_src_vec;
  2856. for (const auto &node_def_src : node_def_v) {
  2857. ge::Operator op_src(node_def_src->name().c_str(), node_def_src->op().c_str());
  2858. status = domi::OperatorAutoMapping(node_def_src, op_src);
  2859. if (status != SUCCESS) {
  2860. REPORT_CALL_ERROR("E19999", "Auto mapping node_def:%s(%s) to operator failed", node_def_src->name().c_str(),
  2861. node_def_src->op().c_str());
  2862. GELOGE(status, "Node[%s] auto mapping failed", node_def_src->name().c_str());
  2863. return status;
  2864. }
  2865. auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_src);
  2866. GE_CHECK_NOTNULL(op_desc);
  2867. for (int32_t i = 0; i < node_def_src->input_size(); i++) {
  2868. ge::GeTensorDesc tensor_desc;
  2869. tensor_desc.SetName(node_def_src->input(i));
  2870. if (op_desc->AddInputDesc(tensor_desc) != GRAPH_SUCCESS) {
  2871. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", op_desc->GetName().c_str(),
  2872. op_desc->GetType().c_str());
  2873. GELOGE(FAILED, "Op [%s] type[%s] add input(%d) tensor failed.", op_desc->GetName().c_str(),
  2874. op_desc->GetType().c_str(), i);
  2875. return FAILED;
  2876. }
  2877. }
  2878. op_src_vec.push_back(op_src);
  2879. }
  2880. shared_ptr<TensorFlowFusionCustomParserAdapter> tf_custom_fusion_op_paser =
  2881. std::dynamic_pointer_cast<TensorFlowFusionCustomParserAdapter>(tensorflow_fusion_op_parser);
  2882. status = tf_custom_fusion_op_paser->ParseParams(op_src_vec, node);
  2883. if (status != SUCCESS) {
  2884. GELOGE(status, "Parse params for fusionop node %s failed", node_def->name().c_str());
  2885. return status;
  2886. }
  2887. }
  2888. return SUCCESS;
  2889. }
  2890. /**
  2891. * @ingroup domi_omg
  2892. * @brief Optimizing const nodes for custom operators
  2893. * @param [in] graph_def graph object
  2894. * @return true optimize successfully
  2895. * @return false optimize failed
  2896. *
  2897. */
  2898. Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) const {
  2899. GE_CHECK_NOTNULL(graph_def);
  2900. // 1. find all the nodes in the graph and save them to all_nodedef_map
  2901. map<string, NodeDef *> all_nodedef_map;
  2902. int graph_node_size = graph_def->node_size();
  2903. for (int i = 0; i != graph_node_size; ++i) {
  2904. // mutable_node return vale is not empty
  2905. domi::tensorflow::NodeDef *current_node = graph_def->mutable_node(i);
  2906. string node_name = current_node->name();
  2907. all_nodedef_map[node_name] = current_node;
  2908. }
  2909. GELOGD("node size is: %zu", all_nodedef_map.size());
  2910. // 2. move input to attr.
  2911. for (auto &it_node_map : all_nodedef_map) {
  2912. domi::tensorflow::NodeDef *current_node = it_node_map.second;
  2913. GE_CHECK_NOTNULL(current_node);
  2914. string current_op_name = current_node->op();
  2915. // 2.1. check whether the current op is register for move to attr.
  2916. const std::vector<domi::RemoveInputConfigure> &move_input_vec =
  2917. domi::OpRegistry::Instance()->GetRemoveInputConfigure(current_op_name);
  2918. // 2.2 check whether the current op is a TVM op.
  2919. const bool is_unknown_custom_op = move_input_vec.empty() ||
  2920. (domi::OpRegistry::Instance()->GetImplyTypeByOriOpType(current_op_name) != domi::ImplyType::TVM);
  2921. if (is_unknown_custom_op) {
  2922. GELOGI("op %s is not TVM op, move input size: %zu", current_op_name.c_str(), move_input_vec.size());
  2923. continue;
  2924. }
  2925. GELOGD("Current op %s is registered for remove input and tvm op", current_op_name.c_str());
  2926. // 2.3 copy input to attr
  2927. set<uint32_t> unused_inputs;
  2928. for (const auto &it : move_input_vec) {
  2929. uint32_t move_index;
  2930. if (it.inputIdx >= 0) {
  2931. move_index = it.inputIdx;
  2932. } else {
  2933. GE_IF_BOOL_EXEC(
  2934. -it.inputIdx > current_node->input_size(),
  2935. ErrorManager::GetInstance().ATCReportErrMessage(
  2936. "E12004", {"opname", "inputIdx", "inputsize"},
  2937. {current_op_name, std::to_string(-it.inputIdx), std::to_string(current_node->input_size())});
  2938. GELOGE(INTERNAL_ERROR,
  2939. "Op[%s] register failed, inputIdx[-%d] should be greater than inputsize[%d] when inputIdx < 0.",
  2940. current_op_name.c_str(), it.inputIdx, current_node->input_size());
  2941. return PARAM_INVALID);
  2942. move_index = current_node->input_size() + it.inputIdx;
  2943. }
  2944. // For an isolated node in deep lab V3 networ.
  2945. // solve the problem of protobuf index less current_size.
  2946. GE_IF_BOOL_EXEC(current_node->input_size() == 0, GELOGI("Input size is 0, already optimized"); continue);
  2947. if (it.moveType == domi::OMG_REMOVE_TYPE_WITH_COND) {
  2948. domi::tensorflow::AttrValue attr_value;
  2949. GE_IF_BOOL_EXEC(!(ge::TensorFlowUtil::FindAttrValue(current_node, it.attrName, attr_value)),
  2950. REPORT_INNER_ERROR("E19999", "Op:%s register AttrName[%s] has no value, check invalid",
  2951. current_op_name.c_str(), it.attrName.c_str());
  2952. GELOGE(INTERNAL_ERROR, "AttrName[%s] has no value!", it.attrName.c_str());
  2953. return PARAM_INVALID);
  2954. GE_IF_BOOL_EXEC(attr_value.b() == it.attrValue, unused_inputs.insert(move_index));
  2955. } else if (it.moveType == domi::OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE && it.originalType == current_op_name) {
  2956. GELOGD("Input %s:%d will be removed.", current_op_name.c_str(), move_index);
  2957. unused_inputs.insert(move_index);
  2958. } else if (it.moveType == domi::OMG_INPUT_REORDER) {
  2959. auto inputs = current_node->input();
  2960. if (static_cast<size_t>(inputs.size()) != it.input_order.size()) {
  2961. REPORT_INNER_ERROR("E19999", "Input size of node:%s(%s) is mismatched, new order size:%zu, input size:%d",
  2962. current_node->name().c_str(), current_node->op().c_str(), it.input_order.size(),
  2963. inputs.size());
  2964. GELOGE(INTERNAL_ERROR, "Size of input is mismatched, new order size is %zu, input size is %d.",
  2965. it.input_order.size(), inputs.size());
  2966. return INTERNAL_ERROR;
  2967. }
  2968. for (size_t i = 0; i < it.input_order.size(); ++i) {
  2969. int new_index = it.input_order[i];
  2970. const bool is_input_invalid = (new_index < 0) || (new_index >= inputs.size());
  2971. if (is_input_invalid) {
  2972. REPORT_INNER_ERROR("E19999", "New order of %s has invalid index %d, out of range(0, %d)",
  2973. it_node_map.first.c_str(), new_index, inputs.size());
  2974. GELOGE(INTERNAL_ERROR, "New order of %s has invalid index %d.", it_node_map.first.c_str(), new_index);
  2975. return INTERNAL_ERROR;
  2976. }
  2977. current_node->set_input(i, inputs[new_index]);
  2978. }
  2979. GELOGI("The input sequence of the node has been rearranged, node name:%s.", it_node_map.first.c_str());
  2980. }
  2981. }
  2982. // 2.4 remove the input const nodes
  2983. Status ret = RemoveInputs(graph_def, current_node, unused_inputs, all_nodedef_map);
  2984. if (ret != SUCCESS) {
  2985. REPORT_CALL_ERROR("E19999", "remove input for op:%s failed", current_op_name.c_str());
  2986. GELOGE(INTERNAL_ERROR, "Op[%s] remove input failed.", current_op_name.c_str());
  2987. return ret;
  2988. }
  2989. }
  2990. return SUCCESS;
  2991. }
  2992. Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def,
  2993. domi::tensorflow::NodeDef *node_def,
  2994. const map<string, NodeDef *> &all_node_map,
  2995. const vector<string> &removed_inputs_vec) const {
  2996. GE_CHECK_NOTNULL(graph_def);
  2997. GE_CHECK_NOTNULL(node_def);
  2998. for (const auto &remove_input : removed_inputs_vec) {
  2999. string input_node_name = NodeNameFromInput(remove_input);
  3000. auto it = all_node_map.find(input_node_name);
  3001. if (it == all_node_map.end()) {
  3002. REPORT_INNER_ERROR("E19999", "Node:%s can't find in all_node_map, check invalid", input_node_name.c_str());
  3003. GELOGE(FAILED, "Can not find node name:%s in all node map.", input_node_name.c_str());
  3004. return FAILED;
  3005. }
  3006. NodeDef *input_node_def = it->second;
  3007. if (input_node_def->op() == parser::SWITCH || input_node_def->op() == parser::REFSWITCH) {
  3008. NodeDef *identity_node_def = graph_def->add_node();
  3009. GE_CHECK_NOTNULL(identity_node_def);
  3010. std::string remove_input_name = remove_input;
  3011. remove_input_name = remove_input_name.find(":") == std::string::npos ?
  3012. input_node_name : (remove_input_name.replace(remove_input_name.find(":"), 1, "_"));
  3013. input_node_name = remove_input_name + "_identity";
  3014. identity_node_def->set_name(input_node_name);
  3015. identity_node_def->set_op(parser::IDENTITY);
  3016. identity_node_def->add_input(remove_input);
  3017. }
  3018. string control_input = "^" + input_node_name;
  3019. node_def->add_input(control_input);
  3020. GELOGD("Add control input:%s for node:%s", control_input.c_str(), node_def->name().c_str());
  3021. }
  3022. return SUCCESS;
  3023. }
  3024. /**
  3025. * @ingroup domi_omg
  3026. * @brief Delete input from nodedef
  3027. * @param [in] node_def Nodedef object
  3028. * @param [in] remove_index_set Index collection of input nodes to be deleted
  3029. * @return true remove successfully
  3030. * @return false remove failed
  3031. *
  3032. */
  3033. Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *node_def,
  3034. const set<uint32_t> &remove_index_set,
  3035. const map<string, NodeDef *> &all_node_map) const {
  3036. GE_CHECK_NOTNULL(node_def);
  3037. if (remove_index_set.empty()) {
  3038. GELOGI("The size of remove_index_set is zero.");
  3039. return SUCCESS;
  3040. }
  3041. map<string, vector<int>> remove_inputs_map;
  3042. for (auto &it : remove_index_set) {
  3043. const string &input_node_name = node_def->input(it);
  3044. remove_inputs_map[input_node_name].emplace_back(it);
  3045. GELOGD("Push input:%s, index:%d into remove map.", input_node_name.c_str(), it);
  3046. }
  3047. RemoveInputAttr(node_def, remove_inputs_map);
  3048. int index = 0;
  3049. vector<string> removed_inputs_vec;
  3050. auto *inputs = node_def->mutable_input();
  3051. for (auto input_it = inputs->begin(); input_it != inputs->end(); ++index) {
  3052. // 1.decide whether to remove the input
  3053. bool flag = false;
  3054. for (auto &remove_input : remove_inputs_map) {
  3055. string remove_input_name = remove_input.first;
  3056. vector<int> remove_input_indexs = remove_input.second;
  3057. if ((*input_it) == remove_input_name &&
  3058. std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) {
  3059. GELOGD("Remove input:%s, index:%d", remove_input_name.c_str(), index);
  3060. flag = true;
  3061. removed_inputs_vec.emplace_back(remove_input_name);
  3062. break;
  3063. }
  3064. }
  3065. if (flag) {
  3066. // 2 remove the input
  3067. input_it = inputs->erase(input_it);
  3068. } else {
  3069. ++input_it;
  3070. }
  3071. }
  3072. Status ret = AddControlEdgeAfterRemoveInputs(graph_def, node_def, all_node_map, removed_inputs_vec);
  3073. if (ret != SUCCESS) {
  3074. GELOGE(FAILED, "Add control edges for node:%s failed.", node_def->name().c_str());
  3075. return FAILED;
  3076. }
  3077. return SUCCESS;
  3078. }
  3079. void TensorFlowModelParser::RemoveInputAttr(domi::tensorflow::NodeDef *node_def,
  3080. const map<string, vector<int>> &remove_inputs_map) const {
  3081. // The caller guarantees that the pointer is not null
  3082. auto *inputs = node_def->mutable_input();
  3083. google::protobuf::Map<std::string, domi::tensorflow::AttrValue> *attr_map = node_def->mutable_attr();
  3084. const google::protobuf::Map<std::string, domi::tensorflow::AttrValue>::iterator it =
  3085. attr_map->find(ge::ATTR_NAME_INPUT_TENSOR_DESC);
  3086. if (it == attr_map->end()) {
  3087. GELOGW("Failed to find input desc from tf node_def[%s]", node_def->name().c_str());
  3088. } else {
  3089. domi::tensorflow::AttrValue *input_attr_value = &(it->second);
  3090. auto tmp_attr = input_attr_value->mutable_list()->mutable_func();
  3091. auto attr_it = tmp_attr->begin();
  3092. int index = 0;
  3093. for (auto input_it = inputs->begin(); input_it != inputs->end(); ++input_it, ++index) {
  3094. // 1.decide whether to remove the input
  3095. bool flag = false;
  3096. for (auto &remove_input : remove_inputs_map) {
  3097. string remove_input_name = remove_input.first;
  3098. vector<int> remove_input_indexs = remove_input.second;
  3099. if ((*input_it) == remove_input_name &&
  3100. std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) {
  3101. GELOGD("Remove input attr:%s, index:%d", remove_input_name.c_str(), index);
  3102. flag = true;
  3103. break;
  3104. }
  3105. }
  3106. if (flag) {
  3107. // 2.1 remove the input attr
  3108. if (!tmp_attr->empty() && attr_it != tmp_attr->end()) {
  3109. attr_it = tmp_attr->erase(attr_it);
  3110. } else {
  3111. ++attr_it;
  3112. }
  3113. } else {
  3114. ++attr_it;
  3115. }
  3116. }
  3117. }
  3118. }
  3119. Status TensorFlowModelParser::GetTensorflowGraphInOutMap(domi::tensorflow::GraphDef *graph_def) {
  3120. GE_CHECK_NOTNULL(graph_def);
  3121. for (int i = 0; i < graph_def->node_size(); i++) {
  3122. domi::tensorflow::NodeDef *node = graph_def->mutable_node(i);
  3123. const string &node_name = node->name();
  3124. node_inputs_outputs_map_.emplace(node_name, std::pair<set<string>, set<string>>{});
  3125. for (const auto &input : node->input()) {
  3126. string input_node_name;
  3127. GE_RETURN_IF_ERROR(CheckInputNodeName(input, &input_node_name, nullptr, nullptr));
  3128. node_inputs_outputs_map_[node_name].first.insert(input_node_name);
  3129. node_inputs_outputs_map_[input_node_name].second.insert(node_name);
  3130. }
  3131. }
  3132. return SUCCESS;
  3133. }
  3134. Status TensorFlowModelParser::RemoveIsolateNode(domi::tensorflow::GraphDef *graph_def) {
  3135. GE_CHECK_NOTNULL(graph_def);
  3136. set<string> node_to_delete;
  3137. for (int i = 0; i < graph_def->node_size(); i++) {
  3138. domi::tensorflow::NodeDef *node = graph_def->mutable_node(i);
  3139. const string &node_name = node->name();
  3140. if (node_inputs_outputs_map_.find(node_name) == node_inputs_outputs_map_.end()) {
  3141. REPORT_INNER_ERROR("E19999", "Node:%s can't find in node_inputs_outputs_map_, check invalid", node_name.c_str());
  3142. GELOGE(FAILED, "Can not find input output context, node:%s.", node_name.c_str());
  3143. return FAILED;
  3144. }
  3145. if ((node_inputs_outputs_map_[node_name].first.empty() && node_inputs_outputs_map_[node_name].second.empty() &&
  3146. node->op() != kDpop) ||
  3147. (node->op() == ge::parser::CONSTANT && node_inputs_outputs_map_[node_name].second.empty())) {
  3148. GELOGI("%s will inset to node_to_delete", node_name.c_str());
  3149. node_to_delete.insert(node_name);
  3150. }
  3151. }
  3152. // delete isolate nodes
  3153. auto nodeList = graph_def->mutable_node();
  3154. for (auto iter = nodeList->begin(); iter != nodeList->end();) {
  3155. if (node_to_delete.count(iter->name()) != 0) {
  3156. GELOGI("%s has zero input and output, will delete.", iter->name().c_str());
  3157. iter = nodeList->erase(iter);
  3158. } else {
  3159. iter++;
  3160. }
  3161. }
  3162. return SUCCESS;
  3163. }
  3164. Status TensorFlowModelParser::RecordFusionResult(const std::shared_ptr<ge::ScopeGraph> &scope_graph,
  3165. const domi::tensorflow::NodeDef *node, const ge::OpDescPtr &op_desc) {
  3166. // The caller guarantees that the pointer is not null
  3167. GELOGI("RecordFusionResult for %s start.", op_desc->GetName().c_str());
  3168. auto &impl_scope_graph = scope_graph->impl_;
  3169. ge::FusionScopesResult *fusion_result = impl_scope_graph->GetFusionScopesResults(node);
  3170. if (fusion_result == nullptr) {
  3171. GELOGW("fusion_result is not found.");
  3172. return SUCCESS;
  3173. }
  3174. std::vector<std::string> original_names;
  3175. auto nodes = fusion_result->Nodes();
  3176. std::transform(nodes.begin(), nodes.end(), std::back_inserter(original_names),
  3177. [](ge::OperatorPtr n) -> std::string { return ParserUtils::GetOperatorName(*n); });
  3178. GELOGI("Op %s original_names size = %zu.", op_desc->GetName().c_str(), original_names.size());
  3179. bool ret = ge::AttrUtils::SetListStr(op_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
  3180. if (!ret) {
  3181. GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), op_desc->GetName().c_str());
  3182. }
  3183. auto outputs_desc = op_desc->GetAllOutputsDesc();
  3184. auto &impl = fusion_result->impl_;
  3185. for (auto &fusion_output : impl->GetOutputs()) {
  3186. for (size_t i = 0; i < fusion_output.second.size(); ++i) {
  3187. if (fusion_output.second[i] == ge::kFusionDisableIndex) {
  3188. continue;
  3189. }
  3190. if (fusion_output.second[i] >= static_cast<int32_t>(op_desc->GetOutputsSize())) {
  3191. REPORT_INNER_ERROR("E19999", "fusion output index:%d of node:%s(%s) must less than outputs desc size %zu.",
  3192. fusion_output.second[i], op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  3193. op_desc->GetOutputsSize());
  3194. GELOGE(PARAM_INVALID, "fusion output index %d must less than outputs desc size %zu.", fusion_output.second[i],
  3195. op_desc->GetOutputsSize());
  3196. return PARAM_INVALID;
  3197. }
  3198. ret = ge::AttrUtils::SetStr(op_desc->MutableOutputDesc(fusion_output.second[i]),
  3199. ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, fusion_output.first);
  3200. if (!ret) {
  3201. GELOGW("Set %s to %s %d output fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME.c_str(), op_desc->GetName().c_str(),
  3202. fusion_output.second[i]);
  3203. }
  3204. ret = ge::AttrUtils::SetInt(op_desc->MutableOutputDesc(fusion_output.second[i]),
  3205. ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, i);
  3206. if (!ret) {
  3207. GELOGW("Set %s to %s %d output fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX.c_str(),
  3208. op_desc->GetName().c_str(), fusion_output.second[i]);
  3209. }
  3210. }
  3211. }
  3212. return SUCCESS;
  3213. }
  3214. Status TensorFlowModelParser::SetOriginNodeContext(const NodeDef *node_def, OpNodeContext &op_node_context,
  3215. const std::vector<std::pair<std::string, int32_t>> &inputs,
  3216. const std::vector<std::pair<std::string, int32_t>> &outputs) {
  3217. int32_t in_index = 0;
  3218. for (const auto &in : inputs) {
  3219. bool is_ctrl = in.second == kControlSlot;
  3220. op_node_context.input_map[in.first].emplace_back(std::make_pair(in.second, is_ctrl ? kControlSlot : in_index));
  3221. SaveEdgesControlInfo(node_def->name(), is_ctrl);
  3222. in_index = is_ctrl ? in_index : in_index + 1;
  3223. }
  3224. int32_t out_index = 0;
  3225. for (const auto &out : outputs) {
  3226. bool is_ctrl = out.second == kControlSlot;
  3227. op_node_context.output_map[out.first].emplace_back(std::make_pair(is_ctrl ? kControlSlot : out_index, out.second));
  3228. out_index = is_ctrl ? out_index : out_index + 1;
  3229. }
  3230. return SUCCESS;
  3231. }
  3232. void TensorFlowModelParser::GetFusionInputInfo(
  3233. const string &fusion_op_name, OpNodeContext &fusion_context,
  3234. std::map<string, std::pair<std::string, std::pair<int32_t, int32_t>>> &remap_data_input,
  3235. std::map<string, std::vector<string>> &remap_ctrl_input, std::set<string> &fusion_input_nodes) {
  3236. for (const auto &fusion_input : fusion_context.input_map) {
  3237. string fusion_src_name = fusion_input.first;
  3238. for (const auto &fusion_idx_pair : fusion_input.second) {
  3239. string key = fusion_op_name + std::to_string(fusion_idx_pair.second);
  3240. if (fusion_idx_pair.second != kControlSlot) {
  3241. remap_data_input[key] = {fusion_src_name, {fusion_idx_pair.first, fusion_idx_pair.second}};
  3242. } else {
  3243. remap_ctrl_input[key].emplace_back(fusion_src_name);
  3244. }
  3245. }
  3246. fusion_input_nodes.insert(fusion_src_name);
  3247. }
  3248. }
  3249. void TensorFlowModelParser::GetFusionOutputInfo(
  3250. const string &fusion_op_name, OpNodeContext &fusion_context,
  3251. std::map<string, std::vector<std::pair<std::string, std::pair<int32_t, int32_t>>>> &remap_data_output,
  3252. std::map<string, std::vector<string>> &remap_ctrl_output, std::set<string> &fusion_output_nodes) {
  3253. for (const auto &fusion_output : fusion_context.output_map) {
  3254. string fusion_dst_name = fusion_output.first;
  3255. for (const auto &fusion_idx_pair : fusion_output.second) {
  3256. string key = fusion_op_name + std::to_string(fusion_idx_pair.first);
  3257. if (fusion_idx_pair.first != kControlSlot) {
  3258. remap_data_output[key].emplace_back(
  3259. std::make_pair(fusion_dst_name, std::make_pair(fusion_idx_pair.first, fusion_idx_pair.second)));
  3260. } else {
  3261. remap_ctrl_output[key].emplace_back(fusion_dst_name);
  3262. }
  3263. }
  3264. fusion_output_nodes.insert(fusion_dst_name);
  3265. }
  3266. }
  3267. void TensorFlowModelParser::UpdateInnerInputMap(const string &fusion_op_name, OpNodeContext &fusion_context,
  3268. const std::vector<std::string> &inner_nodes_name,
  3269. std::set<string> &fusion_input_nodes) {
  3270. std::map<string, std::pair<std::string, std::pair<int32_t, int32_t>>> remap_data_input;
  3271. std::map<string, std::vector<string>> remap_ctrl_input;
  3272. GetFusionInputInfo(fusion_op_name, fusion_context, remap_data_input, remap_ctrl_input, fusion_input_nodes);
  3273. for (const auto &node_name : inner_nodes_name) {
  3274. auto context_iter = op_node_context_map_.find(node_name);
  3275. if (context_iter != op_node_context_map_.end()) {
  3276. OpNodeContext &op_node_context = context_iter->second;
  3277. // update input map of inner node
  3278. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> tmp_input_map;
  3279. for (auto iter = op_node_context.input_map.begin(); iter != op_node_context.input_map.end();) {
  3280. string src_name = iter->first;
  3281. if (src_name == ge::kInputFromFusionScope) {
  3282. std::vector<std::pair<int32_t, int32_t>> &input_idx = iter->second;
  3283. for (const auto &in_pair : input_idx) {
  3284. if (in_pair.second != kControlSlot) {
  3285. auto data = remap_data_input[fusion_op_name + std::to_string(in_pair.first)];
  3286. tmp_input_map[data.first].emplace_back(std::make_pair(data.second.first, in_pair.second));
  3287. GELOGI("Update inner input, src:%s, idx:%u->%u", data.first.c_str(), data.second.first, in_pair.second);
  3288. }
  3289. }
  3290. auto ctrl = remap_ctrl_input[fusion_op_name + std::to_string(kControlSlot)];
  3291. for (const auto &ctrl_in : ctrl) {
  3292. tmp_input_map[ctrl_in].emplace_back(std::make_pair(kControlSlot, kControlSlot));
  3293. SaveEdgesControlInfo(node_name, kControlSlot);
  3294. }
  3295. iter = op_node_context.input_map.erase(iter);
  3296. } else {
  3297. ++iter;
  3298. }
  3299. }
  3300. op_node_context.input_map.insert(tmp_input_map.cbegin(), tmp_input_map.cend());
  3301. // update output map of pre node
  3302. for (const auto &in_iter : op_node_context.input_map) {
  3303. auto src_iter = op_node_context_map_.find(in_iter.first);
  3304. if (src_iter != op_node_context_map_.end()) {
  3305. std::vector<std::pair<int32_t, int32_t>> input_pairs = in_iter.second;
  3306. OpNodeContext &src_context = src_iter->second;
  3307. src_context.output_map[node_name].assign(input_pairs.begin(), input_pairs.end());
  3308. }
  3309. }
  3310. }
  3311. }
  3312. }
  3313. void TensorFlowModelParser::UpdateInnerOutputMap(const string &fusion_op_name, OpNodeContext &fusion_context,
  3314. const std::vector<std::string> &inner_nodes_name,
  3315. std::set<string> &fusion_output_nodes) {
  3316. std::map<string, std::vector<std::pair<std::string, std::pair<int32_t, int32_t>>>> remap_data_output;
  3317. std::map<string, std::vector<string>> remap_ctrl_output;
  3318. GetFusionOutputInfo(fusion_op_name, fusion_context, remap_data_output, remap_ctrl_output, fusion_output_nodes);
  3319. for (const auto &node_name : inner_nodes_name) {
  3320. auto context_iter = op_node_context_map_.find(node_name);
  3321. if (context_iter != op_node_context_map_.end()) {
  3322. OpNodeContext &op_node_context = context_iter->second;
  3323. // update output map of inner node
  3324. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> tmp_output_map;
  3325. for (auto iter = op_node_context.output_map.begin(); iter != op_node_context.output_map.end();) {
  3326. string dst_name = iter->first;
  3327. if (dst_name == ge::kOutputToFusionScope) {
  3328. std::vector<std::pair<int32_t, int32_t>> &output_idx = iter->second;
  3329. for (const auto &out_pair : output_idx) {
  3330. if (out_pair.second != kControlSlot) {
  3331. auto data_outputs = remap_data_output[fusion_op_name + std::to_string(out_pair.second)];
  3332. for (const auto &data : data_outputs) {
  3333. tmp_output_map[data.first].emplace_back(std::make_pair(out_pair.first, data.second.second));
  3334. GELOGI("Update inner output, dst:%s, idx:%u->%u.", data.first.c_str(), out_pair.first,
  3335. data.second.second);
  3336. }
  3337. }
  3338. }
  3339. auto ctrl = remap_ctrl_output[fusion_op_name + std::to_string(kControlSlot)];
  3340. for (const auto &ctrl_in : ctrl) {
  3341. tmp_output_map[ctrl_in].emplace_back(std::make_pair(kControlSlot, kControlSlot));
  3342. }
  3343. iter = op_node_context.output_map.erase(iter);
  3344. } else {
  3345. ++iter;
  3346. }
  3347. }
  3348. op_node_context.output_map.insert(tmp_output_map.cbegin(), tmp_output_map.cend());
  3349. // update input map of pre node
  3350. for (const auto &out_iter : op_node_context.output_map) {
  3351. auto dst_iter = op_node_context_map_.find(out_iter.first);
  3352. if (dst_iter != op_node_context_map_.end()) {
  3353. std::vector<std::pair<int32_t, int32_t>> output_pairs = out_iter.second;
  3354. OpNodeContext &dst_context = dst_iter->second;
  3355. dst_context.input_map[node_name].assign(output_pairs.begin(), output_pairs.end());
  3356. }
  3357. }
  3358. }
  3359. }
  3360. }
  3361. Status TensorFlowModelParser::UpdateInnerNodeContext(const string &fusion_op_name,
  3362. const std::vector<std::string> &inner_nodes_name) {
  3363. auto fusion_iter = op_node_context_map_.find(fusion_op_name);
  3364. if (fusion_iter == op_node_context_map_.end()) {
  3365. REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", fusion_op_name.c_str());
  3366. GELOGE(INTERNAL_ERROR, "Can't find context for fusion node %s.", fusion_op_name.c_str());
  3367. return INTERNAL_ERROR;
  3368. }
  3369. OpNodeContext &fusion_context = fusion_iter->second;
  3370. std::set<string> fusion_input_nodes;
  3371. std::set<string> fusion_output_nodes;
  3372. UpdateInnerInputMap(fusion_op_name, fusion_context, inner_nodes_name, fusion_input_nodes);
  3373. UpdateInnerOutputMap(fusion_op_name, fusion_context, inner_nodes_name, fusion_output_nodes);
  3374. for (const auto &in_name : fusion_input_nodes) {
  3375. auto fusion_in = op_node_context_map_.find(in_name);
  3376. if (fusion_in != op_node_context_map_.end()) {
  3377. OpNodeContext &fusion_in_context = fusion_in->second;
  3378. fusion_in_context.output_map.erase(fusion_op_name);
  3379. }
  3380. }
  3381. for (const auto &out_name : fusion_output_nodes) {
  3382. auto fusion_out = op_node_context_map_.find(out_name);
  3383. if (fusion_out != op_node_context_map_.end()) {
  3384. OpNodeContext &fusion_out_context = fusion_out->second;
  3385. fusion_out_context.input_map.erase(fusion_op_name);
  3386. }
  3387. }
  3388. op_node_context_map_.erase(fusion_op_name);
  3389. return SUCCESS;
  3390. }
  3391. Status TensorFlowModelParser::AddFusionInnerNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph,
  3392. const string &fusion_op_name, vector<string> &node_name_list) {
  3393. auto &impl_scope_graph = scope_graph->impl_;
  3394. GE_CHECK_NOTNULL(impl_scope_graph);
  3395. ge::FusionScopesResult *fusion_result = impl_scope_graph->GetFusionScopesResults(fusion_op_name);
  3396. GE_CHECK_NOTNULL(fusion_result);
  3397. auto &impl_fusion_rlt = fusion_result->impl_;
  3398. GE_CHECK_NOTNULL(impl_fusion_rlt);
  3399. ge::FusionInnerNodesInfo inner_nodes_info = impl_fusion_rlt->GetInnerNodesInfo();
  3400. vector<string> inner_nodes_name;
  3401. for (const auto &info : inner_nodes_info) {
  3402. string node_name;
  3403. string type;
  3404. std::vector<std::pair<std::string, int32_t>> inputs;
  3405. std::vector<std::pair<std::string, int32_t>> outputs;
  3406. const ge::Operator *op = nullptr;
  3407. std::tie(node_name, type, inputs, outputs, op) = info;
  3408. NodeDef *node_def = new (std::nothrow) NodeDef();
  3409. GE_CHECK_NOTNULL(node_def);
  3410. node_def->set_name(node_name);
  3411. node_def->set_op(type);
  3412. nodedef_map_[node_name] = node_def;
  3413. fusion_nodedef_list.push_back(node_def);
  3414. for (const auto &in : inputs) {
  3415. // The input value is not used in the subsequent process. The value is added only for placeholders.
  3416. node_def->add_input(in.first);
  3417. }
  3418. domi::tensorflow::AttrValue attr_value;
  3419. attr_value.set_b(true);
  3420. ge::TensorFlowUtil::AddNodeAttr(kAttrNameIsScopeInnerNode, attr_value, node_def);
  3421. OpNodeContext &op_node_context = op_node_context_map_[node_name];
  3422. Status ret = SetOriginNodeContext(node_def, op_node_context, inputs, outputs);
  3423. if (ret != SUCCESS) {
  3424. GELOGE(ret, "Failed to add context and attrs, node:%s.", node_name.c_str());
  3425. return ret;
  3426. }
  3427. scope_inner_node_map_.insert({node_name, op});
  3428. node_name_list.emplace_back(node_name);
  3429. inner_nodes_name.emplace_back(node_name);
  3430. GELOGI("Add fusion inner node def, name:%s, type:%s.", node_name.c_str(), type.c_str());
  3431. }
  3432. Status ret = UpdateInnerNodeContext(fusion_op_name, inner_nodes_name);
  3433. if (ret != SUCCESS) {
  3434. GELOGE(ret, "Failed to update inner node context, fusion_op_name:%s.", fusion_op_name.c_str());
  3435. return ret;
  3436. }
  3437. return SUCCESS;
  3438. }
  3439. Status TensorFlowModelParser::AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph,
  3440. vector<string> &node_name_list) {
  3441. vector<string> node_name_list_new;
  3442. size_t op_node_list_size = node_name_list.size();
  3443. DumpAllNodeContext("BeforeAddFusionNodeDef");
  3444. for (size_t i = 0; i < op_node_list_size; ++i) {
  3445. const string op_node_name = node_name_list[i];
  3446. std::map<string, vector<const NodeDef *>>::const_iterator iter = fusion_op_nodedef_map_.find(op_node_name);
  3447. if (iter != fusion_op_nodedef_map_.end()) {
  3448. vector<string> fusion_op_info = fusion_op_type_map_[op_node_name];
  3449. if (fusion_op_info[0] != ge::kScopeToMultiNodes) {
  3450. NodeDef *node_def = new (std::nothrow) NodeDef();
  3451. GE_CHECK_NOTNULL(node_def);
  3452. node_def->set_name(op_node_name);
  3453. node_def->set_op(fusion_op_info[0]);
  3454. nodedef_map_[op_node_name] = node_def;
  3455. fusion_nodedef_list.push_back(node_def);
  3456. OpNodeContext &node_context = op_node_context_map_[node_def->name()];
  3457. for (const auto &input : node_context.input_map) {
  3458. // The input value is not used in the subsequent process. The value is added only for placeholders.
  3459. node_def->add_input(input.first);
  3460. }
  3461. node_name_list_new.emplace_back(op_node_name);
  3462. GELOGI("Add Fusion node def, name:%s, type:%s.", node_def->name().c_str(), node_def->op().c_str());
  3463. } else {
  3464. Status ret = AddFusionInnerNodeDef(scope_graph, op_node_name, node_name_list_new);
  3465. if (ret != SUCCESS) {
  3466. REPORT_INNER_ERROR("E19999",
  3467. "Failed to add fusion inner nodes for fusion op:%s, "
  3468. "please check FusionScopesResult set in scope fusion pass",
  3469. op_node_name.c_str());
  3470. GELOGE(ret, "Failed to add fusion inner node, fusion_op_name:%s.", op_node_name.c_str());
  3471. return ret;
  3472. }
  3473. GELOGI("Add fusion inner nodes successfully, fusion name:%s.", op_node_name.c_str());
  3474. op_node_context_map_.erase(op_node_name);
  3475. }
  3476. } else {
  3477. node_name_list_new.emplace_back(op_node_name);
  3478. }
  3479. }
  3480. node_name_list.clear();
  3481. node_name_list.assign(node_name_list_new.begin(), node_name_list_new.end());
  3482. DumpAllNodeContext("AfterAddFusionNodeDef");
  3483. return SUCCESS;
  3484. }
  3485. Status TensorFlowModelParser::AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph,
  3486. std::mutex *const graph_mutex,
  3487. const domi::tensorflow::NodeDef *node_def) {
  3488. // This is an internal function. The pointer input parameter is not empty when this function is invoked.
  3489. string node_name = node_def->name();
  3490. string node_op = node_def->op();
  3491. auto iter = parser->scope_inner_node_map_.find(node_name);
  3492. if (iter == parser->scope_inner_node_map_.end()) {
  3493. REPORT_INNER_ERROR("E19999", "Node:%s can't find in scope_inner_node_map_, check invalid", node_name.c_str());
  3494. GELOGE(PARAM_INVALID, "Failed to find scope inner node:%s, type:%s.", node_name.c_str(), node_op.c_str());
  3495. return PARAM_INVALID;
  3496. }
  3497. const ge::Operator *op = iter->second;
  3498. ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(*op);
  3499. GE_CHECK_NOTNULL(op_desc);
  3500. ge::NodePtr node;
  3501. {
  3502. std::lock_guard<std::mutex> lock(*graph_mutex);
  3503. node = graph->AddNode(op_desc);
  3504. }
  3505. if (node == nullptr) {
  3506. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(),
  3507. op_desc->GetType().c_str(), graph->GetName().c_str());
  3508. GELOGE(INTERNAL_ERROR, "Failed to Add scope inner node:%s, type:%s.", op_desc->GetName().c_str(),
  3509. op_desc->GetType().c_str());
  3510. return INTERNAL_ERROR;
  3511. }
  3512. {
  3513. std::lock_guard<std::mutex> lock(parser->nodeMapMutex_);
  3514. parser->node_map_[node_name] = node;
  3515. }
  3516. GELOGI("Add scope inner node successfully, node name:%s, type:%s.", op_desc->GetName().c_str(),
  3517. op_desc->GetType().c_str());
  3518. return SUCCESS;
  3519. }
  3520. void TensorFlowModelParser::DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase) {
  3521. GELOGD("phase:%s === Begin to dump context for node:%s ===", phase.c_str(), node_name.c_str());
  3522. for (const auto &input : ctx.input_map) {
  3523. for (const auto &input_idx : input.second) {
  3524. GELOGD(" Input info: %s:%d --> in_idx %d.", input.first.c_str(), input_idx.first, input_idx.second);
  3525. }
  3526. }
  3527. for (const auto &output : ctx.output_map) {
  3528. for (const auto &output_idx : output.second) {
  3529. GELOGD(" Output info: out_idx %d --> %s:%d.", output_idx.first, output.first.c_str(), output_idx.second);
  3530. }
  3531. }
  3532. GELOGD("phase:%s === End to dump context for node:%s ===", phase.c_str(), node_name.c_str());
  3533. }
  3534. void TensorFlowModelParser::DumpAllNodeContext(const string &phase) const {
  3535. if (!IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) {
  3536. return;
  3537. }
  3538. for (const auto &iter : op_node_context_map_) {
  3539. DumpNodeContext(iter.first, iter.second, phase);
  3540. }
  3541. }
  3542. Status TensorFlowModelParser::CheckAndUpdateInputDesc(const ge::ComputeGraphPtr &compute_graph) {
  3543. GE_CHECK_NOTNULL(compute_graph);
  3544. for (auto &node : compute_graph->GetDirectNode()) {
  3545. auto op_desc = node->GetOpDesc();
  3546. GE_CHECK_NOTNULL(op_desc);
  3547. for (auto &in_anchor : node->GetAllInDataAnchors()) {
  3548. if (!(op_desc->IsOptionalInput(static_cast<uint32_t>(in_anchor->GetIdx())))) {
  3549. continue;
  3550. }
  3551. auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
  3552. auto in_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()));
  3553. if ((peer_out_anchor != nullptr) && (in_desc == nullptr)) {
  3554. // The input is connected to the peer output but TensorDesc is invalid, update TensorDesc to valid.
  3555. ge::GeTensorDesc tensor_desc;
  3556. auto ret = op_desc->UpdateInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()), tensor_desc);
  3557. if (ret != ge::GRAPH_SUCCESS) {
  3558. REPORT_CALL_ERROR("E19999", "Update index:%d of input desc in op:%s(%s) failed", in_anchor->GetIdx(),
  3559. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  3560. GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
  3561. return ret;
  3562. }
  3563. GELOGI("Update input desc to valid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
  3564. } else if ((peer_out_anchor == nullptr) && (in_desc != nullptr)) {
  3565. // The input is not connected to the peer output but TensorDesc is valid, update TensorDesc to invalid.
  3566. ge::GeTensorDesc tensor_desc(ge::GeShape(), FORMAT_RESERVED, DT_UNDEFINED);
  3567. auto ret = op_desc->UpdateInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()), tensor_desc);
  3568. if (ret != ge::GRAPH_SUCCESS) {
  3569. REPORT_CALL_ERROR("E19999", "Update index:%d of input desc in op:%s(%s) failed", in_anchor->GetIdx(),
  3570. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  3571. GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
  3572. return ret;
  3573. }
  3574. GELOGI("Update input desc to invalid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
  3575. }
  3576. }
  3577. }
  3578. return SUCCESS;
  3579. }
  3580. Status TensorFlowModelParser::UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes) {
  3581. auto &user_specified_nodes = ge::GetParserContext().user_out_nodes;
  3582. if (!user_specified_nodes.empty()) {
  3583. for (auto &output_node_info : user_specified_nodes) {
  3584. ParserUtils::UpdateOutputNodeInfo(final_output_nodes, output_node_info);
  3585. }
  3586. }
  3587. return SUCCESS;
  3588. }
  3589. Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph) {
  3590. GE_CHECK_NOTNULL(root_graph);
  3591. for (const NodePtr &node : root_graph->GetAllNodes()) {
  3592. if (node == nullptr || node->GetOpDesc() == nullptr) {
  3593. continue;
  3594. }
  3595. std::string model_data;
  3596. if (AttrUtils::GetStr(node->GetOpDesc(), kExternalModel, model_data) && !model_data.empty()) {
  3597. ge::Model model;
  3598. auto load_ret = ge::Model::Load(ge::PtrToPtr<char_t, const uint8_t>(model_data.data()), model_data.size(), model);
  3599. if (load_ret != GRAPH_SUCCESS) {
  3600. GELOGE(INTERNAL_ERROR, "[Parse][ExternalModel]Node:%s.", node->GetName().c_str());
  3601. REPORT_CALL_ERROR("E19999", "Failed to parse external model, node:%s.", node->GetName().c_str());
  3602. return INTERNAL_ERROR;
  3603. }
  3604. Graph graph = model.GetGraph();
  3605. GELOGD("Get subgraph[%s] from model[%s].", ParserUtils::GetGraphName(graph).c_str(), node->GetName().c_str());
  3606. Status ret = MappingAndAddSubGraph(node, graph, root_graph);
  3607. if (ret != SUCCESS) {
  3608. GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]Node:%s.", node->GetName().c_str());
  3609. REPORT_CALL_ERROR("E19999", "Failed to map and add sub graph, node:%s.", node->GetName().c_str());
  3610. return INTERNAL_ERROR;
  3611. }
  3612. }
  3613. }
  3614. return SUCCESS;
  3615. }
  3616. } // namespace ge
  3617. namespace domi {
  3618. REGISTER_MODEL_PARSER_CREATOR(TENSORFLOW, ge::TensorFlowModelParser);
  3619. REGISTER_WEIGHTS_PARSER_CREATOR(TENSORFLOW, ge::TensorFlowWeightsParser);
  3620. } // namespace domi