You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_parser.cc 187 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
4 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
3 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135
  1. /*
  2. * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parser/tensorflow/tensorflow_parser.h"
  17. #include <algorithm>
  18. #include <iostream>
  19. #include "ge/ge_api_types.h"
  20. #include "parser/common/convert/pb2json.h"
  21. #include "parser/common/acl_graph_parser_util.h"
  22. #include "common/util/error_manager/error_manager.h"
  23. #include "external/graph/operator_factory.h"
  24. #include "external/parser/tensorflow_parser.h"
  25. #include "external/register/scope/scope_fusion_pass_register.h"
  26. #include "framework/common/debug/ge_log.h"
  27. #include "framework/omg/parser/parser_api.h"
  28. #include "framework/omg/parser/parser_inner_ctx.h"
  29. #include "graph/debug/ge_attr_define.h"
  30. #include "graph/utils/graph_utils.h"
  31. #include "graph/utils/node_utils.h"
  32. #include "graph/utils/type_utils.h"
  33. #include "iterator_fusion_pass.h"
  34. #include "omg/parser/op_parser.h"
  35. #include "omg/parser/parser_factory.h"
  36. #include "parser/common/acl_graph_parser_util.h"
  37. #include "parser/common/model_saver.h"
  38. #include "parser/common/op_map.h"
  39. #include "parser/common/op_parser_factory.h"
  40. #include "parser/common/parser_fp16_t.h"
  41. #include "parser/common/pass_manager.h"
  42. #include "parser/common/prototype_pass_manager.h"
  43. #include "parser/common/thread_pool.h"
  44. #include "parser/common/parser_utils.h"
  45. #include "parser/common/util.h"
  46. #include "parser/tensorflow/tensorflow_custom_parser_adapter.h"
  47. #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h"
  48. #include "parser/tensorflow/tensorflow_fusion_op_parser.h"
  49. #include "parser/tensorflow/tensorflow_op_parser.h"
  50. #include "parser/tensorflow/tensorflow_util.h"
  51. #include "register/op_registry.h"
  52. #include "register/register_utils.h"
  53. #include "register/scope/scope_pass_registry_impl.h"
  54. #include "parser/common/auto_mapping_subgraph_io_index_func.h"
  55. #include "graph/def_types.h"
  56. using ge::OpParserFactory;
  57. using ge::Pb2Json;
  58. using ge::PreChecker;
  59. using ge::TENSORFLOW_ATTR_DATA_FORMAT;
  60. using ge::TENSORFLOW_ATTR_DTYPE;
  61. using ge::TENSORFLOW_ATTR_SHAPE;
  62. using ge::TENSORFLOW_ATTR_T;
  63. using ge::TENSORFLOW_ATTR_TYPE_STRING;
  64. using ge::TENSORFLOW_ATTR_TYPE_TENSOR;
  65. using ge::TENSORFLOW_ATTR_TYPE_TYPE;
  66. using ge::TENSORFLOW_ATTR_VALUE;
  67. using ge::TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG;
  68. using ge::TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG;
  69. using ge::tensorflow_op_map;
  70. using ge::tensorflow_train_op_map;
  71. using ge::TENSORFLOWF_NODE_OP_CONST;
  72. using ge::TENSORFLOWF_NODE_OP_IDENTITY;
  73. using ge::TENSORFLOWF_NODE_OP_MERGE;
  74. using ge::TENSORFLOWF_NODE_OP_PLACEHOLDER;
  75. using ge::TENSORFLOWF_NODE_OP_SWITCH;
  76. using ge::TENSORFLOWF_NODE_OP_TRANSPOSE;
  77. using ge::TENSORFLOWF_TENSOR_NCHW;
  78. using ge::TENSORFLOWF_TENSOR_NHWC;
  79. using ge::TensorFlowFusionCustomParserAdapter;
  80. using ge::TensorFlowFusionOpParser;
  81. using ge::TensorFlowOpParser;
  82. using ge::ThreadPool;
  83. using ge::parser::fp16_t;
  84. using ge::parser::ModelSaver;
  85. namespace ge {
  86. graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) {
  87. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  88. GE_CHECK_NOTNULL(model_file);
  89. GetParserContext().type = domi::TENSORFLOW;
  90. std::map<string, string> options;
  91. options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::TENSORFLOW)));
  92. // load custom plugin so and proto
  93. AclGraphParserUtil acl_graph_parse_util;
  94. if (acl_graph_parse_util.AclParserInitialize(options) != domi::SUCCESS) {
  95. GELOGE(GRAPH_FAILED, "Parser Initialize failed.");
  96. return GRAPH_FAILED;
  97. }
  98. // Create an empty computegraph
  99. ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph" +
  100. std::to_string(ge::parser::GetCurrentTimestamp()));
  101. if (compute_graph == nullptr) {
  102. REPORT_CALL_ERROR("E19999", "New ComputeGraph failed");
  103. GELOGE(FAILED, "Create ComputeGraph fail.");
  104. return FAILED;
  105. }
  106. graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  107. auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW);
  108. if (model_parser == nullptr) {
  109. REPORT_CALL_ERROR("E19999", "No Model Parser for tensorflow, check invalid");
  110. GELOGE(GRAPH_FAILED, "No Model Parser for tensorflow, check invalid");
  111. return FAILED;
  112. }
  113. // parse tensorflow model_file to GE graph
  114. ge::graphStatus ret = model_parser->Parse(model_file, graph);
  115. if (ret != ge::SUCCESS) {
  116. GELOGE(ret, "Parser graph %s failed.", ParserUtils::GetGraphName(graph).c_str());
  117. return ge::FAILED;
  118. }
  119. std::map<AscendString, AscendString> parser_params;
  120. if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
  121. GELOGE(ret, "Set graph %s default output node failed.", ParserUtils::GetGraphName(graph).c_str());
  122. return ge::FAILED;
  123. }
  124. GELOGI("Parser graph %s success.", ParserUtils::GetGraphName(graph).c_str());
  125. return ge::SUCCESS;
  126. }
  127. graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<AscendString, AscendString> &parser_params,
  128. ge::Graph &graph) {
  129. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  130. GE_CHECK_NOTNULL(model_file);
  131. GetParserContext().type = domi::TENSORFLOW;
  132. std::map<string, string> options;
  133. options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::TENSORFLOW)));
  134. // load custom plugin so and proto
  135. AclGraphParserUtil acl_graph_parse_util;
  136. domi::Status status = acl_graph_parse_util.AclParserInitialize(options);
  137. if (status != domi::SUCCESS) {
  138. GELOGE(GRAPH_FAILED, "Parser Initialize failed.");
  139. return GRAPH_FAILED;
  140. }
  141. string output_name;
  142. if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) {
  143. GELOGE(ge::FAILED, "Parser params before graph failed.");
  144. return ge::FAILED;
  145. }
  146. // Create an empty computegraph
  147. string graph_name = output_name.empty() ? ("tmpGraph" +
  148. std::to_string(ge::parser::GetCurrentTimestamp())) : output_name;
  149. ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name);
  150. if (compute_graph == nullptr) {
  151. REPORT_CALL_ERROR("E19999", "New ComputeGraph failed");
  152. GELOGE(FAILED, "Create ComputeGraph fail.");
  153. return FAILED;
  154. }
  155. graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  156. auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW);
  157. if (model_parser == nullptr) {
  158. REPORT_CALL_ERROR("E19999", "No Model Parser for tensorflow, check invalid");
  159. GELOGE(GRAPH_FAILED, "No Model Parser for tensorflow, check invalid");
  160. return FAILED;
  161. }
  162. // parse tensorflow model_file to GE graph
  163. ge::graphStatus ret = model_parser->Parse(model_file, graph);
  164. if (ret != ge::SUCCESS) {
  165. GELOGE(ret, "Parser graph %s failed.", ParserUtils::GetGraphName(graph).c_str());
  166. return ge::FAILED;
  167. }
  168. if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) {
  169. GELOGE(ge::FAILED, "Parser params after graph failed.");
  170. return ge::FAILED;
  171. }
  172. if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
  173. GELOGE(ge::FAILED, "Set graph %s default output node failed.", ParserUtils::GetGraphName(graph).c_str());
  174. return ge::FAILED;
  175. }
  176. GELOGI("AclgrphParse graph %s success.", ParserUtils::GetGraphName(graph).c_str());
  177. return ge::SUCCESS;
  178. }
  179. void AddDumpOriginName(const ge::NodePtr parent_node, const std::string& subgraph_name, ge::ComputeGraphPtr graph) {
  180. if (parent_node == nullptr) {
  181. return; // Root graph no need set dump origin name as parser always keep the origin node name
  182. }
  183. std::vector<std::string> original_names;
  184. (void)ge::AttrUtils::GetListStr(parent_node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
  185. if (original_names.empty()) {
  186. original_names.emplace_back(parent_node->GetName());
  187. }
  188. // for fusion node also used original_names[0]
  189. std::string prefix = original_names[0].append("/").append(subgraph_name).append("/");
  190. for (const ge::NodePtr &node : graph->GetDirectNode()) {
  191. original_names[0] = prefix + node->GetName();
  192. if (!ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names)) {
  193. GELOGW("Set dump origin name to %s fail.", node->GetOpDesc()->GetName().c_str());
  194. }
  195. GELOGD("Add dump origin name %s for node %s.", original_names[0].c_str(), node->GetName().c_str());
  196. }
  197. }
  198. } // namespace ge
  199. namespace ge {
  200. namespace {
  201. const int kTransposeInputIdx = 0;
  202. const uint32_t kThreadNum = 16;
  203. const size_t kInputNumUint = 2;
  204. const int kInputNumInt = 2;
  205. const int32_t kControlSlot = -1;
  206. const size_t kSoftmaxMultiple = 2;
  207. const set<string> kTfBlackFields = {"tensor_content"};
  208. const std::vector<std::string> kSkipCheckoutInputSizeNodes = {ge::parser::DATA, ge::parser::VARIABLE,
  209. ge::parser::FRAMEWORKOP, ge::parser::LAYERNORM};
  210. const std::vector<std::string> kMakeOperatorNotByIr = {ge::parser::ARG, ge::parser::VARIABLE, ge::parser::VARHANDLEOP,
  211. ge::parser::FRAMEWORKOP, ge::parser::DATA};
  212. const char *const kDpop = "DPOP";
  213. const char *const kFuncDefLibraryFilePath = "graph_def_library.pbtxt";
  214. const char *const kAttrNameIsScopeInnerNode = "_is_scope_inner_node";
  215. const char *const kExternalModel = "_external_model";
  216. struct ParseArg {
  217. const google::protobuf::Message *proto;
  218. std::string function_name;
  219. ge::NodePtr parent_node;
  220. std::string subgraph_name;
  221. ge::ComputeGraphPtr graph;
  222. };
  223. Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque<ParseArg> &args) {
  224. GELOGI("Gen subgraph parse tasks start");
  225. for (auto &node : parent_graph->GetDirectNode()) {
  226. auto op_desc = node->GetOpDesc();
  227. GE_CHECK_NOTNULL(op_desc);
  228. for (const auto &subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) {
  229. auto i = subgraph_name_to_index.second;
  230. auto subgraph_iname = op_desc->GetSubgraphInstanceName(i);
  231. if (subgraph_iname.empty()) {
  232. GELOGW("The subgraph index %u of node %s is empty", i, node->GetName().c_str());
  233. continue;
  234. }
  235. // A function may be referenced multiple times in TF, change the graph name to ensure it is unique in GE
  236. auto unique_name = node->GetName() + std::to_string(i) + subgraph_iname;
  237. auto subgraph = ge::parser::MakeShared<ge::ComputeGraph>(unique_name);
  238. if (subgraph == nullptr) {
  239. REPORT_CALL_ERROR("E19999", "New ComputeGraph failed when create subgraph:%s", subgraph_iname.c_str());
  240. GELOGE(OUT_OF_MEMORY, "Failed to alloc subgraph %s", subgraph_iname.c_str());
  241. return OUT_OF_MEMORY;
  242. }
  243. auto ret = ge::NodeUtils::SetSubgraph(*node, i, subgraph);
  244. if (ret != SUCCESS) {
  245. REPORT_CALL_ERROR("E19999", "Set subgraph:%s to node:%s(%s) failed, index:%u", subgraph_iname.c_str(),
  246. node->GetName().c_str(), node->GetType().c_str(), i);
  247. GELOGE(ret, "Set subgraph %s to node %s failed, index %u", subgraph_iname.c_str(), node->GetName().c_str(), i);
  248. return ret;
  249. }
  250. GELOGD("Add subgraph parse task to the queue, node %s, index %u, subgraph instance name %s",
  251. node->GetName().c_str(), i, subgraph_iname.c_str());
  252. args.push_back({nullptr, subgraph_iname, node, subgraph_name_to_index.first, subgraph});
  253. }
  254. }
  255. GELOGI("Gen subgraph parse tasks end");
  256. return SUCCESS;
  257. }
  258. Status PostOpProcessForSubgraph(const ParseArg &arg) {
  259. AddDumpOriginName(arg.parent_node, arg.subgraph_name, arg.graph);
  260. if (arg.parent_node == nullptr) {
  261. return SUCCESS;
  262. }
  263. std::string op_type = arg.parent_node->GetType();
  264. std::string op_name = arg.parent_node->GetName();
  265. domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr;
  266. auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type);
  267. if (post_func == nullptr) {
  268. GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str());
  269. if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS ||
  270. parse_func_v2 == nullptr) {
  271. GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str());
  272. return SUCCESS;
  273. }
  274. }
  275. GELOGD("Post process for subgraph %s node %s type %s subgraph name %s", arg.function_name.c_str(),
  276. arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str(), arg.subgraph_name.c_str());
  277. // refresh node_name in subgraph
  278. for (const ge::NodePtr &node : arg.graph->GetDirectNode()) {
  279. if ((node->GetOpDesc() == nullptr) || (node->GetType() == "Variable") || (node->GetType() == "VariableV2")) {
  280. continue;
  281. }
  282. node->GetOpDesc()->SetName(node->GetOwnerComputeGraph()->GetName() + "/" + node->GetName());
  283. }
  284. auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(arg.graph);
  285. Status ret = FAILED;
  286. if (post_func != nullptr) {
  287. ret = post_func(arg.subgraph_name, graph);
  288. } else if (parse_func_v2 != nullptr) {
  289. ret = parse_func_v2(arg.subgraph_name.c_str(), graph);
  290. }
  291. if (ret != SUCCESS) {
  292. REPORT_CALL_ERROR("E19999", "Call ParseSubgraphPostFunc:%s failed, subgraph:%s, node:%s(%s), ret:0x%X",
  293. arg.function_name.c_str(), arg.subgraph_name.c_str(), arg.parent_node->GetName().c_str(),
  294. arg.parent_node->GetType().c_str(), ret);
  295. GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s subgraph name %s", arg.function_name.c_str(),
  296. arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str(), arg.subgraph_name.c_str());
  297. return FAILED;
  298. }
  299. return SUCCESS;
  300. }
  301. Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, const ComputeGraphPtr &root_graph) {
  302. // Inner function, data format need be set by parant node
  303. GE_CHK_STATUS_RET(AutoMappingSubgraphDataFormat(node, graph),
  304. "[Call][AutoMappingSubgraphDataFormat] failed, node:%s, "
  305. "root graph:%s, graph:%s",
  306. node->GetName().c_str(), root_graph->GetName().c_str(),
  307. ParserUtils::GetGraphName(graph).c_str());
  308. // Inner function, input params have been checked by caller
  309. Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(
  310. graph,
  311. [](int in, int &out) -> Status {
  312. out = in;
  313. return SUCCESS;
  314. },
  315. [](int in, int &out) -> Status {
  316. out = in;
  317. return SUCCESS;
  318. });
  319. if (status != SUCCESS) {
  320. GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]node:%s, sub graph name:%s.", node->GetName().c_str(),
  321. ParserUtils::GetGraphName(graph).c_str());
  322. REPORT_CALL_ERROR("E19999", "Failed to map sub graph input and output, node:%s, sub graph name:%s.",
  323. node->GetName().c_str(), ParserUtils::GetGraphName(graph).c_str());
  324. return INTERNAL_ERROR;
  325. }
  326. ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
  327. GE_CHECK_NOTNULL(compute_graph);
  328. // Inner function, GetOpDesc has been checked by caller
  329. (void)node->GetOpDesc()->AddSubgraphName("f");
  330. auto ret = NodeUtils::SetSubgraph(*node, 0, compute_graph);
  331. if (ret != GRAPH_SUCCESS) {
  332. GELOGE(INTERNAL_ERROR, "[Set][Subgraph]Node:%s, sub graph name:%s.", node->GetName().c_str(),
  333. compute_graph->GetName().c_str());
  334. REPORT_CALL_ERROR("E19999", "Failed to set sub graph, node: %s, sub graph name: %s.", node->GetName().c_str(),
  335. compute_graph->GetName().c_str());
  336. return INTERNAL_ERROR;
  337. }
  338. for (const auto &sub_graph : compute_graph->GetAllSubgraphs()) {
  339. ret = root_graph->AddSubgraph(sub_graph);
  340. if (ret != GRAPH_SUCCESS) {
  341. GELOGE(INTERNAL_ERROR, "[Add][Subgraph]Node:%s, sub graph name:%s, sub sub graph name:%s.",
  342. node->GetName().c_str(), compute_graph->GetName().c_str(), sub_graph->GetName().c_str());
  343. REPORT_CALL_ERROR("E19999", "Failed to add sub graph to root graph, node:%s, sub graph name:%s.",
  344. node->GetName().c_str(), sub_graph->GetName().c_str());
  345. return INTERNAL_ERROR;
  346. }
  347. compute_graph->RemoveSubgraph(sub_graph);
  348. GELOGD("Add subgraph[%s] to root graph[%s].", sub_graph->GetName().c_str(), root_graph->GetName().c_str());
  349. }
  350. return SUCCESS;
  351. }
  352. } // namespace
  353. /*
  354. * @ingroup domi_omg
  355. * @brief Trans common decorate function to PartitionedCall.
  356. * @param [in] node_def: Node of common function.
  357. * @param [out] op: result of PartitionedCall OpDesc.
  358. * @return 0: SUCCESS / Others: FAILED
  359. */
  360. Status TensorFlowModelParser::DefunToPartitionedCall(const domi::tensorflow::NodeDef *node_def,
  361. ge::OpDescPtr &op) const {
  362. const string op_name = node_def->name();
  363. domi::tensorflow::AttrValue attr_call_inference;
  364. if (!ge::TensorFlowUtil::FindAttrValue(node_def, "_disable_call_shape_inference", attr_call_inference)) {
  365. ErrorManager::GetInstance().ATCReportErrMessage(
  366. "E19014", {"opname", "value", "reason"},
  367. {node_def->name(), "attr [_disable_call_shape_inference]",
  368. "may has no ir definition, if it is not a common decorate function operator"});
  369. GELOGE(FAILED,
  370. "Op %s has no ir definition, or has no attr [_disable_call_shape_inference] "
  371. "if it is a common decorate function operator.",
  372. op_name.c_str());
  373. return FAILED;
  374. }
  375. op = ge::parser::MakeShared<ge::OpDesc>(op_name, ge::parser::PARTITIONEDCALL);
  376. GE_CHECK_NOTNULL(op);
  377. size_t input_tensor_num = 0;
  378. size_t output_tensor_num = 0;
  379. GetInputOutputTensorNum(op, input_tensor_num, output_tensor_num);
  380. for (size_t i = 0; i < input_tensor_num; ++i) {
  381. ge::GeTensorDesc input_tensor;
  382. if (op->AddInputDesc(input_tensor) != ge::GRAPH_SUCCESS) {
  383. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", op->GetName().c_str(), op->GetType().c_str());
  384. GELOGE(FAILED, "op [%s] type[%s] add input(%zu) tensor failed.", op_name.c_str(), op->GetType().c_str(), i);
  385. return FAILED;
  386. }
  387. }
  388. for (size_t i = 0; i < output_tensor_num; ++i) {
  389. ge::GeTensorDesc output_tensor;
  390. if (op->AddOutputDesc(output_tensor) != ge::GRAPH_SUCCESS) {
  391. REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", op->GetName().c_str(), op->GetType().c_str());
  392. GELOGE(FAILED, "op [%s] type[%s] add output(%zu) tensor failed.", op_name.c_str(), op->GetType().c_str(), i);
  393. return FAILED;
  394. }
  395. }
  396. GELOGI("After AddTensorDescToOpDesc op[%s]: type[%s] have input size: %zu, output size: %zu, disable inference: %d",
  397. op_name.c_str(), op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize(), attr_call_inference.b());
  398. (void)op->AddSubgraphName("f");
  399. (void)op->SetSubgraphInstanceName(0, op_name);
  400. return SUCCESS;
  401. }
  402. Status TensorFlowModelParser::TransNodeToOpDesc(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op,
  403. const string &op_type) const {
  404. GE_CHECK_NOTNULL(node_def);
  405. string node_name = node_def->name();
  406. ge::Operator op_factory = ge::OperatorFactory::CreateOperator(node_name.c_str(), op_type.c_str());
  407. if (ParserUtils::GetOperatorName(op_factory) != node_name || op_type == ge::parser::DATA) {
  408. if (std::find(kMakeOperatorNotByIr.begin(), kMakeOperatorNotByIr.end(), op_type) != kMakeOperatorNotByIr.end()) {
  409. op = ge::parser::MakeShared<ge::OpDesc>(node_name, op_type);
  410. GE_CHECK_NOTNULL(op);
  411. } else if (node_name == op_type) {
  412. // Trans @tensorflow.python.framework.Defun(...) to PartitionedCall.
  413. GE_RETURN_IF_ERROR(DefunToPartitionedCall(node_def, op));
  414. GE_CHECK_NOTNULL(op);
  415. } else {
  416. ErrorManager::GetInstance().ATCReportErrMessage("E10501", {"opname", "optype"}, {node_name, op_type});
  417. GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str());
  418. return FAILED;
  419. }
  420. } else {
  421. op = ge::OpDescUtils::GetOpDescFromOperator(op_factory);
  422. GE_CHECK_NOTNULL(op);
  423. GELOGI("After GetOpDescFromOperator op[%s]: type[%s] has input size: %zu, output size: %zu", op->GetName().c_str(),
  424. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  425. GE_RETURN_IF_ERROR(AddTensorDescToOpDesc(op, node_def));
  426. GELOGI("After AddTensorDescToOpDesc op[%s]: type[%s] has input size: %zu, output size: %zu", op->GetName().c_str(),
  427. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  428. }
  429. op_factory.BreakConnect();
  430. return SUCCESS;
  431. }
  432. Status TensorFlowModelParser::ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op,
  433. const shared_ptr<OpParser> &op_parser) {
  434. GE_CHECK_NOTNULL(node_def);
  435. GE_CHECK_NOTNULL(op);
  436. GE_CHECK_NOTNULL(op_parser);
  437. string node_name = node_def->name();
  438. string node_op = node_def->op();
  439. Status status = FAILED;
  440. domi::ParseParamByOpFunc parse_param_by_op_fn = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(node_op);
  441. if (parse_param_by_op_fn == nullptr) {
  442. shared_ptr<TensorFlowOpParser> tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
  443. GE_CHECK_NOTNULL(tensorflow_op_parser);
  444. status = tensorflow_op_parser->ParseParams(node_def, op);
  445. if (status != SUCCESS) {
  446. GELOGE(status, "Parse params for node[%s] failed", node_name.c_str());
  447. return status;
  448. }
  449. } else {
  450. ge::Operator op_src(node_def->name().c_str(), node_def->op().c_str());
  451. status = domi::OperatorAutoMapping(node_def, op_src);
  452. if (status != SUCCESS) {
  453. REPORT_CALL_ERROR("E19999", "Auto mapping node_def:%s(%s) to operator failed", node_def->name().c_str(),
  454. node_def->op().c_str());
  455. GELOGE(status, "Node[%s] auto mapping failed.", node_name.c_str());
  456. return status;
  457. }
  458. std::shared_ptr<ge::TensorFlowCustomParserAdapter> tf_custom_op_parser =
  459. std::dynamic_pointer_cast<ge::TensorFlowCustomParserAdapter>(op_parser);
  460. GE_CHECK_NOTNULL(tf_custom_op_parser);
  461. status = tf_custom_op_parser->ParseParams(op_src, op);
  462. if (status != SUCCESS) {
  463. GELOGE(status, "Parse params for node[%s] failed", ParserUtils::GetOperatorName(op_src).c_str());
  464. return status;
  465. }
  466. }
  467. domi::tensorflow::AttrValue attr;
  468. if (ge::TensorFlowUtil::FindAttrValue(node_def, ATTR_NAME_QOS_SERVICE_LABEL, attr)) {
  469. (void)ge::AttrUtils::SetInt(*op, ATTR_NAME_QOS_SERVICE_LABEL, static_cast<int64_t>(attr.i()));
  470. }
  471. return SUCCESS;
  472. }
  473. Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def, ge::ComputeGraphPtr &graph,
  474. shared_ptr<ge::ScopeGraph> &scope_graph) {
  475. GE_CHECK_NOTNULL(node_def);
  476. GE_CHECK_NOTNULL(graph);
  477. GE_CHECK_NOTNULL(scope_graph);
  478. domi::tensorflow::AttrValue attr_value;
  479. if (ge::TensorFlowUtil::FindAttrValue(node_def, kAttrNameIsScopeInnerNode, attr_value) && attr_value.b()) {
  480. std::mutex graph_mutex;
  481. return AddScopeInnerNode(this, graph, &graph_mutex, node_def);
  482. }
  483. // node is released in destructor
  484. string node_name = node_def->name();
  485. string node_op = node_def->op();
  486. std::map<std::string, std::string>::const_iterator type_it = tensorflow_op_map.find(node_op);
  487. if (type_it == tensorflow_op_map.end()) {
  488. GELOGI("Can not find,maybe this node has no plugin node_name is %s, node_op is %s ", node_name.c_str(),
  489. node_op.c_str());
  490. ge::OpDescPtr op_desc;
  491. GE_RETURN_IF_ERROR(TransNodeToOpDesc(node_def, op_desc, node_op));
  492. ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc);
  493. GE_CHK_STATUS(domi::OperatorAutoMapping(node_def, op));
  494. op.BreakConnect();
  495. ge::NodePtr node = nullptr;
  496. node = graph->AddNode(op_desc);
  497. if (node == nullptr) {
  498. DeleteFuisonNodeDef();
  499. GELOGE(FAILED, "add node failed.");
  500. return INTERNAL_ERROR;
  501. }
  502. node_map_[node_name] = node;
  503. return SUCCESS;
  504. }
  505. string op_type = type_it->second;
  506. // The type value is obtained from the definition map set of DaVinci.
  507. ge::OpDescPtr op;
  508. GE_RETURN_IF_ERROR(TransNodeToOpDesc(node_def, op, op_type));
  509. bool needFusion = IsFusionOp(scope_graph, node_def);
  510. // The number of inputs and outputs of each operator can be determined after the new IR design model is resolved.
  511. // Add tensordesc to the opdesc object of the operator
  512. // Process change of tensordesc initialization of opdesc,
  513. // Previous process: Tensordesc is constructed according to graph structure in builder stage
  514. // Current process: Tensordesc is determined before the opdesc of the operator is added to the graph
  515. Status status = FAILED;
  516. // create OpParser
  517. shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
  518. GE_CHECK_NOTNULL(factory);
  519. if (!needFusion) {
  520. shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type);
  521. // parse op param
  522. status = ParseOpParams(node_def, op, op_parser);
  523. if (status != SUCCESS) {
  524. GELOGE(status, "Parse params for node[%s] failed", node_name.c_str());
  525. return status;
  526. }
  527. }
  528. GELOGI("After op parser op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(),
  529. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  530. // checkout op input number with IR
  531. GE_RETURN_IF_ERROR(CheckoutInputNum(op, node_def));
  532. ge::NodePtr node = graph->AddNode(op);
  533. if (node == nullptr) {
  534. DeleteFuisonNodeDef();
  535. GELOGE(FAILED, "add node failed.");
  536. return INTERNAL_ERROR;
  537. }
  538. node_map_[node_name] = node;
  539. if (needFusion) {
  540. shared_ptr<OpParser> fusion_op_parser = factory->CreateFusionOpParser(op_type);
  541. GE_CHECK_NOTNULL(fusion_op_parser);
  542. // Find all children of the fusion operator
  543. std::map<string, vector<const NodeDef *>>::const_iterator iter = fusion_op_nodedef_map_.find(node_def->name());
  544. if (iter == fusion_op_nodedef_map_.end()) {
  545. REPORT_INNER_ERROR("E19999", "FusionOp node %s has no children node, check invalid", node_name.c_str());
  546. GELOGE(FAILED, "FusionOp node %s has no children node!", node_name.c_str());
  547. return INTERNAL_ERROR;
  548. }
  549. vector<const domi::tensorflow::NodeDef *> node_def_v = iter->second;
  550. // parse fusion node param
  551. status = FusionNodeParseParams(fusion_op_parser, node_def, node);
  552. if (status != SUCCESS) {
  553. GELOGE(status, "Parse params for fusion node[%s] failed", node_name.c_str());
  554. return status;
  555. }
  556. // record original op names
  557. std::vector<std::string> namesTmp;
  558. for (auto &node_def_iter : node_def_v) {
  559. GE_CHECK_NOTNULL(node_def_iter);
  560. std::string nodeName = node_def_iter->name();
  561. namesTmp.push_back(nodeName);
  562. }
  563. ge::GraphUtils::RecordOriginalNames(namesTmp, node);
  564. status = RecordFusionResult(scope_graph, node_def, op);
  565. if (status != SUCCESS) {
  566. GELOGE(INTERNAL_ERROR, "Record fusion result for fusion op: %s failed", op->GetName().c_str());
  567. DeleteFuisonNodeDef();
  568. return status;
  569. }
  570. }
  571. return SUCCESS;
  572. }
  573. void TensorFlowModelParser::GetInputOutputTensorNum(const ge::OpDescPtr &op_desc, size_t &input_tensor_num,
  574. size_t &output_tensor_num) const {
  575. // The caller guarantees that the pointer is not null
  576. auto iter = op_node_context_map_.find(op_desc->GetName());
  577. if (iter == op_node_context_map_.end()) {
  578. return;
  579. }
  580. const OpNodeContext &op_context = iter->second;
  581. const std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &dest_input_map = op_context.input_map;
  582. // input number
  583. input_tensor_num = 0;
  584. for (auto &input_vec : dest_input_map) {
  585. for (auto &input_v : input_vec.second) {
  586. if (input_v.second != kControlSlot) {
  587. input_tensor_num++;
  588. }
  589. }
  590. }
  591. // output number
  592. const std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &src_output_map = op_context.output_map;
  593. int32_t max_anchor_index = 0;
  594. for (auto &src_output_iter : src_output_map) {
  595. for (auto &index_output_iter : src_output_iter.second) {
  596. if (index_output_iter.first > max_anchor_index) {
  597. max_anchor_index = index_output_iter.first;
  598. }
  599. }
  600. }
  601. output_tensor_num = max_anchor_index + 1;
  602. }
  603. Status TensorFlowModelParser::CheckoutInputNum(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node) const {
  604. GE_CHECK_NOTNULL(node);
  605. GE_CHECK_NOTNULL(op_desc);
  606. if (std::find(kSkipCheckoutInputSizeNodes.begin(), kSkipCheckoutInputSizeNodes.end(), op_desc->GetType()) !=
  607. kSkipCheckoutInputSizeNodes.end()) {
  608. return SUCCESS;
  609. }
  610. // get input and output tensor number
  611. size_t input_tensor_num = 0;
  612. size_t output_tensor_num = 0;
  613. GetInputOutputTensorNum(op_desc, input_tensor_num, output_tensor_num);
  614. // get input and output tensor number from op desc
  615. size_t factory_input_size = op_desc->GetInputsSize();
  616. if (input_tensor_num != factory_input_size) {
  617. ErrorManager::GetInstance().ATCReportErrMessage(
  618. "E19014", {"opname", "value", "reason"},
  619. {op_desc->GetName(), "input number of tensorflow[" + std::to_string(input_tensor_num) + "]",
  620. "should be equal to factory size[" + std::to_string(factory_input_size) + "]"});
  621. GELOGE(FAILED, "op [%s], type[%s], The input number of tensorflow[%zu] should be equal to factory size[%zu]",
  622. op_desc->GetName().c_str(), op_desc->GetType().c_str(), input_tensor_num, factory_input_size);
  623. return FAILED;
  624. }
  625. return SUCCESS;
  626. }
  627. void TensorFlowModelParser::UpdateInputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &input_desc,
  628. const size_t input_tensor_num) {
  629. // The caller guarantees that the pointer is not null
  630. for (size_t i = 0; i < input_tensor_num; ++i) {
  631. if (i < input_desc.size()) {
  632. // i is guaranteed to be valid, no check required.
  633. ge::graphStatus ret = op_desc->UpdateInputDesc(op_desc->GetInputNameByIndex(i), input_desc[i]);
  634. if (ret != ge::GRAPH_SUCCESS) {
  635. // UpdateInputDesc for dynamic intput will be failed, but it will be added in later op parser.
  636. GELOGI("op [%s], type[%s], input(%zu) with name %s is not updated", op_desc->GetName().c_str(),
  637. op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str());
  638. }
  639. } else {
  640. ge::GeTensorDesc input_tensor;
  641. // i is guaranteed to be valid, no check required.
  642. ge::graphStatus ret = op_desc->UpdateInputDesc(op_desc->GetInputNameByIndex(i), input_tensor);
  643. if (ret != ge::GRAPH_SUCCESS) {
  644. // UpdateInputDesc for dynamic intput will be failed, but it will be added in later op parser.
  645. GELOGI("op [%s], type[%s], input(%zu) with name %s is not updated", op_desc->GetName().c_str(),
  646. op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str());
  647. }
  648. }
  649. }
  650. }
  651. void TensorFlowModelParser::UpdateOutputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &output_desc,
  652. size_t output_tensor_num) {
  653. // The caller guarantees that the pointer is not null
  654. for (size_t i = 0; i < output_tensor_num; ++i) {
  655. if (i < output_desc.size()) {
  656. // i is guaranteed to be valid, no check required.
  657. ge::graphStatus ret = op_desc->UpdateOutputDesc(op_desc->GetOutputNameByIndex(i), output_desc[i]);
  658. if (ret != ge::GRAPH_SUCCESS) {
  659. // UpdateOutputDesc for dynamic output will be failed, but it will be added in later op parser.
  660. GELOGI("op [%s], type[%s], output(%zu) with name %s is not updated", op_desc->GetName().c_str(),
  661. op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str());
  662. }
  663. } else {
  664. ge::GeTensorDesc output_tensor;
  665. // i is guaranteed to be valid, no check required.
  666. ge::graphStatus ret = op_desc->UpdateOutputDesc(op_desc->GetOutputNameByIndex(i), output_tensor);
  667. if (ret != ge::GRAPH_SUCCESS) {
  668. // UpdateOutputDesc for dynamic output will be failed, but it will be added in later op parser.
  669. GELOGI("op [%s], type[%s], output(%zu) with name %s is not updated", op_desc->GetName().c_str(),
  670. op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str());
  671. }
  672. }
  673. }
  674. }
  675. Status TensorFlowModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc,
  676. const domi::tensorflow::NodeDef *node) const {
  677. GE_CHECK_NOTNULL(node);
  678. GE_CHECK_NOTNULL(op_desc);
  679. // get input and output attr from tensorflow
  680. const string type = node->op();
  681. domi::tensorflow::AttrValue input_attr_value;
  682. domi::tensorflow::AttrValue output_attr_value;
  683. ParserOperator temp_op;
  684. if (ge::TensorFlowUtil::FindAttrValue(node, ge::parser::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) {
  685. GE_CHK_STATUS_RET(ge::TensorFlowUtil::TransTensorDescriptor(input_attr_value, &temp_op,
  686. TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG, type),
  687. "trans input_attr_value failed, op: %s", node->name().c_str());
  688. } else {
  689. GELOGD("Frameworkop has no input tensor desc, name:%s, type:%s.", node->name().c_str(), type.c_str());
  690. }
  691. if (ge::TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) {
  692. GE_CHK_STATUS_RET(ge::TensorFlowUtil::TransTensorDescriptor(output_attr_value, &temp_op,
  693. TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG, type),
  694. "trans output_attr_value failed, op: %s", node->name().c_str());
  695. } else {
  696. GELOGD("Frameworkop has no output tensor desc, name:%s, type:%s.", node->name().c_str(), type.c_str());
  697. }
  698. auto iter = op_node_context_map_.find(op_desc->GetName());
  699. if (iter == op_node_context_map_.end()) {
  700. return SUCCESS;
  701. }
  702. const std::vector<ge::GeTensorDesc> &input_desc = temp_op.GetInputTensorDesc();
  703. const std::vector<ge::GeTensorDesc> &output_desc = temp_op.GetOutputTensorDesc();
  704. // get input and output tensor number
  705. size_t input_tensor_num = 0;
  706. size_t output_tensor_num = 0;
  707. GetInputOutputTensorNum(op_desc, input_tensor_num, output_tensor_num);
  708. // update input
  709. UpdateInputTensor(op_desc, input_desc, input_tensor_num);
  710. // update output
  711. UpdateOutputTensor(op_desc, output_desc, output_tensor_num);
  712. return SUCCESS;
  713. }
  714. Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) {
  715. GE_CHECK_NOTNULL(graph);
  716. for (auto &src_iter : op_node_context_map_) {
  717. string src_op_name = src_iter.first;
  718. OpNodeContext src_op_node_context = src_iter.second;
  719. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &src_output_map = src_op_node_context.output_map;
  720. // Traverse all output of the op_node
  721. for (auto &src_output_iter : src_output_map) {
  722. string dest_op_name = src_output_iter.first;
  723. auto dest_iter = op_node_context_map_.find(dest_op_name);
  724. if (dest_iter == op_node_context_map_.end()) {
  725. continue;
  726. }
  727. // Find that the output of the source node is equal to the destination node
  728. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &dest_input_map = dest_iter->second.input_map;
  729. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>>::const_iterator
  730. input_iter = dest_input_map.find(src_op_name);
  731. // Find output and input
  732. if (input_iter == dest_input_map.end()) {
  733. continue;
  734. }
  735. auto iter = node_map_.find(src_op_name);
  736. if (iter == node_map_.end()) {
  737. continue;
  738. }
  739. ge::NodePtr src = iter->second;
  740. GE_CHECK_NOTNULL(src);
  741. auto iter1 = node_map_.find(dest_op_name);
  742. if (iter1 == node_map_.end()) {
  743. continue;
  744. }
  745. // Each pair builds an edge
  746. ge::NodePtr dest = iter1->second;
  747. GE_CHECK_NOTNULL(dest);
  748. if (src_output_iter.second.size() != input_iter->second.size()) {
  749. REPORT_INNER_ERROR("E19999", "Input size of op[%s]:%zu is not equal to Output size of op[%s]:%zu.",
  750. src_op_name.c_str(), input_iter->second.size(), dest_op_name.c_str(),
  751. src_output_iter.second.size());
  752. GELOGE(INTERNAL_ERROR, "Input size of op[%s]:%zu is not equal to Output size of op[%s]:%zu.",
  753. src_op_name.c_str(), input_iter->second.size(), dest_op_name.c_str(), src_output_iter.second.size());
  754. return INTERNAL_ERROR;
  755. }
  756. for (auto &outputpair : src_output_iter.second) {
  757. // Get control edge properties
  758. bool control = GetEdgesControlInfo(dest_op_name, outputpair.second);
  759. // Graph create new edge
  760. if (!control) {
  761. GELOGD("Start add edge: from %s:%d to %s:%d.", src->GetName().c_str(), outputpair.first,
  762. dest->GetName().c_str(), outputpair.second);
  763. ge::OutDataAnchorPtr out_archor_ptr = src->GetOutDataAnchor(outputpair.first);
  764. GE_CHECK_NOTNULL(out_archor_ptr);
  765. ge::InDataAnchorPtr in_archor_ptr = dest->GetInDataAnchor(outputpair.second);
  766. GE_CHECK_NOTNULL(in_archor_ptr);
  767. if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) {
  768. REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed",
  769. src->GetName().c_str(), dest->GetName().c_str());
  770. GELOGE(FAILED, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), dest->GetName().c_str());
  771. return INTERNAL_ERROR;
  772. }
  773. } else {
  774. GELOGD("Start add contorl edge: from %s to %s.", src->GetName().c_str(), dest->GetName().c_str());
  775. ge::InControlAnchorPtr in_archor_ptr = dest->GetInControlAnchor();
  776. GE_CHECK_NOTNULL(in_archor_ptr);
  777. ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor();
  778. GE_CHECK_NOTNULL(out_archor_ptr);
  779. if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) {
  780. REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed",
  781. src->GetName().c_str(), dest->GetName().c_str());
  782. GELOGE(FAILED, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), dest->GetName().c_str());
  783. return INTERNAL_ERROR;
  784. }
  785. }
  786. }
  787. dest_input_map.erase(input_iter);
  788. }
  789. }
  790. return SUCCESS;
  791. }
  792. Status TensorFlowModelParser::AddFmkNodeDefToMap(const domi::tensorflow::NodeDef *node_def,
  793. vector<string> &op_node_name_list) {
  794. GE_CHECK_NOTNULL(node_def);
  795. const string &node_name = node_def->name();
  796. nodedef_map_[node_name] = node_def;
  797. OpNodeContext op_node_context;
  798. op_node_context_map_[node_name] = op_node_context;
  799. op_node_name_list.push_back(node_name);
  800. return SUCCESS;
  801. }
  802. Status TensorFlowModelParser::CheckOpShapeDim(const domi::tensorflow::NodeDef *node_def, const std::set<int> &dims,
  803. bool &valid) {
  804. GE_CHECK_NOTNULL(node_def);
  805. domi::tensorflow::AttrValue input_attr_value;
  806. bool is_attr_exist =
  807. ge::TensorFlowUtil::FindAttrValue(node_def, ge::parser::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value);
  808. GE_IF_BOOL_EXEC(!is_attr_exist, return SUCCESS);
  809. GE_CHK_BOOL_EXEC(input_attr_value.has_list(),
  810. REPORT_INNER_ERROR("E19999", "Attr:%s of node_def:%s(%s) is empty, check invalid",
  811. ge::parser::ATTR_NAME_INPUT_TENSOR_DESC.c_str(), node_def->name().c_str(),
  812. node_def->op().c_str());
  813. return PARAM_INVALID, "output attr value vector is empty");
  814. // list contain many TensorDescriptors
  815. domi::tensorflow::AttrValue_ListValue a_list = input_attr_value.list();
  816. for (int32_t i = 0; i < a_list.func_size(); i++) {
  817. ge::GeTensorDesc ge_desc;
  818. int32_t tf_datatype = 0;
  819. GE_CHK_BOOL_RET_STATUS(ge::TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, i, tf_datatype), PARAM_INVALID,
  820. "parse ge_desc failed.");
  821. for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) {
  822. int64_t temp_dim = ge_desc.GetShape().GetDim(j);
  823. GE_IF_BOOL_EXEC(dims.count(temp_dim) > 0, valid = false);
  824. }
  825. }
  826. return SUCCESS;
  827. }
  828. Status TensorFlowModelParser::CheckOpType(const domi::tensorflow::NodeDef *node_def, string &op_type) {
  829. GE_CHECK_NOTNULL(node_def);
  830. bool valid = true;
  831. string node_name = node_def->name();
  832. std::map<std::string, set<int>> check_dims = {
  833. {ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, {10}},
  834. };
  835. GE_IF_BOOL_EXEC(
  836. op_type == ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS,
  837. GE_CHK_STATUS_RET(CheckOpShapeDim(node_def, check_dims[op_type], valid), "failed to check op shape");
  838. GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP;
  839. GELOGI("Set op %s to frameworkop", node_name.c_str());
  840. framework_ops_[node_name] = node_def;
  841. );
  842. );
  843. GE_IF_BOOL_EXEC(
  844. op_type == ge::parser::ADD || op_type == ge::parser::MULTIPLY || op_type == ge::parser::MEAN,
  845. for (const string &input_name
  846. : node_def->input()) {
  847. string tmp_input_name;
  848. GE_RETURN_IF_ERROR(CheckInputNodeName(input_name, &tmp_input_name, nullptr, nullptr));
  849. GELOGD("Add or Mul op %s input name is %s", node_name.c_str(), input_name.c_str());
  850. GE_IF_BOOL_EXEC(framework_ops_.find(tmp_input_name) != framework_ops_.end(),
  851. GELOGI("Set op %s to frameworkop", node_name.c_str());
  852. op_type = ge::parser::FRAMEWORKOP;);
  853. });
  854. return SUCCESS;
  855. }
  856. /*
  857. * @ingroup domi_omg
  858. * @brief Mapping TF's datatype to GE's datatype
  859. * @param [in] type, datatype types of operators in TF networks
  860. * @return ge::DataType
  861. */
  862. ge::DataType TensorFlowModelParser::ConvertToGeDataType(const uint32_t type) {
  863. ErrorManager::GetInstance().GenWorkStreamIdDefault();
  864. ge::DataType data_type = domi::TensorAssign::ConvertTensorflowDataType(type);
  865. return data_type;
  866. }
  867. Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph,
  868. std::mutex *graphMutex, shared_ptr<ge::ScopeGraph> &scope_graph,
  869. const domi::tensorflow::NodeDef *node_def,
  870. error_message::Context error_context) {
  871. ErrorManager::GetInstance().SetErrorContext(error_context);
  872. // The caller guarantees that the pointer is not null
  873. string node_name = node_def->name();
  874. string node_op = node_def->op();
  875. GELOGD("TF op node name = %s, op type= %s", node_name.c_str(), node_op.c_str());
  876. domi::tensorflow::AttrValue attr_value;
  877. if (ge::TensorFlowUtil::FindAttrValue(node_def, kAttrNameIsScopeInnerNode, attr_value) && attr_value.b()) {
  878. return AddScopeInnerNode(parser, graph, graphMutex, node_def);
  879. }
  880. std::map<std::string, std::string>::const_iterator iterator = parser->adaptedOpTypeMap_.find(node_name);
  881. if (iterator == parser->adaptedOpTypeMap_.cend()) {
  882. REPORT_INNER_ERROR("E19999", "get adapted op type failed, node name = %s", node_name.c_str());
  883. GELOGE(FAILED, "get adapted op type failed, node name = %s", node_name.c_str());
  884. return FAILED;
  885. }
  886. string op_type = iterator->second;
  887. // Log printing for determining operator type
  888. domi::ImplyType implyType = domi::OpRegistry::Instance()->GetImplyType(op_type);
  889. GE_IF_BOOL_EXEC((implyType == domi::ImplyType::TVM) && (op_type != ge::parser::FRAMEWORKOP),
  890. GELOGD("TBE %s parsering", node_op.c_str()););
  891. GE_IF_BOOL_EXEC((implyType == domi::ImplyType::CCE) && (op_type != ge::parser::FRAMEWORKOP),
  892. GELOGD("CCE %s parsering", node_op.c_str()););
  893. GE_IF_BOOL_EXEC((implyType == domi::ImplyType::HCCL) && (op_type != ge::parser::FRAMEWORKOP),
  894. GELOGD("HCCL %s parsering", node_op.c_str()););
  895. GE_IF_BOOL_EXEC(op_type == ge::parser::FRAMEWORKOP,
  896. GELOGD("FRAMEWORKOP %s parsering", node_op.c_str()););
  897. GELOGD("TF op node name = %s, op type= %s, trans to op type %s", node_name.c_str(), node_op.c_str(), op_type.c_str());
  898. // Construct operator by IR
  899. ge::OpDescPtr op;
  900. ge::Operator op_factory = ge::OperatorFactory::CreateOperator(node_name.c_str(), op_type.c_str());
  901. if (ParserUtils::GetOperatorName(op_factory) != node_name) {
  902. if (std::find(kMakeOperatorNotByIr.begin(), kMakeOperatorNotByIr.end(), op_type) != kMakeOperatorNotByIr.end()) {
  903. op = ge::parser::MakeShared<ge::OpDesc>(node_name, op_type);
  904. GE_CHECK_NOTNULL(op);
  905. } else if (node_name == op_type) {
  906. GE_RETURN_IF_ERROR(parser->DefunToPartitionedCall(node_def, op));
  907. GE_CHECK_NOTNULL(op);
  908. ge::Operator op_tmp = ge::OpDescUtils::CreateOperatorFromOpDesc(op);
  909. GE_CHK_STATUS(domi::OperatorAutoMapping(node_def, op_tmp));
  910. op_tmp.BreakConnect();
  911. ge::NodePtr node;
  912. {
  913. std::lock_guard<std::mutex> lock(*graphMutex);
  914. node = graph->AddNode(op);
  915. }
  916. GE_CHECK_NOTNULL(node);
  917. {
  918. std::lock_guard<std::mutex> lock(parser->nodeMapMutex_);
  919. parser->node_map_[node_name] = node;
  920. }
  921. return SUCCESS;
  922. } else {
  923. REPORT_INPUT_ERROR("E10501", std::vector<std::string>({"opname", "optype"}),
  924. std::vector<std::string>({node_name, op_type}));
  925. GELOGE(INTERNAL_ERROR, "op[%s] type[%s] have no ir factory.]", node_name.c_str(), op_type.c_str());
  926. return FAILED;
  927. }
  928. } else {
  929. op = ge::OpDescUtils::GetOpDescFromOperator(op_factory);
  930. GE_CHECK_NOTNULL(op);
  931. GELOGD("After GetOpDescFromOperator op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(),
  932. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  933. GE_RETURN_IF_ERROR(parser->AddTensorDescToOpDesc(op, node_def));
  934. GELOGD("After AddTensorDescToOpDesc op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(),
  935. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  936. }
  937. GELOGD("TF op node name = %s, outpusize= %zu", node_name.c_str(), op->GetAllOutputsDesc().size());
  938. op_factory.BreakConnect();
  939. // create OpParser
  940. shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
  941. GE_CHECK_NOTNULL(factory);
  942. bool needFusion = parser->IsFusionOp(scope_graph, node_def);
  943. GELOGD("TF op node name = %s, op type= %s is fusion op(NO: 0; YES: 1)= %d", node_name.c_str(), node_op.c_str(),
  944. needFusion);
  945. Status status = FAILED;
  946. if (!needFusion) {
  947. shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type);
  948. status = parser->ParseOpParams(node_def, op, op_parser);
  949. if (status != SUCCESS) {
  950. GELOGE(status, "Parse params for node[%s] failed", node_name.c_str());
  951. return status;
  952. }
  953. }
  954. GELOGD("After op parser op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(),
  955. op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize());
  956. // checkout op input number with IR
  957. GE_RETURN_IF_ERROR(parser->CheckoutInputNum(op, node_def));
  958. if (needFusion) {
  959. status = RecordFusionResult(scope_graph, node_def, op);
  960. if (status != SUCCESS) {
  961. GELOGE(INTERNAL_ERROR, "Record fusion result for fusion op: %s failed", op->GetName().c_str());
  962. return status;
  963. }
  964. }
  965. ge::NodePtr node;
  966. {
  967. std::lock_guard<std::mutex> lock(*graphMutex);
  968. node = graph->AddNode(op);
  969. }
  970. if (node == nullptr) {
  971. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op->GetName().c_str(),
  972. op->GetType().c_str(), graph->GetName().c_str());
  973. GELOGE(FAILED, "add node failed.");
  974. return INTERNAL_ERROR;
  975. }
  976. if (needFusion) {
  977. shared_ptr<OpParser> fusion_op_parser = factory->CreateFusionOpParser(op_type);
  978. status = parser->FusionNodeParseParams(fusion_op_parser, node_def, node);
  979. GE_CHK_STATUS_EXEC(status, return status, "Parse Params for node %s failed", node_name.c_str());
  980. }
  981. {
  982. std::lock_guard<std::mutex> lock(parser->nodeMapMutex_);
  983. parser->node_map_[node_name] = node;
  984. }
  985. return SUCCESS;
  986. }
  987. Status TensorFlowModelParser::AdaptOpType(const domi::tensorflow::NodeDef *node_def, bool isDatasetInit) {
  988. // The caller guarantees that the pointer is not null
  989. string node_name = node_def->name();
  990. string node_op = node_def->op();
  991. string op_type;
  992. if (tensorflow_train_op_map.find(node_op) != tensorflow_train_op_map.end()) {
  993. op_type = tensorflow_train_op_map.at(node_op);
  994. GE_CHK_STATUS_RET(CheckOpType(node_def, op_type), "Failed to check op type");
  995. } else {
  996. op_type = ge::parser::FRAMEWORKOP;
  997. domi::tensorflow::AttrValue attr_call_inference;
  998. if ((node_name == node_op) &&
  999. ge::TensorFlowUtil::FindAttrValue(node_def, "_disable_call_shape_inference", attr_call_inference)) {
  1000. op_type = node_op;
  1001. }
  1002. }
  1003. GE_IF_BOOL_EXEC(isDatasetInit, op_type = ge::parser::FRAMEWORKOP);
  1004. adaptedOpTypeMap_[node_name] = op_type;
  1005. return SUCCESS;
  1006. }
  1007. Status TensorFlowModelParser::AddFmkNode(ge::ComputeGraphPtr &graph, shared_ptr<ge::ScopeGraph> &scope_graph,
  1008. vector<string> &op_node_name_list, bool is_dataset_init) {
  1009. GE_CHECK_NOTNULL(graph);
  1010. GE_CHECK_NOTNULL(scope_graph);
  1011. GE_RETURN_IF_ERROR(AddFusionNodeDef(scope_graph, op_node_name_list));
  1012. size_t op_node_list_size = op_node_name_list.size();
  1013. for (size_t i = 0; i < op_node_list_size; ++i) {
  1014. const string op_node_name = op_node_name_list[i];
  1015. const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name];
  1016. GE_CHECK_NOTNULL(node_def);
  1017. GE_RETURN_IF_ERROR(AdaptOpType(node_def, is_dataset_init));
  1018. }
  1019. GELOGD("Add fusion nodedef and Adapt op type success");
  1020. // Multithreading parallel parsing nodedef
  1021. ThreadPool executor(kThreadNum);
  1022. std::mutex graphMutex;
  1023. std::vector<std::future<Status>> vectorFuture(op_node_list_size);
  1024. ge::ComputeGraphPtr graph_tmp = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
  1025. GE_CHECK_NOTNULL(graph_tmp);
  1026. for (size_t j = 0; j < op_node_list_size; j++) {
  1027. const string op_node_name = op_node_name_list[j];
  1028. const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name];
  1029. GE_CHECK_NOTNULL(node_def);
  1030. std::future<Status> f =
  1031. executor.commit(TensorFlowModelParser::ParseNodeDef, this, graph_tmp, &graphMutex, scope_graph, node_def,
  1032. ErrorManager::GetInstance().GetErrorManagerContext());
  1033. if (!f.valid()) {
  1034. GELOGE(FAILED, "Future is invalid");
  1035. return FAILED;
  1036. }
  1037. vectorFuture[j] = std::move(f);
  1038. }
  1039. GELOGD("Parse nodedef success");
  1040. // Wait for the return value of each thread. If the thread does not finish processing, it will block here
  1041. bool ret_flag = true;
  1042. size_t futureSize = vectorFuture.size();
  1043. for (size_t i = 0; i < futureSize; ++i) {
  1044. Status retStatus = vectorFuture[i].get();
  1045. if (retStatus != SUCCESS) {
  1046. ret_flag = false;
  1047. }
  1048. }
  1049. if (!ret_flag) {
  1050. return FAILED;
  1051. }
  1052. return AddNodeToGraphAndMarkFormat(graph, op_node_name_list);
  1053. }
  1054. Status TensorFlowModelParser::AddNodeToGraphAndMarkFormat(ge::ComputeGraphPtr &graph,
  1055. const vector<string> &op_node_name_list) {
  1056. // Add ge:: nodeptr to graph in order
  1057. size_t op_node_list_size = op_node_name_list.size();
  1058. for (size_t j = 0; j < op_node_list_size; j++) {
  1059. const string op_node_name = op_node_name_list[j];
  1060. auto iterator = node_map_.find(op_node_name);
  1061. if (iterator == node_map_.end()) {
  1062. REPORT_INNER_ERROR("E19999", "node:%s can't find in node_map_, check invalid", op_node_name.c_str());
  1063. GELOGE(FAILED, "add node failed.");
  1064. return INTERNAL_ERROR;
  1065. }
  1066. GE_CHECK_NOTNULL(iterator->second);
  1067. GE_CHK_STATUS_RET(iterator->second->SetOwnerComputeGraph(graph), "set owner compute graph failed");
  1068. graph->AddNode(iterator->second);
  1069. }
  1070. return SUCCESS;
  1071. }
  1072. Status TensorFlowModelParser::ExcuteScopeFusionPasses(domi::tensorflow::GraphDef *const graph_def,
  1073. shared_ptr<ge::ScopeGraph> &scope_graph) {
  1074. // Identifying scope fusion operators based on scope rules
  1075. GE_CHECK_NOTNULL(graph_def);
  1076. ScopePassManager passmanager;
  1077. PARSER_TIMESTAMP_START(BuildScopeGraph);
  1078. scope_graph = passmanager.BuildScopeGraph(graph_def);
  1079. GE_CHECK_NOTNULL(scope_graph);
  1080. PARSER_TIMESTAMP_END(BuildScopeGraph, "TensorFlowModelParser::BuildScopeGraph");
  1081. PARSER_TIMESTAMP_START(ScopeGraphPass);
  1082. // Validate the non-general scope fusion pass.
  1083. // The parameter is set to the name of the fusion rule.
  1084. // Multiple names can be set and separated by ",".
  1085. std::vector<std::string> enable_pass_names =
  1086. ge::StringUtils::Split(ge::GetParserContext().enable_scope_fusion_passes, ',');
  1087. auto &impl = ge::ScopeFusionPassRegistry::GetInstance().impl_;
  1088. if (impl == nullptr) {
  1089. REPORT_INNER_ERROR("E19999", "ScopeFusionPassRegistry is not properly initialized.");
  1090. GELOGE(ge::MEMALLOC_FAILED, "ScopeFusionPassRegistry is not properly initialized.");
  1091. return ge::MEMALLOC_FAILED;
  1092. }
  1093. for (size_t i = 0; i < enable_pass_names.size(); ++i) {
  1094. if (enable_pass_names[i].empty()) {
  1095. continue;
  1096. }
  1097. if (!impl->SetPassEnableFlag(enable_pass_names[i], true)) {
  1098. GELOGW("Failed to set enable flag of scope fusion pass:%s", enable_pass_names[i].c_str());
  1099. }
  1100. }
  1101. std::vector<std::string> scope_passes_list = impl->GetAllRegisteredPasses();
  1102. Status ret = RunScopeFusionPass(scope_passes_list, passmanager, scope_graph);
  1103. if (ret != SUCCESS) {
  1104. GELOGE(ret, "Run scope fusion failed, ret:%u.", ret);
  1105. return ret;
  1106. }
  1107. PARSER_TIMESTAMP_END(ScopeGraphPass, "TensorFlowModelParser::ScopeGraphPass");
  1108. return SUCCESS;
  1109. }
  1110. Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) {
  1111. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  1112. GE_CHECK_NOTNULL(data);
  1113. GE_CHECK_NOTNULL(graph);
  1114. // Store objects parsed from pb files
  1115. domi::tensorflow::GraphDef OriDef;
  1116. bool read = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &OriDef);
  1117. if (!read) {
  1118. REPORT_INNER_ERROR("E19999", "read graph proto from binary failed");
  1119. GELOGE(FAILED, "read_proto_from_binary failed.");
  1120. return INTERNAL_ERROR;
  1121. }
  1122. domi::tensorflow::GraphDef graph_def;
  1123. const bool is_empty_input = GetParserContext().input_dims.empty() && GetParserContext().out_nodes_map.empty();
  1124. if (is_empty_input) {
  1125. graph_def = OriDef;
  1126. } else {
  1127. GELOGI("Before Trim, the Graph Node size is:%d", OriDef.node_size());
  1128. if (static_cast<bool>(TrimGraph(OriDef, &graph_def))) {
  1129. GELOGE(FAILED, "Trim Graph fail.");
  1130. return INTERNAL_ERROR;
  1131. }
  1132. GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size());
  1133. }
  1134. GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW),
  1135. "Run ProtoType Pass Failed");
  1136. shared_ptr<ge::ScopeGraph> scope_graph = nullptr;
  1137. Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph);
  1138. if (ret != SUCCESS) {
  1139. GELOGE(ret, "[TF ParseFromMemory] scope fusion failed.");
  1140. return ret;
  1141. }
  1142. GELOGD("[TF ParseFromMemory] scope fusion success");
  1143. // Add nodedef in the model to prechecker and check the general parameters
  1144. for (int i = 0; i < graph_def.node_size(); i++) {
  1145. const domi::tensorflow::NodeDef &node = graph_def.node(i);
  1146. GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()),
  1147. "Add node_def to PreChecker failed, node name: %s.", node.name().c_str());
  1148. GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().CheckName(&node), "Check node_def name failed, node name: %s.",
  1149. node.name().c_str());
  1150. if (node.op() != TENSORFLOWF_NODE_OP_IDENTITY) {
  1151. GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().CheckType(&node, true),
  1152. "Check node_def type failed, node name: %s.", node.name().c_str());
  1153. }
  1154. }
  1155. bool has_error = false;
  1156. // save node name
  1157. vector<string> op_node_name_list;
  1158. for (int i = 0; i < graph_def.node_size(); i++) {
  1159. const domi::tensorflow::NodeDef *node_def = graph_def.mutable_node(i);
  1160. // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_
  1161. GE_IF_BOOL_EXEC(MaybeFusionOp(scope_graph, node_def),
  1162. GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str()););
  1163. // Do not exit immediately when there is an error, wait until all errors are collected before exiting
  1164. GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(node_def, op_node_name_list), has_error = true,
  1165. "add node failed.");
  1166. }
  1167. // The fusion operator has passed the verification.
  1168. // The errors of internal non key operators (which will be ignored later)
  1169. // do not affect the transformation of the whole model,
  1170. // So clear the error information of non key operators
  1171. // This function call affects the return value of prechecker::instance().Haserror()
  1172. GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list));
  1173. // Building input and input relationships for all OP nodes
  1174. GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def));
  1175. GELOGD("[TF ParseFromMemory] get op nodes context from graph success");
  1176. // Infer input formats
  1177. ge::GetParserContext().format = InferInputFormats();
  1178. GELOGD("[TF ParseFromMemory] infer input formats success");
  1179. // Building input-output relationship between fusionop and common op
  1180. GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, op_node_name_list));
  1181. ret = AddFusionNodeDef(scope_graph, op_node_name_list);
  1182. if (ret != SUCCESS) {
  1183. GELOGE(ret, "Add fusion NodeDef failed.");
  1184. DeleteFuisonNodeDef();
  1185. return ret;
  1186. }
  1187. GELOGI("TF op node size = %zu.", op_node_name_list.size());
  1188. // Loop analysis of op_nodes and map them to nodes in graph
  1189. for (size_t i = 0; i < op_node_name_list.size(); i++) {
  1190. GELOGI("TF op node name = %s.", op_node_name_list[i].c_str());
  1191. const string op_node_name = op_node_name_list[i];
  1192. const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name_list[i]];
  1193. if (node_def == nullptr) {
  1194. REPORT_INNER_ERROR("E19999", "Node:%s can't find in nodedef_map_, check invalid", op_node_name.c_str());
  1195. GELOGE(INTERNAL_ERROR, "Node def is nullptr, name:%s.", op_node_name.c_str());
  1196. DeleteFuisonNodeDef();
  1197. return INTERNAL_ERROR;
  1198. }
  1199. const string &node_op = node_def->op();
  1200. if (tensorflow_op_map.find(node_op) == tensorflow_op_map.cend()) {
  1201. DeleteFuisonNodeDef();
  1202. REPORT_INNER_ERROR("E19999", "Op type %s unsupport", node_op.c_str());
  1203. GELOGE(FAILED, "Unsupport op type %s", node_op.c_str());
  1204. return INTERNAL_ERROR;
  1205. }
  1206. ret = AddNode(node_def, graph, scope_graph);
  1207. if (ret != SUCCESS) {
  1208. GELOGE(ret, "Add node failed, name:%s.", op_node_name.c_str());
  1209. DeleteFuisonNodeDef();
  1210. return ret;
  1211. }
  1212. }
  1213. DeleteFuisonNodeDef();
  1214. GE_RETURN_IF_ERROR(AddEdges(graph));
  1215. GE_RETURN_IF_ERROR(graph->TopologicalSorting());
  1216. has_error = has_error || PreChecker::Instance().HasError();
  1217. if (has_error) {
  1218. GELOGE(PARAM_INVALID, "Precheck has errors.");
  1219. return PARAM_INVALID;
  1220. }
  1221. GELOGI("[TF ParseFromMemory] Parse from memory success.");
  1222. return SUCCESS;
  1223. }
  1224. Status TensorFlowModelParser::GetFunctionProto(const string &file,
  1225. domi::tensorflow::GraphDefLibrary &graph_def_library) {
  1226. int pos = file.rfind('/');
  1227. string graph_def_path = (pos == -1) ? kFuncDefLibraryFilePath : file.substr(0, pos) + "/" + kFuncDefLibraryFilePath;
  1228. GELOGI("Function def libraray path is %s.", graph_def_path.c_str());
  1229. bool read = ge::parser::ReadProtoFromText(graph_def_path.c_str(), &graph_def_library);
  1230. if (!read) {
  1231. GELOGE(INTERNAL_ERROR,
  1232. "Get subgraph library failed. "
  1233. "The model contains function operators. "
  1234. "Need to use the script func2graph.py in the atc package to save the subgraphs to graph_def_library.pbtxt");
  1235. ErrorManager::GetInstance().ATCReportErrMessage("E12029");
  1236. return FAILED;
  1237. }
  1238. GELOGI("Get subgraph library success.");
  1239. return SUCCESS;
  1240. }
  1241. Status TensorFlowModelParser::Parse(const char *file, ge::Graph &graph) {
  1242. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  1243. GE_CHECK_NOTNULL(file);
  1244. ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
  1245. GE_CHECK_NOTNULL(root_graph);
  1246. Status ret = Parse(file, root_graph);
  1247. if (ret != SUCCESS) {
  1248. GELOGE(ret, "Parser graph %s failed.", ParserUtils::GetGraphName(graph).c_str());
  1249. return ret;
  1250. }
  1251. GELOGI("Parser graph %s success.", ParserUtils::GetGraphName(graph).c_str());
  1252. return SUCCESS;
  1253. }
  1254. Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &root_graph) {
  1255. GE_CHECK_NOTNULL(model_path);
  1256. GE_CHECK_NOTNULL(root_graph);
  1257. GELOGI("Parse file %s", model_path);
  1258. // Store objects parsed from pb files
  1259. domi::tensorflow::GraphDef ori_def;
  1260. bool read = ge::parser::ReadProtoFromBinaryFile(model_path, &ori_def);
  1261. if (!read) {
  1262. GELOGE(FAILED, "read tensorflow file failed when the inupt param value of --framework is 3.");
  1263. return INTERNAL_ERROR;
  1264. }
  1265. // Trim graph by user input and output.
  1266. domi::tensorflow::GraphDef graph_def;
  1267. if (ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) {
  1268. graph_def = ori_def;
  1269. } else {
  1270. GELOGI("Before Trim, the Graph Node size is:%d", ori_def.node_size());
  1271. if (static_cast<bool>(TrimGraph(ori_def, &graph_def))) {
  1272. GELOGE(FAILED, "Trim Graph fail.");
  1273. return INTERNAL_ERROR;
  1274. }
  1275. GELOGI("After Trim, The graph_def.node size is:%d", graph_def.node_size());
  1276. }
  1277. // Construct ParseArg for root graph.
  1278. google::protobuf::Message *root_proto = &graph_def;
  1279. std::deque<ParseArg> tasks;
  1280. tasks.push_back({root_proto, "root", nullptr, "", root_graph});
  1281. // Get sub graph from graph_def_library.pbtxt which prepared before and stored in model_path.
  1282. std::map<std::string, domi::tensorflow::GraphDef> function_name_to_graphdef;
  1283. // Parse all root graph and sub graph.
  1284. while (!tasks.empty()) {
  1285. auto arg = tasks.front();
  1286. tasks.pop_front();
  1287. if (arg.proto == nullptr) {
  1288. if (function_name_to_graphdef.empty() && (ori_def.library().function_size() > 0)) {
  1289. GELOGI("Graph has function size: %d ", ori_def.library().function_size());
  1290. domi::tensorflow::GraphDefLibrary graph_def_library;
  1291. GE_CHK_STATUS_RET(GetFunctionProto(model_path, graph_def_library));
  1292. for (auto &ge_graph_def : graph_def_library.graph_def()) {
  1293. function_name_to_graphdef[ge_graph_def.name()] = ge_graph_def.graph();
  1294. GELOGD("Graph_def name: %s, node size: %d", ge_graph_def.name().c_str(), ge_graph_def.graph().node_size());
  1295. }
  1296. }
  1297. const std::map<std::string, domi::tensorflow::GraphDef>::const_iterator
  1298. iter = function_name_to_graphdef.find(arg.function_name);
  1299. if (iter == function_name_to_graphdef.end()) {
  1300. ErrorManager::GetInstance().ATCReportErrMessage("E12013", {"functionname"}, {arg.function_name});
  1301. GELOGE(FAILED, "Failed to get subgraph by function name %s", arg.function_name.c_str());
  1302. return FAILED;
  1303. }
  1304. arg.proto = &(iter->second);
  1305. }
  1306. GELOGI("Begin to parse graph %s", arg.function_name.c_str());
  1307. auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW);
  1308. auto ret = model_parser->ParseAllGraph(arg.proto, arg.graph);
  1309. if (ret != SUCCESS) {
  1310. GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(),
  1311. arg.graph->GetName().c_str());
  1312. return ret;
  1313. }
  1314. ret = PostOpProcessForSubgraph(arg);
  1315. if (ret != SUCCESS) {
  1316. // the error log has been printed inner the function
  1317. return ret;
  1318. }
  1319. ret = GenSubgraphParseTasks(arg.graph, tasks);
  1320. if (ret != SUCCESS) {
  1321. REPORT_CALL_ERROR("E19999", "Failed to gen tasks on graph:%s for next iteration", arg.graph->GetName().c_str());
  1322. GELOGE(ret, "Failed to gen tasks on graph %s for next iteration", arg.graph->GetName().c_str());
  1323. return ret;
  1324. }
  1325. }
  1326. return SUCCESS;
  1327. }
  1328. Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) {
  1329. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  1330. GE_CHECK_NOTNULL(proto);
  1331. GE_CHECK_NOTNULL(graph);
  1332. const domi::tensorflow::GraphDef *ori_graph =
  1333. ge::PtrToPtr<google::protobuf::Message, domi::tensorflow::GraphDef>(proto);
  1334. // Make a copy for operation without modifying the original graph def.
  1335. domi::tensorflow::GraphDef graph_def = *ori_graph;
  1336. GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&graph_def, domi::TENSORFLOW),
  1337. "Run ProtoType Pass Failed");
  1338. shared_ptr<ge::ScopeGraph> scope_graph = nullptr;
  1339. Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph);
  1340. if (ret != SUCCESS) {
  1341. GELOGE(ret, "[TF Parse] scope fusion failed.");
  1342. return ret;
  1343. }
  1344. GELOGD("[TF Parse] scope fusion success.");
  1345. GE_RETURN_IF_ERROR(OptimizeConstNodes4CustomOp(&graph_def));
  1346. GELOGD("[TF Parse] optimize const nodes for custom op base success.");
  1347. // Add nodedef in the model to prechecker and check the general parameters
  1348. // Prevent data residue in multiple calls
  1349. PreChecker::Instance().Clear();
  1350. for (int i = 0; i < graph_def.node_size(); i++) {
  1351. const domi::tensorflow::NodeDef &node = graph_def.node(i);
  1352. GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()),
  1353. "Add node_def to PreChecker failed, node name: %s.", node.name().c_str());
  1354. if (PreChecker::Instance().CheckName(&node) != SUCCESS) {
  1355. GELOGE(FAILED, "Check op[%s] failed, name repeat in tensorflow pb file.", node.name().c_str());
  1356. return FAILED;
  1357. }
  1358. if (node.op() != TENSORFLOWF_NODE_OP_IDENTITY) {
  1359. if (PreChecker::Instance().CheckType(&node, true) != SUCCESS) {
  1360. GELOGE(FAILED, "Check op[%s]'s optype failed, type is not supported.", node.name().c_str());
  1361. return FAILED;
  1362. }
  1363. }
  1364. }
  1365. bool has_error = false;
  1366. // save node name
  1367. vector<string> op_node_name_list;
  1368. for (int i = 0; i < graph_def.node_size(); i++) {
  1369. const domi::tensorflow::NodeDef *node_def = graph_def.mutable_node(i);
  1370. // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_
  1371. if (MaybeFusionOp(scope_graph, node_def)) {
  1372. GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str());
  1373. }
  1374. // Do not exit immediately when there is an error, wait until all errors are collected before exiting
  1375. GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(node_def, op_node_name_list), has_error = true);
  1376. }
  1377. // The fusion operator has passed the verification.
  1378. // The errors of internal non key operators (which will be ignored later)
  1379. // do not affect the transformation of the whole model,
  1380. // So clear the error information of non key operators
  1381. // This function call affects the return value of prechecker::instance().Haserror()
  1382. GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list));
  1383. // Building input and input relationships for all OP nodes
  1384. GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def));
  1385. GELOGD("[TF Parse] get op nodes context from graph success.");
  1386. // Infer input formats
  1387. ge::GetParserContext().format = InferInputFormats();
  1388. GELOGD("[TF Parse] infer input formats success.");
  1389. // Building input-output relationship between fusionop and common op
  1390. GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, op_node_name_list));
  1391. GELOGD("[TF Parse] update all node op context success.");
  1392. // set user-designate-inputs-order
  1393. std::vector<std::string> user_inputs_order;
  1394. for (auto &input : ge::GetParserContext().user_input_dims) {
  1395. user_inputs_order.push_back(input.first);
  1396. }
  1397. graph->SetInputsOrder(user_inputs_order);
  1398. ret = AddFusionNodeDef(scope_graph, op_node_name_list);
  1399. if (ret != SUCCESS) {
  1400. GELOGE(ret, "Add fusion NodeDef failed.");
  1401. DeleteFuisonNodeDef();
  1402. return ret;
  1403. }
  1404. GELOGI("TF op node size = %zu.", op_node_name_list.size());
  1405. // Loop analysis of op_nodes and map them to nodes in graph
  1406. for (size_t i = 0; i < op_node_name_list.size(); i++) {
  1407. GELOGI("TF op node name = %s.", op_node_name_list[i].c_str());
  1408. const string op_node_name = op_node_name_list[i];
  1409. const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name_list[i]];
  1410. if (node_def == nullptr) {
  1411. REPORT_INNER_ERROR("E19999", "Node:%s can't find in nodedef_map_, check invalid", op_node_name.c_str());
  1412. GELOGE(INTERNAL_ERROR, "Cannot find [%s] in nodedef map.", op_node_name_list[i].c_str());
  1413. DeleteFuisonNodeDef();
  1414. return INTERNAL_ERROR;
  1415. }
  1416. const string &node_op = node_def->op();
  1417. if (tensorflow_op_map.find(node_op) == tensorflow_op_map.end()) {
  1418. GELOGW("%s not found in tensorflow_op_map.", node_op.c_str());
  1419. }
  1420. ret = AddNode(node_def, graph, scope_graph);
  1421. if (ret != SUCCESS) {
  1422. GELOGE(ret, "Add op[%s] failed.", node_def->name().c_str());
  1423. DeleteFuisonNodeDef();
  1424. return ret;
  1425. }
  1426. }
  1427. GELOGD("[TF Parse] parse tf node to geop success.");
  1428. DeleteFuisonNodeDef();
  1429. GE_RETURN_IF_ERROR(AddEdges(graph));
  1430. Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
  1431. ParserUtils::OutputMapping final_output_nodes;
  1432. GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph, final_output_nodes));
  1433. GE_RETURN_IF_ERROR(UpdateOutputsInfo(final_output_nodes));
  1434. GE_RETURN_IF_ERROR(RemoveIsolateNode(graph));
  1435. GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph));
  1436. GE_RETURN_IF_ERROR(graph->TopologicalSorting());
  1437. if (has_error) {
  1438. GELOGE(PARAM_INVALID, "Precheck has errors.");
  1439. return PARAM_INVALID;
  1440. }
  1441. GELOGI("[TF Parser] Parse proto success.");
  1442. return SUCCESS;
  1443. }
  1444. Status TensorFlowModelParser::GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def) {
  1445. // Build the input relationship first
  1446. for (auto &iter : op_node_context_map_) {
  1447. map<string, std::vector<std::pair<int32_t, int32_t>>> input_map;
  1448. const string &op_node_name = iter.first;
  1449. GE_RETURN_IF_ERROR(GetOpNodeInputMap(op_node_name, input_map));
  1450. OpNodeContext &op_node_context = iter.second;
  1451. op_node_context.input_map = input_map;
  1452. }
  1453. // Then build the output relationship
  1454. GE_RETURN_IF_ERROR(GetOpNodeOutputMap(graph_def));
  1455. return SUCCESS;
  1456. }
  1457. // Get the input relation of opnode includeing input_op and input_const
  1458. Status TensorFlowModelParser::GetOpNodeInputMap(const string &op_node_name,
  1459. map<string, std::vector<std::pair<int32_t, int32_t>>> &input_map) {
  1460. // Get the current nodedef according to the node_name
  1461. const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name];
  1462. GE_CHECK_NOTNULL(node_def);
  1463. int32_t input_index = 0;
  1464. int32_t output_index = 0;
  1465. for (const string &input_node_name : node_def->input()) {
  1466. GELOGD("Get Op InputMap, node_name : %s, input node:%s", node_def->name().c_str(), input_node_name.c_str());
  1467. string tmp_node_name;
  1468. bool control = false;
  1469. GE_RETURN_IF_ERROR(CheckInputNodeName(input_node_name, &tmp_node_name, &output_index, &control));
  1470. input_map[tmp_node_name].push_back({output_index, control ? kControlSlot : input_index});
  1471. SaveEdgesControlInfo(node_def->name(), control);
  1472. input_index = control ? input_index : input_index + 1;
  1473. }
  1474. return SUCCESS;
  1475. }
  1476. Status TensorFlowModelParser::GetOpNodeOutputMap(const domi::tensorflow::GraphDef &graph_def) {
  1477. // Loop through all nodes in graphdef
  1478. for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) {
  1479. auto currentIter = op_node_context_map_.find(node_def.name());
  1480. if (currentIter != op_node_context_map_.end()) {
  1481. OpNodeContext &op_node_context = currentIter->second;
  1482. // Find all input nodes of the current node
  1483. for (auto &inputIter : op_node_context.input_map) {
  1484. auto iter = op_node_context_map_.find(inputIter.first);
  1485. if (iter != op_node_context_map_.end()) {
  1486. std::vector<std::pair<int32_t, int32_t>> inputpairs = inputIter.second;
  1487. OpNodeContext &op_node_context1 = iter->second;
  1488. op_node_context1.output_map[node_def.name()].assign(inputpairs.begin(), inputpairs.end());
  1489. }
  1490. }
  1491. }
  1492. }
  1493. return SUCCESS;
  1494. }
  1495. Status TensorFlowModelParser::GeStoi(const string &input_node_name, const string &index_str, int32_t *index) {
  1496. try {
  1497. int32_t tmp_index = static_cast<int32_t>(std::stoi(index_str.c_str(), nullptr, 10));
  1498. *index = tmp_index;
  1499. } catch (std::invalid_argument &) {
  1500. ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"},
  1501. {"input_node_name(" + input_node_name + ")", index_str});
  1502. GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is invalid argument!", input_node_name.c_str(),
  1503. index_str.c_str());
  1504. return INTERNAL_ERROR;
  1505. } catch (std::out_of_range &) {
  1506. ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"},
  1507. {"input_node_name(" + input_node_name + ")", index_str});
  1508. GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is out of range!", input_node_name.c_str(),
  1509. index_str.c_str());
  1510. return INTERNAL_ERROR;
  1511. } catch (...) {
  1512. ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"},
  1513. {"input_node_name(" + input_node_name + ")", index_str});
  1514. GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is bad argument!", input_node_name.c_str(),
  1515. index_str.c_str());
  1516. return INTERNAL_ERROR;
  1517. }
  1518. return SUCCESS;
  1519. }
  1520. Status TensorFlowModelParser::CheckInputNodeName(const string &input_node_name, string *node_name, int32_t *index,
  1521. bool *control) {
  1522. // Processing scene: input: "^fastrcnn_predictions/map/while/Identity""
  1523. string tmp_input_node_name = input_node_name;
  1524. if (tmp_input_node_name.find("^") == 0) {
  1525. tmp_input_node_name = tmp_input_node_name.substr(1, tmp_input_node_name.length() - 1);
  1526. if (control != nullptr) {
  1527. *control = true;
  1528. }
  1529. } else {
  1530. if (control != nullptr) {
  1531. *control = false;
  1532. }
  1533. }
  1534. auto find = tmp_input_node_name.find(":");
  1535. if (find == string::npos) {
  1536. *node_name = tmp_input_node_name;
  1537. if (index == nullptr) {
  1538. return SUCCESS;
  1539. }
  1540. *index = 0;
  1541. return SUCCESS;
  1542. }
  1543. string indexstr = tmp_input_node_name.substr(find + 1, tmp_input_node_name.length() - find - 1);
  1544. *node_name = tmp_input_node_name.substr(0, find);
  1545. if (index == nullptr) {
  1546. return SUCCESS;
  1547. }
  1548. if (GeStoi(input_node_name, indexstr, index) != SUCCESS) {
  1549. return INTERNAL_ERROR;
  1550. }
  1551. return SUCCESS;
  1552. }
  1553. Status TensorFlowModelParser::RunScopeFusionPass(const vector<string> &scope_passes_list,
  1554. ScopePassManager &pass_manager,
  1555. shared_ptr<ge::ScopeGraph> &scope_graph) {
  1556. if (scope_passes_list.empty()) {
  1557. return SUCCESS;
  1558. }
  1559. GE_CHECK_NOTNULL(scope_graph);
  1560. auto &impl = ge::ScopeFusionPassRegistry::GetInstance().impl_;
  1561. if (impl == nullptr) {
  1562. REPORT_INNER_ERROR("E19999", "ScopeFusionPassRegistry is not properly initialized.");
  1563. GELOGE(ge::MEMALLOC_FAILED, "ScopeFusionPassRegistry is not properly initialized.");
  1564. return ge::MEMALLOC_FAILED;
  1565. }
  1566. for (auto &pass_name : scope_passes_list) {
  1567. auto pass = impl->CreateScopeFusionPass(pass_name);
  1568. if (pass == nullptr) {
  1569. REPORT_INNER_ERROR("E19999", "Scope fusion pass[%s] is not registered.", pass_name.c_str());
  1570. GELOGE(INTERNAL_ERROR, "Scope fusion pass[%s] is not registered.", pass_name.c_str());
  1571. return INTERNAL_ERROR;
  1572. }
  1573. Status ret = pass_manager.AddPass(pass);
  1574. if (ret != SUCCESS) {
  1575. REPORT_CALL_ERROR("E19999", "Add scope fusion pass[%s] failed.", pass_name.c_str());
  1576. GELOGE(INTERNAL_ERROR, "Add scope fusion pass[%s] failed.", pass_name.c_str());
  1577. return INTERNAL_ERROR;
  1578. }
  1579. }
  1580. Status ret = pass_manager.Run(scope_graph);
  1581. if (ret != SUCCESS && ret != domi::SCOPE_NOT_CHANGED) {
  1582. GELOGE(FAILED, "Run scope fusion pass failed, ret:%u.", ret);
  1583. return FAILED;
  1584. }
  1585. return SUCCESS;
  1586. }
  1587. bool TensorFlowModelParser::MaybeFusionOp(shared_ptr<ge::ScopeGraph> &scope_graph,
  1588. const domi::tensorflow::NodeDef *node_def) {
  1589. GE_CHECK_NOTNULL(scope_graph);
  1590. GE_CHECK_NOTNULL(node_def);
  1591. // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_
  1592. ge::ScopeFusionOpInfo info;
  1593. std::vector<ge::ScopeFusionOpInfo> info_list;
  1594. auto &impl = scope_graph->impl_;
  1595. if (impl->IsFusionOpChild(node_def->name(), info_list)) {
  1596. GE_IF_BOOL_EXEC(info_list.size() > 0,
  1597. for (size_t i = 0; i < info_list.size(); ++i) {
  1598. fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].fusion_op_type);
  1599. fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].description);
  1600. fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(node_def);
  1601. if (info_list[i].fusion_op_type == "Dropout" &&
  1602. (node_def->op() == "Add" || node_def->op() == "RandomUniform")) {
  1603. fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(nodedef_map_[node_def->input(0)]);
  1604. }
  1605. if (info_list[i].fusion_op_type == "LayerNorm" && node_def->op() == "Mean") {
  1606. fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(nodedef_map_[node_def->input(1)]);
  1607. }
  1608. fusion_op_policy_[info_list[i].fusion_node_name] = info_list[i].scope_pass;
  1609. fusion_op_children_[node_def->name()] = info_list[i];
  1610. });
  1611. GE_IF_BOOL_EXEC(info_list.size() == 0, fusion_op_type_map_[info.fusion_node_name].push_back(info.fusion_op_type);
  1612. fusion_op_type_map_[info.fusion_node_name].push_back(info.description);
  1613. fusion_op_nodedef_map_[info.fusion_node_name].push_back(node_def);
  1614. fusion_op_policy_[info.fusion_node_name] = info.scope_pass;
  1615. fusion_op_children_[node_def->name()] = info);
  1616. return true;
  1617. }
  1618. return false;
  1619. }
  1620. bool TensorFlowModelParser::IsFusionOpChild(const string &node_name, ge::ScopeFusionOpInfo *info) {
  1621. GE_CHK_BOOL_EXEC(info != nullptr, REPORT_CALL_ERROR("E19999", "Param info is nullptr, check invalid");
  1622. return false, "fusion info is null.");
  1623. // 1.View in full match fusion strategy first
  1624. // 2.View in scope fusion policy then
  1625. auto iter = fusion_op_children_.find(node_name);
  1626. if (iter != fusion_op_children_.end()) {
  1627. info->node_name = fusion_op_children_[node_name].node_name;
  1628. info->fusion_node_name = fusion_op_children_[node_name].fusion_node_name;
  1629. info->fusion_op_type = fusion_op_children_[node_name].fusion_op_type;
  1630. info->description = fusion_op_children_[node_name].description;
  1631. info->scope_pass = fusion_op_children_[node_name].scope_pass;
  1632. return true;
  1633. }
  1634. return false;
  1635. }
  1636. bool TensorFlowModelParser::FusionOpChildIgnore(const shared_ptr<ge::ScopeGraph> &scope_graph,
  1637. const ge::ScopeFusionOpInfo &info) {
  1638. GE_CHECK_NOTNULL(scope_graph);
  1639. bool ignore = false;
  1640. if (info.scope_pass) {
  1641. // Scope fusion strategy
  1642. auto &impl = scope_graph->impl_;
  1643. ignore = impl->FusionOpChildIgnore(info);
  1644. }
  1645. return ignore;
  1646. }
  1647. bool TensorFlowModelParser::IsFusionOp(const shared_ptr<ge::ScopeGraph> &scope_graph,
  1648. const domi::tensorflow::NodeDef *node_def) {
  1649. // The caller guarantees that the pointer is not null
  1650. auto &impl = scope_graph->impl_;
  1651. return (impl->IsFusionOp(node_def));
  1652. }
  1653. Status TensorFlowModelParser::GetInPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info,
  1654. const int32_t old_index, int32_t &new_index) {
  1655. GE_CHECK_NOTNULL(scope_graph);
  1656. if (info.scope_pass) {
  1657. auto &impl = scope_graph->impl_;
  1658. return impl->GetInputOrOutputIndex(info, old_index, true, new_index);
  1659. }
  1660. GELOGE(INTERNAL_ERROR, "Fusion op should come from scope fusion pass, node name:%s, fusion node name:%s",
  1661. info.node_name.c_str(), info.fusion_node_name.c_str());
  1662. return INTERNAL_ERROR;
  1663. }
  1664. Status TensorFlowModelParser::GetOutPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info,
  1665. const int32_t old_index, int32_t &new_index) {
  1666. GE_CHECK_NOTNULL(scope_graph);
  1667. if (info.scope_pass) {
  1668. auto &impl = scope_graph->impl_;
  1669. return impl->GetInputOrOutputIndex(info, old_index, false, new_index);
  1670. }
  1671. GELOGE(INTERNAL_ERROR, "Fusion op should come from scope fusion pass, node name:%s, fusion node name:%s",
  1672. info.node_name.c_str(), info.fusion_node_name.c_str());
  1673. return INTERNAL_ERROR;
  1674. }
  1675. bool TensorFlowModelParser::ConstOpNeedUpdate(const string &op_name) {
  1676. if (nodedef_map_[op_name]->op() != TENSORFLOWF_NODE_OP_CONST) {
  1677. // Normal op need to update
  1678. return true;
  1679. } else {
  1680. auto iter = op_node_context_map_.find(op_name);
  1681. if (iter != op_node_context_map_.end()) {
  1682. ge::ScopeFusionOpInfo info;
  1683. auto outmap = iter->second.output_map;
  1684. for (auto &out_node : outmap) {
  1685. // if the const op output connected to are all fusion ops and the cosnt op is not in the update vector
  1686. if (!IsFusionOpChild(out_node.first, &info)) {
  1687. return true;
  1688. }
  1689. }
  1690. }
  1691. return true;
  1692. }
  1693. }
  1694. Status TensorFlowModelParser::UpdateAllNodeOpContext(shared_ptr<ge::ScopeGraph> &scope_graph,
  1695. vector<string> &op_node_name_list) {
  1696. GE_CHECK_NOTNULL(scope_graph);
  1697. vector<string> tmp_op_node_name_list;
  1698. map<string, OpNodeContext> tmp_fusion_op_node_context_map;
  1699. for (auto &op_node_name : op_node_name_list) {
  1700. auto iter = op_node_context_map_.find(op_node_name);
  1701. if (iter != op_node_context_map_.end()) {
  1702. ge::ScopeFusionOpInfo info;
  1703. if (IsFusionOpChild(op_node_name, &info) && nodedef_map_[op_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) {
  1704. // This node is a fusion operator
  1705. const std::map<std::string, OpNodeContext>::const_iterator
  1706. fusion_iter = tmp_fusion_op_node_context_map.find(info.fusion_node_name);
  1707. if (fusion_iter == tmp_fusion_op_node_context_map.end()) {
  1708. OpNodeContext op_node_context;
  1709. tmp_fusion_op_node_context_map[info.fusion_node_name] = op_node_context;
  1710. tmp_op_node_name_list.push_back(info.fusion_node_name);
  1711. }
  1712. OpNodeContext &fusion_op_node_context = tmp_fusion_op_node_context_map[info.fusion_node_name];
  1713. OpNodeContext &normal_op_node_context = op_node_context_map_[op_node_name];
  1714. GE_RETURN_IF_ERROR(UpdateFusionOpContext(scope_graph, info, fusion_op_node_context, normal_op_node_context));
  1715. // Delete fusion operator context
  1716. op_node_context_map_.erase(iter);
  1717. } else {
  1718. // This node is a common operator
  1719. OpNodeContext &normal_op_node_context = op_node_context_map_[op_node_name];
  1720. GE_RETURN_IF_ERROR(UpdateNormalOpContext(scope_graph, op_node_name, normal_op_node_context));
  1721. tmp_op_node_name_list.push_back(op_node_name);
  1722. }
  1723. }
  1724. }
  1725. // update op_node_name_list
  1726. op_node_name_list.clear();
  1727. op_node_name_list.assign(tmp_op_node_name_list.begin(), tmp_op_node_name_list.end());
  1728. // update op_node_context_map_
  1729. for (const auto &iter : tmp_fusion_op_node_context_map) {
  1730. op_node_context_map_[iter.first] = iter.second;
  1731. }
  1732. // Normalized context
  1733. GE_RETURN_IF_ERROR(NormalizeAllNodeOpContext());
  1734. return SUCCESS;
  1735. }
  1736. Status TensorFlowModelParser::UpdateFusionOpContext(shared_ptr<ge::ScopeGraph> &scope_graph,
  1737. const ge::ScopeFusionOpInfo &info,
  1738. OpNodeContext &fusion_op_node_context,
  1739. OpNodeContext &normal_op_node_context) {
  1740. GE_CHECK_NOTNULL(scope_graph);
  1741. if (FusionOpChildIgnore(scope_graph, info)) {
  1742. // The inner children operators of the fusion operator can be ignored directly
  1743. // if they do not establish the edge relationship with other outer ordinary / fusion operators
  1744. return SUCCESS;
  1745. }
  1746. GE_CHK_STATUS_RET(UppdateInputMap(scope_graph, info, fusion_op_node_context, normal_op_node_context),
  1747. "UppdateInputMap ret fail");
  1748. GE_CHK_STATUS_RET(UppdateOutputMap(scope_graph, info, fusion_op_node_context, normal_op_node_context),
  1749. "UppdateOutputMap ret fail");
  1750. return SUCCESS;
  1751. }
  1752. Status TensorFlowModelParser::UppdateInputMap(shared_ptr<ge::ScopeGraph> &scope_graph,
  1753. const ge::ScopeFusionOpInfo &info, OpNodeContext &fusion_op_node_context,
  1754. OpNodeContext &normal_op_node_context) {
  1755. GE_CHECK_NOTNULL(scope_graph);
  1756. for (auto &iter : normal_op_node_context.input_map) {
  1757. string input_node_name = iter.first;
  1758. std::vector<std::pair<int32_t, int32_t>> &pairs = iter.second;
  1759. ge::ScopeFusionOpInfo from_info;
  1760. int32_t from_index = 0;
  1761. int32_t to_index = 0;
  1762. if (!ConstOpNeedUpdate(input_node_name)) {
  1763. GELOGI("%s is const node connected to a fusion child, ignore", input_node_name.c_str());
  1764. continue;
  1765. }
  1766. if (IsFusionOpChild(input_node_name, &from_info)) {
  1767. if (info.fusion_node_name == from_info.fusion_node_name) {
  1768. // Ignore two sub operators in the same fusion operator
  1769. continue;
  1770. }
  1771. for (auto &pair : pairs) {
  1772. GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, from_info, pair.first, from_index),
  1773. "GetOutPutIndex failed ,input_node_name %s.", input_node_name.c_str());
  1774. GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, info, pair.second, to_index),
  1775. "GetInPutIndex failed ,input_node_name %s.", input_node_name.c_str());
  1776. fusion_op_node_context.input_map[from_info.fusion_node_name].push_back({from_index, to_index});
  1777. UpdateEdgesControlInfo(info);
  1778. GELOGD("[Update op context] update fusion input map for fusion input, %s:%d TO %s:%d",
  1779. from_info.fusion_node_name.c_str(), from_index, info.fusion_node_name.c_str(), to_index);
  1780. }
  1781. } else {
  1782. for (auto &pair : pairs) {
  1783. from_index = pair.first;
  1784. GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, info, pair.second, to_index),
  1785. "GetInPutIndex input_node_name %s.", input_node_name.c_str());
  1786. fusion_op_node_context.input_map[input_node_name].push_back({from_index, to_index});
  1787. UpdateEdgesControlInfo(info);
  1788. GELOGD("[Update op context] update fusion input map for normal input, %s:%d TO %s:%d",
  1789. input_node_name.c_str(), from_index, info.fusion_node_name.c_str(), to_index);
  1790. }
  1791. }
  1792. }
  1793. return SUCCESS;
  1794. }
  1795. Status TensorFlowModelParser::UppdateOutputMap(shared_ptr<ge::ScopeGraph> &scope_graph,
  1796. const ge::ScopeFusionOpInfo &info, OpNodeContext &fusion_op_node_context,
  1797. OpNodeContext &normal_op_node_context) {
  1798. GE_CHECK_NOTNULL(scope_graph);
  1799. for (auto &iter : normal_op_node_context.output_map) {
  1800. string output_node_name = iter.first;
  1801. std::vector<std::pair<int32_t, int32_t>> &pairs = iter.second;
  1802. ge::ScopeFusionOpInfo to_info;
  1803. int32_t from_index = 0;
  1804. int32_t to_index = 0;
  1805. if (IsFusionOpChild(output_node_name, &to_info)) {
  1806. if (info.fusion_node_name == to_info.fusion_node_name) {
  1807. // Ignore two sub operators in the same fusion operator
  1808. continue;
  1809. }
  1810. for (auto &pair : pairs) {
  1811. GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, info, pair.first, from_index),
  1812. "fusion GetOutPutIndex failed ,output_node_name %s.", output_node_name.c_str());
  1813. GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, to_info, pair.second, to_index),
  1814. "fusion GetInPutIndex failed ,output_node_name %s.", output_node_name.c_str());
  1815. fusion_op_node_context.output_map[to_info.fusion_node_name].push_back({from_index, to_index});
  1816. GELOGD("[Update op context] update fusion output map for fusion output, %s:%d TO %s:%d",
  1817. info.fusion_node_name.c_str(), from_index, to_info.fusion_node_name.c_str(), to_index);
  1818. }
  1819. } else {
  1820. for (auto &pair : pairs) {
  1821. to_index = pair.second;
  1822. GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, info, pair.first, from_index),
  1823. "not fusion,GetOutPutIndex failed ,output_node_name %s.", output_node_name.c_str());
  1824. fusion_op_node_context.output_map[output_node_name].push_back({from_index, to_index});
  1825. GELOGD("[Update op context] update fusion output map for normal output, %s:%d TO %s:%d",
  1826. info.fusion_node_name.c_str(), from_index, output_node_name.c_str(), to_index);
  1827. }
  1828. }
  1829. }
  1830. return SUCCESS;
  1831. }
  1832. Status TensorFlowModelParser::EraseNormalOpOutputIfChild(shared_ptr<ge::ScopeGraph> &scope_graph,
  1833. const string &op_node_name,
  1834. OpNodeContext &normal_op_node_context) {
  1835. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> tmp_output_map;
  1836. for (auto iter = normal_op_node_context.output_map.begin(); iter != normal_op_node_context.output_map.end();) {
  1837. string output_node_name = iter->first;
  1838. ge::ScopeFusionOpInfo to_info;
  1839. if (IsFusionOpChild(output_node_name, &to_info) &&
  1840. nodedef_map_[output_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) {
  1841. // Fuse operator, update index
  1842. std::vector<std::pair<int32_t, int32_t>> &pairs = iter->second;
  1843. int32_t to_index = 0;
  1844. for (auto &pair : pairs) {
  1845. int32_t from_index = pair.first;
  1846. GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, to_info, pair.second, to_index),
  1847. "GetInPutIndex failed ,output_node_name %s.", output_node_name.c_str());
  1848. tmp_output_map[to_info.fusion_node_name].push_back({from_index, to_index});
  1849. GELOGD("[Update op context] update normal output map for fusion output, %s:%d TO %s:%d", op_node_name.c_str(),
  1850. from_index, to_info.fusion_node_name.c_str(), to_index);
  1851. }
  1852. iter = normal_op_node_context.output_map.erase(iter);
  1853. } else {
  1854. iter++;
  1855. }
  1856. }
  1857. for (auto &iter : tmp_output_map) {
  1858. normal_op_node_context.output_map[iter.first] = iter.second;
  1859. }
  1860. return SUCCESS;
  1861. }
  1862. Status TensorFlowModelParser::UpdateNormalOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, const string &op_node_name,
  1863. OpNodeContext &normal_op_node_context) {
  1864. GE_CHECK_NOTNULL(scope_graph);
  1865. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> tmp_input_map;
  1866. for (auto iter = normal_op_node_context.input_map.begin(); iter != normal_op_node_context.input_map.end();) {
  1867. string input_node_name = iter->first;
  1868. ge::ScopeFusionOpInfo from_info;
  1869. if (IsFusionOpChild(input_node_name, &from_info) &&
  1870. nodedef_map_[input_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) {
  1871. // Fuse operator, update index
  1872. std::vector<std::pair<int32_t, int32_t>> &pairs = iter->second;
  1873. int32_t from_index = 0;
  1874. for (auto &pair : pairs) {
  1875. int32_t to_index = pair.second;
  1876. GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, from_info, pair.first, from_index),
  1877. "GetOutPutIndex failed ,input_node_name %s.", input_node_name.c_str());
  1878. tmp_input_map[from_info.fusion_node_name].push_back({from_index, to_index});
  1879. GELOGD("[Update op context] update normal input map for fusion input, %s:%d TO %s:%d",
  1880. from_info.fusion_node_name.c_str(), from_index, op_node_name.c_str(), to_index);
  1881. }
  1882. iter = normal_op_node_context.input_map.erase(iter);
  1883. } else {
  1884. iter++;
  1885. }
  1886. }
  1887. Status ret = EraseNormalOpOutputIfChild(scope_graph, op_node_name, normal_op_node_context);
  1888. if (ret != SUCCESS) {
  1889. return ret;
  1890. }
  1891. for (auto &iter : tmp_input_map) {
  1892. normal_op_node_context.input_map[iter.first] = iter.second;
  1893. }
  1894. return SUCCESS;
  1895. }
  1896. Status TensorFlowModelParser::NormalizeAllNodeOpContext() {
  1897. for (auto iter = op_node_context_map_.begin(); iter != op_node_context_map_.end();) {
  1898. OpNodeContext &context = iter->second;
  1899. NormalizeInputOrOutputMap(iter->first, context.input_map);
  1900. NormalizeInputOrOutputMap(iter->first, context.output_map);
  1901. if ((context.input_map.size() == 0) && (context.output_map.size() == 0)) {
  1902. GELOGD("[Update op context] node: %s will be removed at the back.", iter->first.c_str());
  1903. iter = op_node_context_map_.erase(iter);
  1904. } else {
  1905. iter++;
  1906. }
  1907. }
  1908. return SUCCESS;
  1909. }
  1910. Status TensorFlowModelParser::NormalizeInputOrOutputMap(
  1911. const string &node_name, std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &context_map) {
  1912. if (context_map.empty()) {
  1913. return SUCCESS;
  1914. }
  1915. for (auto iter = context_map.begin(); iter != context_map.end();) {
  1916. std::vector<std::pair<int32_t, int32_t>> &pairs = iter->second;
  1917. std::vector<std::pair<int32_t, int32_t>> temp_pairs;
  1918. std::set<std::string> compare_set;
  1919. for (auto &pair : pairs) {
  1920. bool is_fusion_child = (fusion_op_children_.find(node_name) != fusion_op_children_.cend()) ||
  1921. (fusion_op_children_.find(iter->first) != fusion_op_children_.cend());
  1922. bool is_fusion_op = (fusion_op_type_map_.find(node_name) != fusion_op_type_map_.cend()) ||
  1923. (fusion_op_type_map_.find(iter->first) != fusion_op_type_map_.cend());
  1924. if (((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) &&
  1925. (is_fusion_child || is_fusion_op)) {
  1926. // The edge will be cut off at the back, ignoring
  1927. continue;
  1928. }
  1929. string name = to_string(pair.first) + ":" + to_string(pair.second);
  1930. const std::set<std::string>::const_iterator compare_iter = compare_set.find(name);
  1931. if (compare_iter != compare_set.end()) {
  1932. // pair<from,to> repeat, ignore
  1933. continue;
  1934. }
  1935. temp_pairs.push_back(pair);
  1936. compare_set.insert(name);
  1937. }
  1938. if (temp_pairs.empty()) {
  1939. // If there is no pair, the context can be deleted
  1940. iter = context_map.erase(iter);
  1941. continue;
  1942. } else {
  1943. iter++;
  1944. }
  1945. pairs.clear();
  1946. pairs.assign(temp_pairs.begin(), temp_pairs.end());
  1947. }
  1948. return SUCCESS;
  1949. }
  1950. void TensorFlowModelParser::DeleteFuisonNodeDef() {
  1951. for (auto &fusion_nodedef : fusion_nodedef_list) {
  1952. GE_DELETE_NEW_SINGLE(fusion_nodedef);
  1953. }
  1954. }
  1955. void TensorFlowModelParser::SaveEdgesControlInfo(const string &node_name, const bool control) {
  1956. if (control) {
  1957. // If the control attribute is true, save the control attribute to edges_control_map
  1958. edges_control_map[node_name].push_back(kControlSlot);
  1959. }
  1960. }
  1961. void TensorFlowModelParser::UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo &info) {
  1962. const std::map<std::string, std::vector<int32_t>>::const_iterator iter = edges_control_map.find(info.node_name);
  1963. if (iter != edges_control_map.end()) {
  1964. // Delete the original fusion operator node information and add the fusion operator control edge information
  1965. edges_control_map.erase(iter);
  1966. edges_control_map[info.fusion_node_name].push_back(kControlSlot);
  1967. }
  1968. }
  1969. bool TensorFlowModelParser::GetEdgesControlInfo(const string &node_name, const int32_t index) const {
  1970. // If the node name is included, then confirm whether the index is the same
  1971. auto iter = edges_control_map.find(node_name);
  1972. if (iter != edges_control_map.end()) {
  1973. for (auto &i : iter->second) {
  1974. if (i == index) {
  1975. return true;
  1976. }
  1977. }
  1978. }
  1979. return false;
  1980. }
  1981. Status TensorFlowModelParser::ClearFusionOpError(const vector<string> &op_node_name_list) {
  1982. for (const auto &name : op_node_name_list) {
  1983. ge::ScopeFusionOpInfo info;
  1984. if (IsFusionOpChild(name, &info)) {
  1985. const NodeDef *node = nodedef_map_[name];
  1986. GE_CHECK_NOTNULL(node);
  1987. GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().Clear(node, "fused and removed."),
  1988. "Clear pre-checking for node %s failed.", node->name().c_str());
  1989. }
  1990. }
  1991. return SUCCESS;
  1992. }
  1993. Status TensorFlowModelParser::ToJson(const char *model_file, const char *json_file) {
  1994. GE_CHK_BOOL_RET_STATUS(model_file != nullptr, FAILED, "model_file is nullptr.");
  1995. GE_CHK_BOOL_RET_STATUS(json_file != nullptr, FAILED, "json_file is nullptr.");
  1996. domi::tensorflow::GraphDef graph_def;
  1997. nlohmann::json j;
  1998. GE_RETURN_WITH_LOG_IF_FALSE(ge::parser::ReadProtoFromBinaryFile(model_file, &graph_def),
  1999. "ReadProtoFromBinaryFile failed, file:%s.", model_file);
  2000. Pb2Json::Message2Json(graph_def, kTfBlackFields, j, true);
  2001. return ModelSaver::SaveJsonToFile(json_file, j);
  2002. }
  2003. Status TensorFlowWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) {
  2004. (void)data;
  2005. (void)size;
  2006. (void)graph;
  2007. return SUCCESS;
  2008. }
  2009. Status TensorFlowWeightsParser::Parse(const char *file, ge::Graph &graph) {
  2010. (void)file;
  2011. (void)graph;
  2012. return SUCCESS;
  2013. }
  2014. Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) {
  2015. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  2016. ErrorManager::GetInstance().GenWorkStreamIdDefault();
  2017. PARSER_TIMESTAMP_START(ParseProto);
  2018. GE_CHECK_NOTNULL(proto);
  2019. GE_CHECK_NOTNULL(graph);
  2020. ge::GetParserContext().train_flag = true;
  2021. const domi::tensorflow::GraphDef *graph_def_in =
  2022. ge::PtrToPtr<google::protobuf::Message, domi::tensorflow::GraphDef>(proto);
  2023. // Make a copy for operation without modifying the original graph def.
  2024. domi::tensorflow::GraphDef graph_def_operation = *graph_def_in;
  2025. domi::tensorflow::GraphDef *graph_def = &graph_def_operation;
  2026. GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(graph_def, domi::TENSORFLOW),
  2027. "Run ProtoType Pass Failed");
  2028. shared_ptr<ge::ScopeGraph> scope_graph = nullptr;
  2029. Status ret = ExcuteScopeFusionPasses(graph_def, scope_graph);
  2030. if (ret != SUCCESS) {
  2031. GELOGE(ret, "[TF Parser] scope fusion failed.");
  2032. return ret;
  2033. }
  2034. GELOGD("[TF Parser] scope fusion success");
  2035. bool has_error = false;
  2036. // Graphdef optimizes identity
  2037. PARSER_TIMESTAMP_START(GraphDefOptimize);
  2038. GE_RETURN_IF_ERROR(GraphDefOptimize(graph_def));
  2039. PARSER_TIMESTAMP_END(GraphDefOptimize, "TensorFlowModelParser::GraphDefOptimize");
  2040. GELOGD("[TF Parser] graph def optimize success");
  2041. // Optimization for TVM operator
  2042. PARSER_TIMESTAMP_START(OptimizeConstNodes4CustomOp);
  2043. GE_RETURN_IF_ERROR(OptimizeConstNodes4CustomOp(graph_def));
  2044. PARSER_TIMESTAMP_END(OptimizeConstNodes4CustomOp, "TensorFlowModelParser::OptimizeConstNodes4CustomOp");
  2045. GELOGD("[TF Parser] optimize const nodes for custom op success");
  2046. GE_RETURN_IF_ERROR(GetTensorflowGraphInOutMap(graph_def));
  2047. GE_RETURN_IF_ERROR(RemoveIsolateNode(graph_def));
  2048. vector<string> op_node_name_list;
  2049. bool isDatasetInit = false;
  2050. PARSER_TIMESTAMP_START(AddFmkNodeDefToMap);
  2051. for (int i = 0; i < graph_def->node_size(); i++) {
  2052. const domi::tensorflow::NodeDef *node_def = graph_def->mutable_node(i);
  2053. if (node_def->op() == ge::parser::IDENTITY && node_def->input_size() == 0) {
  2054. continue;
  2055. }
  2056. if (node_def->op() == ge::parser::SNAPSHOT && node_def->input_size() == 0) {
  2057. continue;
  2058. }
  2059. GE_IF_BOOL_EXEC(node_def->op() == "MakeIterator", isDatasetInit = true);
  2060. // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_
  2061. if (MaybeFusionOp(scope_graph, node_def)) {
  2062. GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str());
  2063. }
  2064. // Do not exit immediately when there is an error, wait until all errors are collected before exiting
  2065. ret = AddFmkNodeDefToMap(node_def, op_node_name_list);
  2066. GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed");
  2067. }
  2068. PARSER_TIMESTAMP_END(AddFmkNodeDefToMap, "TensorFlowModelParser::AddFmkNodeDefToMap");
  2069. GELOGI("[TF Parser] TF subgraph isDatasetInit: %d.", isDatasetInit);
  2070. // Build input and output relationships for all OP nodes
  2071. PARSER_TIMESTAMP_START(GetOpNodesContextFromGraph);
  2072. GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def));
  2073. PARSER_TIMESTAMP_END(GetOpNodesContextFromGraph, "TensorFlowModelParser::GetOpNodesContextFromGraph");
  2074. GELOGD("[TF Parser] Get op nodes context from graph success");
  2075. // Building input-output relationship between fusionop and common op
  2076. GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, op_node_name_list));
  2077. GELOGI("[TF Parser] TF op node size = %zu.", op_node_name_list.size());
  2078. PARSER_TIMESTAMP_START(AddFmkNode);
  2079. // Loop analysis of op_nodes and map them to nodes in graph
  2080. ret = AddFmkNode(graph, scope_graph, op_node_name_list, isDatasetInit);
  2081. PARSER_TIMESTAMP_END(AddFmkNode, "TensorFlowModelParser::AddFmkNode");
  2082. GE_CHK_STATUS_EXEC(ret, DeleteFuisonNodeDef();
  2083. return ret, "AddFmkNode failed");
  2084. GELOGD("[TF Parser] Add framework node success");
  2085. ret = AddEdges(graph);
  2086. Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
  2087. ParserUtils::OutputMapping final_output_nodes;
  2088. GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph, final_output_nodes));
  2089. GE_RETURN_IF_ERROR(UpdateOutputsInfo(final_output_nodes));
  2090. DeleteFuisonNodeDef();
  2091. GE_CHK_STATUS_EXEC(ret, return ret, "AddEdges failed");
  2092. GELOGD("[TF Parser] Add edges success");
  2093. PARSER_TIMESTAMP_START(RemoveIsolateNode);
  2094. // Delete isolated nodes
  2095. GE_RETURN_IF_ERROR(RemoveIsolateNode(graph));
  2096. GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph));
  2097. PARSER_TIMESTAMP_END(RemoveIsolateNode, "TensorFlowModelParser::RemoveIsolateNode");
  2098. PARSER_TIMESTAMP_START(TopologicalSorting);
  2099. GE_RETURN_IF_ERROR(graph->TopologicalSorting());
  2100. PARSER_TIMESTAMP_END(TopologicalSorting, "TensorFlowModelParser::TopologicalSorting");
  2101. ge::parser::PassManager iterator_fusion_pass;
  2102. try {
  2103. (void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass", new ge::IteratorFusionPass(domi::TENSORFLOW));
  2104. } catch (std::bad_alloc &e) {
  2105. GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs.");
  2106. return INTERNAL_ERROR;
  2107. }
  2108. ret = iterator_fusion_pass.Run(graph);
  2109. if (ret != SUCCESS && ret != ge::NOT_CHANGED) {
  2110. GELOGE(ret, "Run graph passes optimize for preprocess failed, ret:%u.", ret);
  2111. return ret;
  2112. }
  2113. has_error = has_error || PreChecker::Instance().HasError();
  2114. if (has_error) {
  2115. GELOGE(PARAM_INVALID, "Precheck has errors.");
  2116. return PARAM_INVALID;
  2117. }
  2118. GELOGI("[TF Parser] Parse proto success.");
  2119. PARSER_TIMESTAMP_END(ParseProto, "TensorFlowModelParser::ParseProto");
  2120. return SUCCESS;
  2121. }
  2122. Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Message *root_proto,
  2123. domi::GetGraphCallback callback, ge::ComputeGraphPtr &root_graph) {
  2124. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  2125. ErrorManager::GetInstance().GenWorkStreamIdDefault();
  2126. GE_CHECK_NOTNULL(root_proto);
  2127. GE_CHECK_NOTNULL(callback);
  2128. GE_CHECK_NOTNULL(root_graph);
  2129. PARSER_TIMESTAMP_START(ParseProtoWithSubgraph);
  2130. std::vector<std::unique_ptr<google::protobuf::Message>> proto_holder;
  2131. std::deque<ParseArg> tasks;
  2132. tasks.push_back({root_proto, "root", nullptr, "", root_graph});
  2133. while (!tasks.empty()) {
  2134. auto arg = tasks.front();
  2135. tasks.pop_front();
  2136. if (arg.proto == nullptr) {
  2137. auto proto = callback(root_proto, arg.function_name);
  2138. if (proto == nullptr) {
  2139. REPORT_CALL_ERROR("E19999", "callback execute failed, func_name:%s", arg.function_name.c_str());
  2140. GELOGE(FAILED, "Failed to get function by name %s", arg.function_name.c_str());
  2141. return FAILED;
  2142. }
  2143. arg.proto = proto.get();
  2144. proto_holder.emplace_back(std::move(proto));
  2145. }
  2146. GELOGI("Begin to parse graph %s", arg.function_name.c_str());
  2147. auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW);
  2148. auto ret = model_parser->ParseProto(arg.proto, arg.graph);
  2149. if (ret != SUCCESS) {
  2150. GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(),
  2151. arg.graph->GetName().c_str());
  2152. return ret;
  2153. }
  2154. ret = PostOpProcessForSubgraph(arg);
  2155. if (ret != SUCCESS) {
  2156. // the error log has been printed inner the function
  2157. return ret;
  2158. }
  2159. ret = GenSubgraphParseTasks(arg.graph, tasks);
  2160. if (ret != SUCCESS) {
  2161. GELOGE(ret, "Failed to gen tasks on graph %s for next iteration", arg.graph->GetName().c_str());
  2162. return ret;
  2163. }
  2164. }
  2165. auto add_ret = AddExternalGraph(root_graph);
  2166. if (add_ret != SUCCESS) {
  2167. GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str());
  2168. return add_ret;
  2169. }
  2170. PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph");
  2171. return SUCCESS;
  2172. }
  2173. Status TensorFlowModelParser::ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) {
  2174. if (serialized_proto.empty()) {
  2175. GELOGE(FAILED, "Deserialize proto failed as serialized proto is empty");
  2176. return FAILED;
  2177. }
  2178. domi::tensorflow::GraphDef graph_def;
  2179. if (!graph_def.ParseFromString(serialized_proto)) {
  2180. GELOGE(FAILED, "Proto object GraphDef parse serialized proto failed");
  2181. return FAILED;
  2182. }
  2183. return ParseProto(ge::PtrToPtr<domi::tensorflow::GraphDef, const google::protobuf::Message>(&graph_def), graph);
  2184. }
  2185. Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback,
  2186. ge::ComputeGraphPtr &root_graph) {
  2187. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  2188. ErrorManager::GetInstance().GenWorkStreamIdDefault();
  2189. GE_CHECK_NOTNULL(callback);
  2190. GE_CHECK_NOTNULL(root_graph);
  2191. PARSER_TIMESTAMP_START(ParseProtoWithSubgraph);
  2192. std::deque<ParseArg> tasks;
  2193. tasks.push_back({nullptr, "root", nullptr, "", root_graph});
  2194. bool root_parsed = false;
  2195. while (!tasks.empty()) {
  2196. auto arg = tasks.front();
  2197. tasks.pop_front();
  2198. auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW);
  2199. Status ret = SUCCESS;
  2200. if (root_parsed) {
  2201. GELOGI("Begin to parse serialized proto of sub graph %s", arg.function_name.c_str());
  2202. ret = model_parser->ParseProto(callback(arg.function_name), arg.graph);
  2203. } else {
  2204. GELOGI("Begin to parse serialized proto of root graph");
  2205. ret = model_parser->ParseProto(root_proto, arg.graph);
  2206. root_parsed = true;
  2207. }
  2208. if (ret != SUCCESS) {
  2209. GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(),
  2210. arg.graph->GetName().c_str());
  2211. return ret;
  2212. }
  2213. ret = PostOpProcessForSubgraph(arg);
  2214. if (ret != SUCCESS) {
  2215. return ret; // the error log has been printed inner the function
  2216. }
  2217. ret = GenSubgraphParseTasks(arg.graph, tasks);
  2218. if (ret != SUCCESS) {
  2219. GELOGE(ret, "Failed to gen tasks for sub graph of graph %s", arg.graph->GetName().c_str());
  2220. return ret;
  2221. }
  2222. }
  2223. auto add_ret = AddExternalGraph(root_graph);
  2224. if (add_ret != SUCCESS) {
  2225. GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str());
  2226. return add_ret;
  2227. }
  2228. PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph");
  2229. return SUCCESS;
  2230. }
  2231. // For the identity operator whose output is "_retval", optimize it.
  2232. Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map,
  2233. const string &curr_node_name, bool &clear_input_flag) {
  2234. auto context_iter = op_node_context_map_.find(curr_node_name);
  2235. if (context_iter == op_node_context_map_.end()) {
  2236. REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str());
  2237. GELOGE(FAILED, "Can't find op node context.");
  2238. return INTERNAL_ERROR;
  2239. }
  2240. OpNodeContext op_node_context = context_iter->second;
  2241. const std::map<std::string, NodeDef *>::const_iterator node_def_iter = nodedef_map.find(curr_node_name);
  2242. if (node_def_iter == nodedef_map.cend()) {
  2243. REPORT_INNER_ERROR("E19999", "Node:%s can't find in nodedef_map, check invalid", curr_node_name.c_str());
  2244. GELOGE(FAILED, "Can't find nodedef");
  2245. return INTERNAL_ERROR;
  2246. }
  2247. domi::tensorflow::NodeDef *curr_node_def = node_def_iter->second;
  2248. GE_CHECK_NOTNULL(curr_node_def);
  2249. bool has_out_retval = false;
  2250. // For the identity operator whose output is "_retval", optimize it
  2251. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map;
  2252. for (auto output_iter = output_map.cbegin(); output_iter != output_map.cend(); ++output_iter) {
  2253. const string &output_node_name = output_iter->first;
  2254. domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name];
  2255. GE_CHECK_NOTNULL(output_node_def);
  2256. if (output_node_def->op() == "_Retval") {
  2257. GELOGW("_Retval Identity need optimize. node:%s", curr_node_name.c_str());
  2258. output_node_def->set_input(0, curr_node_def->input(0).c_str());
  2259. has_out_retval = true;
  2260. GELOGW("op %s set input(0):%s.", output_node_def->name().c_str(), curr_node_def->input(0).c_str());
  2261. }
  2262. }
  2263. // Deal with non _Retval output operator of Identity.
  2264. if (has_out_retval) {
  2265. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>>::const_iterator output_iter = output_map.begin();
  2266. for (; output_iter != output_map.end(); ++output_iter) {
  2267. const string &output_node_name = output_iter->first;
  2268. GELOGW("[test]node name:%s.", output_node_name.c_str());
  2269. domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name];
  2270. GE_CHECK_NOTNULL(output_node_def);
  2271. GELOGW("[test]op name:%s, input size:%u.", output_node_def->op().c_str(), output_node_def->input_size());
  2272. GE_IF_BOOL_EXEC(output_node_def->op() == "_Retval", continue);
  2273. for (int k = 0; k < output_node_def->input_size(); ++k) {
  2274. GELOGW("[test]input name:%s, curr_node_name:%s.", output_node_def->input(k).c_str(), curr_node_name.c_str());
  2275. bool is_control = false;
  2276. string node_name;
  2277. GE_RETURN_IF_ERROR(CheckInputNodeName(output_node_def->input(k), &node_name, nullptr, &is_control));
  2278. GE_IF_BOOL_EXEC(
  2279. node_name == curr_node_name, output_node_def->set_input(k, is_control ? ("^" + curr_node_def->input(0)).c_str() : curr_node_def->input(0).c_str());
  2280. GELOGW("%s op set input(%d):%s, is_control:%d.", output_node_def->name().c_str(), k, curr_node_def->input(0).c_str(), is_control);)
  2281. }
  2282. }
  2283. clear_input_flag = true;
  2284. }
  2285. return SUCCESS;
  2286. }
  2287. Status TensorFlowModelParser::GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def,
  2288. map<string, NodeDef *> &nodedef_map,
  2289. const vector<NodeDef *> &nodedef_to_optimize) {
  2290. GE_CHECK_NOTNULL(graph_def);
  2291. if (!nodedef_to_optimize.empty()) {
  2292. // Building input and input relationships for all OP nodes
  2293. GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def));
  2294. } else {
  2295. return SUCCESS;
  2296. }
  2297. for (auto &curr_node_def : nodedef_to_optimize) {
  2298. GE_CHECK_NOTNULL(curr_node_def);
  2299. bool clear_input_flag = false;
  2300. const string &curr_node_name = curr_node_def->name();
  2301. GE_RETURN_IF_ERROR(OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag));
  2302. if (clear_input_flag) {
  2303. GELOGW("[test]node name:%s.", curr_node_name.c_str());
  2304. curr_node_def->clear_input();
  2305. }
  2306. }
  2307. GELOGI("GraphDefOptimizeIdentity success.");
  2308. return SUCCESS;
  2309. }
  2310. Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def,
  2311. map<string, NodeDef *> &nodedef_map,
  2312. const std::pair<string, int> &input_data,
  2313. const std::vector<string> &control_list) {
  2314. GE_CHECK_NOTNULL(curr_mode_def);
  2315. string curr_node_name = curr_mode_def->name();
  2316. auto context_iter = op_node_context_map_.find(curr_node_name);
  2317. if (context_iter == op_node_context_map_.end()) {
  2318. REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str());
  2319. GELOGE(FAILED, "Can't find op node context.");
  2320. return INTERNAL_ERROR;
  2321. }
  2322. OpNodeContext op_node_context = context_iter->second;
  2323. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map;
  2324. for (auto &output_iter : output_map) {
  2325. const string &output_node_name = output_iter.first;
  2326. domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name];
  2327. GE_CHECK_NOTNULL(output_node_def);
  2328. auto inputs = output_node_def->mutable_input();
  2329. std::vector<std::string> added_inputs;
  2330. for (auto &input : *inputs) {
  2331. string node_name;
  2332. bool is_control = false;
  2333. if (CheckInputNodeName(input, &node_name, nullptr, &is_control) != SUCCESS) {
  2334. GELOGE(FAILED, "parse node input info failed, node %s, input %s.", output_node_def->name().c_str(),
  2335. input.c_str());
  2336. return FAILED;
  2337. }
  2338. if (node_name == curr_node_name) {
  2339. if (is_control) {
  2340. input = "^" + input_data.first;
  2341. } else if (input_data.second == 0) {
  2342. input = input_data.first;
  2343. } else {
  2344. input = input_data.first + ":" + std::to_string(input_data.second);
  2345. }
  2346. GELOGD("Optimize Snapshot node, dest:%s, set input:%s.", output_node_name.c_str(), input.c_str());
  2347. for (auto &item : control_list) {
  2348. bool is_exist_input = false;
  2349. for (auto &tmp_input : output_node_def->input()) {
  2350. string tmp_node_name;
  2351. if (CheckInputNodeName(tmp_input, &tmp_node_name, nullptr, nullptr) != SUCCESS) {
  2352. GELOGE(INTERNAL_ERROR, "parse node input info failed, node %s, input %s.",
  2353. output_node_def->name().c_str(), tmp_input.c_str());
  2354. return FAILED;
  2355. }
  2356. if (tmp_node_name == item) {
  2357. is_exist_input = true;
  2358. break;
  2359. }
  2360. }
  2361. if (!is_exist_input) {
  2362. added_inputs.push_back("^" + item);
  2363. }
  2364. }
  2365. }
  2366. }
  2367. for (std::string added_input : added_inputs) {
  2368. GELOGD("Optimize Snapshot node, dest:%s, set control input:%s.", output_node_name.c_str(), added_input.c_str());
  2369. output_node_def->add_input(added_input);
  2370. }
  2371. }
  2372. // Clear the input of snapshot and become an isolated node
  2373. curr_mode_def->clear_input();
  2374. return SUCCESS;
  2375. }
  2376. Status TensorFlowModelParser::GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def,
  2377. map<string, NodeDef *> &nodedef_map,
  2378. const vector<NodeDef *> &nodedef_to_optimize) {
  2379. GE_CHECK_NOTNULL(graph_def);
  2380. if (!nodedef_to_optimize.empty()) {
  2381. // Building input and input relationships for all OP nodes
  2382. GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def));
  2383. GELOGD("Optimize snapshot num:%zu.", nodedef_to_optimize.size());
  2384. } else {
  2385. return SUCCESS;
  2386. }
  2387. for (auto &curr_node_def : nodedef_to_optimize) {
  2388. GE_CHECK_NOTNULL(curr_node_def);
  2389. std::pair<string, int> input_data; // src node name, src index
  2390. vector<string> control_list;
  2391. uint32_t data_input_cnt = 0;
  2392. for (auto &input : curr_node_def->input()) {
  2393. string node_name;
  2394. int input_index = 0;
  2395. bool is_control = false;
  2396. if (CheckInputNodeName(input, &node_name, &input_index, &is_control) != SUCCESS) {
  2397. GELOGE(FAILED, "parse SnapShot input info failed, node %s, input %s.", curr_node_def->name().c_str(),
  2398. input.c_str());
  2399. return FAILED;
  2400. }
  2401. if (is_control) {
  2402. control_list.push_back(node_name);
  2403. } else {
  2404. data_input_cnt++;
  2405. input_data = std::make_pair(node_name, input_index);
  2406. }
  2407. }
  2408. if (data_input_cnt != 1) {
  2409. REPORT_INNER_ERROR("E19999", "Node:%s's input data size:%u not equal to 1, check invalid",
  2410. curr_node_def->name().c_str(), data_input_cnt);
  2411. GELOGE(FAILED, "%s op data input size %u invalid", curr_node_def->name().c_str(), data_input_cnt);
  2412. return FAILED;
  2413. }
  2414. // Optimize Snapshot Node
  2415. GE_CHK_STATUS_RET(OptimizeSnapShot(curr_node_def, nodedef_map, input_data, control_list));
  2416. }
  2417. GELOGI("GraphDefOptimizeSnapShot success.");
  2418. return SUCCESS;
  2419. }
  2420. Status TensorFlowModelParser::SetDestNodeName(const domi::tensorflow::NodeDef *const node_current,
  2421. domi::tensorflow::NodeDef *const node_dest, const int32_t input_idx,
  2422. const bool is_control, bool &clear_input_flag) {
  2423. GELOGI("current node name is %s ", node_current->name().c_str());
  2424. clear_input_flag = true;
  2425. if (is_control) {
  2426. string node_current_name = node_current->input(0);
  2427. string current_name;
  2428. if (CheckInputNodeName(node_current_name, &current_name, nullptr, nullptr) != SUCCESS) {
  2429. GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", node_current_name.c_str());
  2430. return FAILED;
  2431. }
  2432. current_name = "^" + current_name;
  2433. GELOGI("set nodeCurrentNameTmp: %s", current_name.c_str());
  2434. node_dest->set_input(input_idx, current_name);
  2435. } else {
  2436. node_dest->set_input(input_idx, node_current->input(0).c_str());
  2437. GELOGD("%s op set input:%s.", node_dest->name().c_str(), node_current->input(0).c_str());
  2438. }
  2439. // DestroyTemporaryVariable node have only one input and one output.
  2440. // If the number of inputs is greater than 1, all subsequent inputs are
  2441. // control edge inputs. Therefore, after deleting DestroyTemporaryVariable,
  2442. // these control edge inputs can be directly connected to nodeDst.
  2443. for (int i = 1; i < node_current->input_size(); ++i) {
  2444. node_dest->add_input(node_current->input(i));
  2445. }
  2446. return SUCCESS;
  2447. }
  2448. void TensorFlowModelParser::OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *const graph_def,
  2449. domi::tensorflow::NodeDef *const nodeCurrent,
  2450. bool &clearInputFlag) const {
  2451. // Internal call to ensure that the parameter is not empty.
  2452. GELOGI("DestroyTemporaryVariable optimizing.");
  2453. for (int w = 0; w < graph_def->node_size(); w++) {
  2454. domi::tensorflow::NodeDef *nodeDst = graph_def->mutable_node(w);
  2455. GE_IF_BOOL_EXEC(nodeDst->name() == nodeCurrent->name(), continue);
  2456. for (int k = 0; k < nodeDst->input_size(); k++) {
  2457. string nodeDstInputName = nodeDst->input(k);
  2458. string nodeDstInputNameTmp;
  2459. bool isControl = false;
  2460. if (CheckInputNodeName(nodeDstInputName, &nodeDstInputNameTmp, nullptr, &isControl) != SUCCESS) {
  2461. GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", nodeDstInputName.c_str());
  2462. return;
  2463. }
  2464. if (nodeDstInputNameTmp != nodeCurrent->name()) {
  2465. continue;
  2466. }
  2467. if (SetDestNodeName(nodeCurrent, nodeDst, k, isControl, clearInputFlag) != SUCCESS) {
  2468. GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", nodeCurrent->name().c_str());
  2469. return;
  2470. }
  2471. GELOGI("Optimize DestroyTemporaryVariable successful.");
  2472. }
  2473. }
  2474. }
  2475. Status TensorFlowModelParser::GraphDefOptimizeDestroyTemporaryVariable(
  2476. domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *const nodeCurrent) const {
  2477. if (graph_def == nullptr || nodeCurrent == nullptr) {
  2478. REPORT_INNER_ERROR("E19999", "Param graph_def or nodeCurrent is nullptr, check invalid");
  2479. GELOGE(FAILED, "input param is nullptr.");
  2480. return FAILED;
  2481. }
  2482. if (nodeCurrent->op() != ge::parser::DESTROYTEMPORARYVARIABLE) {
  2483. return SUCCESS;
  2484. }
  2485. GELOGI("Optimize DestroyTemporaryVariable, node name is :%s.", nodeCurrent->name().c_str());
  2486. bool clearInputFlag = false;
  2487. google::protobuf::Map<std::string, domi::tensorflow::AttrValue> *attr_map_destroy = nodeCurrent->mutable_attr();
  2488. domi::tensorflow::AttrValue var_name_attr_destroy = (*attr_map_destroy)[ge::VAR_ATTR_NAME];
  2489. for (int j = 0; j < graph_def->node_size(); j++) {
  2490. domi::tensorflow::NodeDef *nodeTmpVar = graph_def->mutable_node(j);
  2491. GE_IF_BOOL_EXEC(nodeTmpVar->op() != ge::parser::TEMPORARYVARIABLE, continue);
  2492. google::protobuf::Map<std::string, domi::tensorflow::AttrValue> *attr_map_tmp = nodeTmpVar->mutable_attr();
  2493. domi::tensorflow::AttrValue var_name_attr_tmp = (*attr_map_tmp)[ge::VAR_ATTR_NAME];
  2494. if (var_name_attr_destroy.s() != var_name_attr_tmp.s()) {
  2495. continue;
  2496. }
  2497. // Optimize destroytemporaryvariable operator
  2498. OptimizeDestroyTemporaryVariable(graph_def, nodeCurrent, clearInputFlag);
  2499. if (clearInputFlag) {
  2500. nodeCurrent->clear_input(); // Clear the destroytemporaryvariable input to become an isolated node
  2501. break;
  2502. }
  2503. }
  2504. if (!clearInputFlag) {
  2505. REPORT_INNER_ERROR("E19999", "Optimize DestroyTemporaryVariable failed, node name is :%s.",
  2506. nodeCurrent->name().c_str());
  2507. GELOGE(INTERNAL_ERROR, "Optimize DestroyTemporaryVariable failed, node name is :%s.", nodeCurrent->name().c_str());
  2508. return FAILED;
  2509. }
  2510. return SUCCESS;
  2511. }
  2512. struct DelTransposeInfo {
  2513. domi::tensorflow::NodeDef *node_def; // transpose
  2514. domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next]
  2515. int inputIdx;
  2516. };
  2517. Status GetTransposeInfo(domi::tensorflow::GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo,
  2518. std::map<std::string, DelTransposeInfo> &transposeInfo) {
  2519. GE_CHECK_NOTNULL(graph_def);
  2520. for (int i = 0; i < graph_def->node_size(); ++i) {
  2521. auto node_def = graph_def->mutable_node(i);
  2522. if (node_def->op() == ge::parser::TRANSPOSE) {
  2523. DelTransposeInfo transpose;
  2524. transpose.node_def = node_def;
  2525. transposeInfo.insert(std::make_pair(node_def->name(), transpose));
  2526. } else if (node_def->op() == ge::parser::SOFTMAX) {
  2527. softmaxInfo.insert(std::make_pair(node_def->name(), node_def->input(0)));
  2528. GELOGI("softmax name:%s, input name:%s", node_def->name().c_str(), node_def->input(0).c_str());
  2529. }
  2530. }
  2531. for (auto &itTranspose : transposeInfo) {
  2532. for (int j = 0; j < graph_def->node_size(); ++j) {
  2533. auto nextNodeDef = graph_def->mutable_node(j);
  2534. bool bFind = false;
  2535. for (int k = 0; k < nextNodeDef->input_size(); ++k) {
  2536. if (nextNodeDef->input(k) == itTranspose.first) {
  2537. itTranspose.second.nextNodeDef = nextNodeDef;
  2538. itTranspose.second.inputIdx = k;
  2539. GELOGI("transpose info name:%s, next name:%s, idx:%d", itTranspose.second.node_def->name().c_str(),
  2540. nextNodeDef->name().c_str(), k);
  2541. bFind = true;
  2542. break;
  2543. }
  2544. }
  2545. if (bFind) {
  2546. break;
  2547. }
  2548. }
  2549. }
  2550. return SUCCESS;
  2551. }
  2552. Status EraseTransposeNode(std::map<std::string, std::string> &softmaxInfo,
  2553. std::map<std::string, DelTransposeInfo> &transposeInfo) {
  2554. std::map<std::string, DelTransposeInfo>::const_iterator itTranspose = transposeInfo.begin();
  2555. for (; itTranspose != transposeInfo.end();) {
  2556. // transpose --> softmax
  2557. bool bErase = true;
  2558. if (softmaxInfo.find(itTranspose->second.node_def->input(0)) != softmaxInfo.end() ||
  2559. softmaxInfo.find(itTranspose->second.nextNodeDef->name()) != softmaxInfo.end()) {
  2560. bErase = false;
  2561. }
  2562. if (bErase) {
  2563. GELOGI("erase node name:%s, input(0):%s", itTranspose->first.c_str(),
  2564. itTranspose->second.node_def->input(0).c_str());
  2565. itTranspose = transposeInfo.erase(itTranspose);
  2566. } else {
  2567. ++itTranspose;
  2568. }
  2569. }
  2570. if ((softmaxInfo.size() <= SIZE_MAX / kSoftmaxMultiple) &&
  2571. (softmaxInfo.size() * kSoftmaxMultiple != transposeInfo.size())) {
  2572. GELOGW("softmax size[%zu], transpose size[%zu]", softmaxInfo.size(), transposeInfo.size());
  2573. return FAILED;
  2574. }
  2575. return SUCCESS;
  2576. }
  2577. void TensorFlowModelParser::OptimizeTranspose(std::map<std::string, DelTransposeInfo> &transposeInfo) {
  2578. for (auto &it : transposeInfo) {
  2579. auto transpose = it.second;
  2580. transpose.nextNodeDef->set_input(transpose.inputIdx, transpose.node_def->input(kTransposeInputIdx));
  2581. transpose.node_def->clear_input();
  2582. }
  2583. }
  2584. void TensorFlowModelParser::SoftmaxAddAttr(domi::tensorflow::GraphDef *const graph_def) {
  2585. // The caller guarantees that the pointer is not null
  2586. for (int i = 0; i < graph_def->node_size(); ++i) {
  2587. auto node_def = graph_def->mutable_node(i);
  2588. if (node_def->op() == ge::parser::SOFTMAX) {
  2589. domi::tensorflow::AttrValue attr_value;
  2590. attr_value.set_i(1);
  2591. ge::TensorFlowUtil::AddNodeAttr("axis", attr_value, node_def);
  2592. GELOGI("SoftmaxAddAttr, name: %s, input name:%s", node_def->name().c_str(), node_def->input(0).c_str());
  2593. }
  2594. }
  2595. }
  2596. Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph_def) {
  2597. GE_CHECK_NOTNULL(graph_def);
  2598. map<string, NodeDef *> nodedef_map;
  2599. vector<string> op_node_name_list;
  2600. // Save Identity and ReadVariableOp
  2601. vector<NodeDef *> identity_to_optimize;
  2602. // Save Snapshot
  2603. vector<NodeDef *> snapshot_to_optimize;
  2604. for (int i = 0; i < graph_def->node_size(); i++) {
  2605. // mutable_node return vale is not empty
  2606. domi::tensorflow::NodeDef *node_def = graph_def->mutable_node(i);
  2607. const string &node_name = node_def->name();
  2608. Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list);
  2609. GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed");
  2610. if (node_def->op() == ge::parser::IDENTITY || node_def->op() == ge::parser::READVARIABLEOP) {
  2611. identity_to_optimize.push_back(node_def);
  2612. } else if (node_def->op() == ge::parser::SNAPSHOT) {
  2613. snapshot_to_optimize.push_back(node_def);
  2614. }
  2615. nodedef_map[node_name] = node_def;
  2616. }
  2617. // Optimize for Identity/ReadVariableOp
  2618. GE_RETURN_IF_ERROR(GraphDefOptimizeIdentity(graph_def, nodedef_map, identity_to_optimize));
  2619. // Optimize for Snapshot
  2620. GE_RETURN_IF_ERROR(GraphDefOptimizeSnapShot(graph_def, nodedef_map, snapshot_to_optimize));
  2621. for (int i = 0; i < graph_def->node_size(); i++) {
  2622. domi::tensorflow::NodeDef *nodeCurrent = graph_def->mutable_node(i);
  2623. GE_CHK_STATUS_RET(GraphDefOptimizeDestroyTemporaryVariable(graph_def, nodeCurrent));
  2624. }
  2625. // These member variables will be rebuilt later and need to be cleared here.
  2626. nodedef_map_.clear();
  2627. op_node_context_map_.clear();
  2628. return SUCCESS;
  2629. }
  2630. Status TensorFlowModelParser::RemoveIsolateNode(const ge::ComputeGraphPtr &graph) {
  2631. GE_CHECK_NOTNULL(graph);
  2632. auto nodes = graph->GetDirectNode();
  2633. for (auto &n : nodes) {
  2634. // get front 4 char
  2635. if (n->GetName().substr(0, 4) == "dpop") {
  2636. continue;
  2637. }
  2638. if ((n->GetType() == ge::parser::DATA) ||
  2639. (ge::GetParserContext().out_nodes_map.find(n->GetName()) != ge::GetParserContext().out_nodes_map.end())) {
  2640. GELOGI("Can not remove op [%s] because it is data or out node.", n->GetName().c_str());
  2641. continue;
  2642. }
  2643. GE_IF_BOOL_EXEC((((n->GetInAllNodes().size() == 0) && (n->GetOutDataNodes().size() == 0)) ||
  2644. ((n->GetType() == ge::parser::CONSTANTOP || n->GetType() == ge::parser::CONSTANT) &&
  2645. (n->GetOutDataNodes().size() == 0))),
  2646. GE_CHK_STATUS_RET(ge::GraphUtils::IsolateNode(n, {}), "Isolate removed node: %s, type: %s failed",
  2647. n->GetName().c_str(), n->GetType().c_str());
  2648. GE_CHK_STATUS_RET(ge::GraphUtils::RemoveNodeWithoutRelink(graph, n),
  2649. "Remove node: %s, type: %s without relink failed", n->GetName().c_str(),
  2650. n->GetType().c_str()););
  2651. }
  2652. return SUCCESS;
  2653. }
  2654. // The format specified by the command line argument is preferred,
  2655. // if not specified, use InferInputFormats to infer,
  2656. // and if the inference fails, the default NHWC format is used.
  2657. domiTensorFormat_t TensorFlowModelParser::InferInputFormats() {
  2658. GE_IF_BOOL_EXEC(ge::GetParserContext().format != DOMI_TENSOR_RESERVED, return ge::GetParserContext().format);
  2659. domiTensorFormat_t global_input_format = DOMI_TENSOR_RESERVED;
  2660. set<const NodeDef *> visited_node;
  2661. for (auto &node_item : nodedef_map_) {
  2662. // Infer format for data node and save it to ge::GetParserContext().format.
  2663. domiTensorFormat_t format = DOMI_TENSOR_RESERVED;
  2664. const NodeDef *node = node_item.second;
  2665. if (node == nullptr) {
  2666. return format;
  2667. }
  2668. auto it = tensorflow_op_map.find(node->op());
  2669. if (it != tensorflow_op_map.end() && it->second == ge::parser::DATA) {
  2670. GE_IF_BOOL_EXEC(GetNodeFormat(node, NO_TRANSPOSE, format, visited_node) != SUCCESS,
  2671. GELOGW("Cannot infer input format, the NHWC format is used by default, and you can also "
  2672. "specify format by command line arguments.");
  2673. return domi::DOMI_TENSOR_NHWC);
  2674. GE_IF_BOOL_EXEC(global_input_format == DOMI_TENSOR_RESERVED, global_input_format = format);
  2675. GE_IF_BOOL_EXEC(
  2676. format != DOMI_TENSOR_RESERVED && format != global_input_format,
  2677. GELOGW("Multiple data ops with different formats are not supported, "
  2678. "the NHWC format is used by default, and you can also specify format by command line arguments.");
  2679. return domi::DOMI_TENSOR_NHWC);
  2680. }
  2681. }
  2682. return global_input_format == DOMI_TENSOR_RESERVED ? domi::DOMI_TENSOR_NHWC : global_input_format;
  2683. }
  2684. Status TensorFlowModelParser::GetNodeFormat(const NodeDef *node, TfTranspose pred_transpose, domiTensorFormat_t &format,
  2685. set<const NodeDef *> &visited_node) {
  2686. GE_CHECK_NOTNULL(node);
  2687. // Avoid repeated visits.
  2688. GE_IF_BOOL_EXEC(visited_node.find(node) != visited_node.end(), return SUCCESS);
  2689. visited_node.emplace(node);
  2690. GE_IF_BOOL_EXEC(node->op() == TENSORFLOWF_NODE_OP_SWITCH || node->op() == TENSORFLOWF_NODE_OP_MERGE, return SUCCESS);
  2691. // If node has a data_format attribute, format is set according to data_format.
  2692. domi::tensorflow::AttrValue attr;
  2693. if (ge::TensorFlowUtil::FindAttrValue(node, TENSORFLOW_ATTR_DATA_FORMAT, attr) && node->op() != ge::parser::BIASADD) {
  2694. GE_RETURN_IF_ERROR(ge::TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_STRING));
  2695. format = (attr.s() == TENSORFLOWF_TENSOR_NCHW) ? domi::DOMI_TENSOR_NCHW : domi::DOMI_TENSOR_NHWC;
  2696. GE_IF_BOOL_EXEC(format == domi::DOMI_TENSOR_NCHW && pred_transpose == TO_NCHW, format = domi::DOMI_TENSOR_NHWC);
  2697. GE_IF_BOOL_EXEC(format == domi::DOMI_TENSOR_NHWC && pred_transpose == TO_NHWC, format = domi::DOMI_TENSOR_NCHW);
  2698. GE_IF_BOOL_EXEC((format == domi::DOMI_TENSOR_NCHW && pred_transpose == TO_NHWC) ||
  2699. (format == domi::DOMI_TENSOR_NHWC && pred_transpose == TO_NCHW),
  2700. GELOGI("Format conflicts with transpose.");
  2701. return FAILED);
  2702. return SUCCESS;
  2703. }
  2704. TfTranspose transpose;
  2705. GE_RETURN_IF_ERROR(GetFormatTranspose(node, transpose));
  2706. GE_IF_BOOL_EXEC(pred_transpose == transpose && pred_transpose != NO_TRANSPOSE,
  2707. GELOGI("Multiple transpose conflicts.");
  2708. return FAILED);
  2709. // If node does not have the data_format attribute, format is set according to the output node.
  2710. string node_name = node->name();
  2711. GE_IF_BOOL_EXEC(op_node_context_map_.find(node_name) == op_node_context_map_.end(),
  2712. GELOGI("node %s not found in op_node_context_map_", node_name.c_str());
  2713. return FAILED);
  2714. domiTensorFormat_t inferred_format = DOMI_TENSOR_RESERVED;
  2715. const OpNodeContext &node_ctx = op_node_context_map_.at(node_name);
  2716. for (const auto &output_item : node_ctx.output_map) {
  2717. auto node_iter = nodedef_map_.find(output_item.first);
  2718. GE_IF_BOOL_EXEC(node_iter == nodedef_map_.end(),
  2719. GELOGI("node %s not found in nodedef_map_", output_item.first.c_str());
  2720. return FAILED);
  2721. const NodeDef *output_node = node_iter->second;
  2722. GE_CHECK_NOTNULL(output_node);
  2723. domiTensorFormat_t output_format = DOMI_TENSOR_RESERVED;
  2724. GE_RETURN_IF_ERROR(GetNodeFormat(output_node, transpose, output_format, visited_node));
  2725. GE_IF_BOOL_EXEC(output_format != DOMI_TENSOR_RESERVED && inferred_format != DOMI_TENSOR_RESERVED &&
  2726. output_format != inferred_format,
  2727. GELOGI("Multiple output formats conflict.");
  2728. return FAILED);
  2729. inferred_format = output_format;
  2730. }
  2731. format = inferred_format;
  2732. return SUCCESS;
  2733. }
  2734. Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node, TfTranspose &transpose_direc) const {
  2735. GE_CHECK_NOTNULL(transpose_node);
  2736. transpose_direc = NO_TRANSPOSE;
  2737. GE_IF_BOOL_EXEC(transpose_node->op() != TENSORFLOWF_NODE_OP_TRANSPOSE, return SUCCESS);
  2738. GE_IF_BOOL_EXEC(transpose_node->input_size() != kInputNumInt, GELOGI("Input size of transpose is not 2.");
  2739. return FAILED);
  2740. string perm_node_name = transpose_node->input(1);
  2741. auto it = nodedef_map_.find(perm_node_name);
  2742. GE_IF_BOOL_EXEC(it == nodedef_map_.end(), GELOGI("Node %s not found in nodedef_map_.", perm_node_name.c_str());
  2743. return FAILED);
  2744. const NodeDef *perm_node = it->second;
  2745. GE_CHECK_NOTNULL(perm_node);
  2746. domi::tensorflow::AttrValue attr_value;
  2747. GE_IF_BOOL_EXEC(perm_node->op() != TENSORFLOWF_NODE_OP_CONST, GELOGI("Input node of transpose is not const.");
  2748. return FAILED);
  2749. GE_IF_BOOL_EXEC(!ge::TensorFlowUtil::FindAttrValue(perm_node, TENSORFLOW_ATTR_DTYPE, attr_value), return FAILED);
  2750. GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TYPE) != SUCCESS,
  2751. return FAILED);
  2752. domi::tensorflow::DataType type = attr_value.type();
  2753. GE_IF_BOOL_EXEC(type != domi::tensorflow::DT_INT32 && type != domi::tensorflow::DT_INT64, return FAILED);
  2754. GE_IF_BOOL_EXEC(!ge::TensorFlowUtil::FindAttrValue(perm_node, TENSORFLOW_ATTR_VALUE, attr_value), return FAILED);
  2755. GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TENSOR) != SUCCESS,
  2756. return FAILED);
  2757. const TensorProto &tensor = attr_value.tensor();
  2758. const domi::tensorflow::TensorShapeProto &tensor_shape = tensor.tensor_shape();
  2759. GE_IF_BOOL_EXEC(tensor_shape.dim_size() != 1 || tensor_shape.dim(0).size() != parser::DIM_DEFAULT_SIZE,
  2760. return SUCCESS);
  2761. GE_IF_BOOL_EXEC(tensor.tensor_content().empty(), return SUCCESS);
  2762. vector<int64_t> perm_value;
  2763. GE_IF_BOOL_EXEC(
  2764. type == domi::tensorflow::DT_INT32,
  2765. const int32_t *data = reinterpret_cast<const int32_t *>(tensor.tensor_content().data());
  2766. for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) {
  2767. perm_value.push_back(data[i]);
  2768. });
  2769. GE_IF_BOOL_EXEC(
  2770. type == domi::tensorflow::DT_INT64,
  2771. const int64_t *data = reinterpret_cast<const int64_t *>(tensor.tensor_content().data());
  2772. for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) {
  2773. perm_value.push_back(data[i]);
  2774. });
  2775. // 0, 1, 2, 3 present dim num.
  2776. vector<int64_t> perm_to_nchw = {0, 3, 1, 2};
  2777. vector<int64_t> perm_to_nhwc = {0, 2, 3, 1};
  2778. GE_IF_BOOL_EXEC(perm_value == perm_to_nchw, transpose_direc = TO_NCHW);
  2779. GE_IF_BOOL_EXEC(perm_value == perm_to_nhwc, transpose_direc = TO_NHWC);
  2780. return SUCCESS;
  2781. }
  2782. Status TensorFlowModelParser::TrimGraph(const domi::tensorflow::GraphDef &input_graph_def,
  2783. domi::tensorflow::GraphDef *output_graph_def) {
  2784. GE_CHECK_NOTNULL(output_graph_def);
  2785. if (!ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) {
  2786. return TrimGraphByInput(input_graph_def, output_graph_def);
  2787. } else {
  2788. return TrimGraphByOutput(input_graph_def, output_graph_def);
  2789. }
  2790. }
  2791. Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef &input_graph_def,
  2792. domi::tensorflow::GraphDef *const output_graph_def) {
  2793. // The caller guarantees that the pointer is not null
  2794. std::set<string> delete_nodes;
  2795. std::set<string> input_nodes;
  2796. for (auto &iter : ge::GetParserContext().input_dims) {
  2797. input_nodes.insert(iter.first);
  2798. }
  2799. std::map<string, const NodeDef *> node_lookup;
  2800. for (const NodeDef &node : input_graph_def.node()) {
  2801. node_lookup[node.name()] = &node;
  2802. }
  2803. std::vector<string> current_inputs;
  2804. for (auto &iter : ge::GetParserContext().input_dims) {
  2805. current_inputs.push_back(iter.first);
  2806. }
  2807. while (!current_inputs.empty()) {
  2808. std::set<string> next_inputs;
  2809. for (const string &current_input : current_inputs) {
  2810. delete_nodes.insert(current_input);
  2811. GE_CHK_BOOL_EXEC(node_lookup.count(current_input) > 0U,
  2812. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  2813. {"input_shape", current_input});
  2814. return FAILED, "Input op[%s] not found in graph.", current_input.c_str());
  2815. const NodeDef *current_node = node_lookup[current_input];
  2816. GE_CHECK_NOTNULL(current_node);
  2817. for (const string &input_name : current_node->input()) {
  2818. string input_node_name = NodeNameFromInput(input_name);
  2819. if (delete_nodes.count(input_node_name) == 0U) {
  2820. next_inputs.insert(input_node_name);
  2821. }
  2822. }
  2823. }
  2824. current_inputs = std::vector<string>(next_inputs.begin(), next_inputs.end());
  2825. }
  2826. domi::tensorflow::GraphDef filtered_graph_def;
  2827. filtered_graph_def.mutable_node()->Clear();
  2828. for (const NodeDef &node : input_graph_def.node()) {
  2829. if (static_cast<bool>(input_nodes.count(node.name()))) {
  2830. *(filtered_graph_def.mutable_node()->Add()) = node;
  2831. }
  2832. if (delete_nodes.count(node.name()) == 0U) {
  2833. *(filtered_graph_def.mutable_node()->Add()) = node;
  2834. }
  2835. }
  2836. output_graph_def->Clear();
  2837. for (const NodeDef &node : filtered_graph_def.node()) {
  2838. if (static_cast<bool>(input_nodes.count(node.name()))) {
  2839. NodeDef placeholder_node = node;
  2840. placeholder_node.clear_input();
  2841. GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder"));
  2842. domi::tensorflow::AttrValue attr_value;
  2843. domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape();
  2844. GE_CHECK_NOTNULL(data_shape);
  2845. const ge::ParserContext &ctx = ge::GetParserContext();
  2846. std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
  2847. std::vector<int64_t> designated_dims = input_dims.at(node.name());
  2848. for (int32_t i = 0; i < static_cast<int32_t>(designated_dims.size()); i++) {
  2849. data_shape->add_dim()->set_size(designated_dims[i]);
  2850. }
  2851. google::protobuf::Map<std::string, domi::tensorflow::AttrValue> *attr = placeholder_node.mutable_attr();
  2852. (*attr)[TENSORFLOW_ATTR_SHAPE] = attr_value;
  2853. GE_CHECK_NOTNULL(output_graph_def->mutable_node());
  2854. *(output_graph_def->mutable_node()->Add()) = placeholder_node;
  2855. } else {
  2856. GE_CHECK_NOTNULL(output_graph_def->mutable_node());
  2857. *(output_graph_def->mutable_node()->Add()) = node;
  2858. }
  2859. }
  2860. return SUCCESS;
  2861. }
  2862. Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef &input_graph_def,
  2863. domi::tensorflow::GraphDef *const output_graph_def) {
  2864. // The caller guarantees that the pointer is not null
  2865. std::set<string> required_nodes;
  2866. std::set<string> input_nodes;
  2867. for (auto &iter : ge::GetParserContext().input_dims) {
  2868. required_nodes.insert(iter.first);
  2869. input_nodes.insert(iter.first);
  2870. }
  2871. for (auto &iter : ge::GetParserContext().out_nodes_map) {
  2872. required_nodes.insert(iter.first);
  2873. }
  2874. std::map<string, const NodeDef *> node_lookup;
  2875. for (const NodeDef &node : input_graph_def.node()) {
  2876. node_lookup[node.name()] = &node;
  2877. }
  2878. std::vector<string> current_inputs;
  2879. for (auto &iter : ge::GetParserContext().out_nodes_map) {
  2880. current_inputs.push_back(iter.first);
  2881. }
  2882. while (!current_inputs.empty()) {
  2883. std::set<string> next_inputs;
  2884. for (const string &current_input : current_inputs) {
  2885. required_nodes.insert(current_input);
  2886. GE_IF_BOOL_EXEC(static_cast<bool>(input_nodes.count(current_input)), continue);
  2887. GE_CHK_BOOL_EXEC(node_lookup.count(current_input) > 0U,
  2888. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  2889. {"out_nodes", current_input});
  2890. return FAILED, "op[%s] not found in graph.", current_input.c_str());
  2891. const NodeDef *current_node = node_lookup[current_input];
  2892. GE_CHECK_NOTNULL(current_node);
  2893. for (const string &input_name : current_node->input()) {
  2894. string input_node_name = NodeNameFromInput(input_name);
  2895. if (required_nodes.count(input_node_name) == 0U) {
  2896. next_inputs.insert(input_node_name);
  2897. }
  2898. }
  2899. }
  2900. current_inputs = std::vector<string>(next_inputs.begin(), next_inputs.end());
  2901. }
  2902. domi::tensorflow::GraphDef filtered_graph_def;
  2903. filtered_graph_def.mutable_node()->Clear();
  2904. for (const NodeDef &node : input_graph_def.node()) {
  2905. if (static_cast<bool>(required_nodes.count(node.name()))) {
  2906. *(filtered_graph_def.mutable_node()->Add()) = node;
  2907. }
  2908. }
  2909. output_graph_def->Clear();
  2910. for (const NodeDef &node : filtered_graph_def.node()) {
  2911. if (static_cast<bool>(input_nodes.count(node.name()))) {
  2912. NodeDef placeholder_node = node;
  2913. placeholder_node.clear_input();
  2914. GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder"));
  2915. domi::tensorflow::AttrValue attr_value;
  2916. domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape();
  2917. GE_CHECK_NOTNULL(data_shape);
  2918. const ge::ParserContext &ctx = ge::GetParserContext();
  2919. std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
  2920. std::vector<int64_t> designated_dims = input_dims.at(node.name());
  2921. for (int32_t i = 0; i < static_cast<int32_t>(designated_dims.size()); i++) {
  2922. data_shape->add_dim()->set_size(designated_dims[i]);
  2923. }
  2924. google::protobuf::Map<std::string, domi::tensorflow::AttrValue> *attr = placeholder_node.mutable_attr();
  2925. (*attr)[TENSORFLOW_ATTR_SHAPE] = attr_value;
  2926. GE_CHECK_NOTNULL(output_graph_def->mutable_node());
  2927. *(output_graph_def->mutable_node()->Add()) = placeholder_node;
  2928. } else {
  2929. GE_CHECK_NOTNULL(output_graph_def->mutable_node());
  2930. *(output_graph_def->mutable_node()->Add()) = node;
  2931. }
  2932. }
  2933. return SUCCESS;
  2934. }
  2935. string TensorFlowModelParser::NodeNameFromInput(const string &input_name) {
  2936. string prefix;
  2937. string node_name;
  2938. string suffix;
  2939. std::vector<string> input_parts = ge::StringUtils::Split(input_name, ':');
  2940. suffix = (input_parts.size() < kInputNumUint) ? "" : (":" + input_parts[1]);
  2941. string tmp_name = input_parts[0];
  2942. GE_IF_BOOL_EXEC(input_parts[0].find("^") == 0, tmp_name = tmp_name.substr(1, tmp_name.length() - 1));
  2943. node_name = tmp_name;
  2944. return node_name;
  2945. }
  2946. Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_parser,
  2947. const domi::tensorflow::NodeDef *node_def,
  2948. ge::NodePtr &node) const {
  2949. GE_CHECK_NOTNULL(node_def);
  2950. GE_CHECK_NOTNULL(node);
  2951. GE_CHECK_NOTNULL(op_parser);
  2952. GELOGI("FusionNodeParseParams:node name:%s.", node_def->name().c_str());
  2953. // The fusion operator deals with parseparams separately
  2954. shared_ptr<TensorFlowFusionOpParser> tensorflow_fusion_op_parser =
  2955. std::dynamic_pointer_cast<TensorFlowFusionOpParser>(op_parser);
  2956. GE_IF_BOOL_EXEC(tensorflow_fusion_op_parser == nullptr,
  2957. REPORT_INNER_ERROR("E19999", "Param op_parser is not TensorFlowFusionOpParser Type, check invalid");
  2958. GELOGE(FAILED, "node :%s can not get fusion parser, please check!", node_def->name().c_str());
  2959. return INTERNAL_ERROR);
  2960. // Find all children of the fusion operator
  2961. auto iter = fusion_op_nodedef_map_.find(node_def->name());
  2962. if (iter == fusion_op_nodedef_map_.end()) {
  2963. REPORT_INNER_ERROR("E19999", "Node:%s can't find in fusion_op_nodedef_map_, check invalid",
  2964. node_def->name().c_str());
  2965. GELOGE(FAILED, "FusionOp node %s has no children node!", node_def->name().c_str());
  2966. return INTERNAL_ERROR;
  2967. }
  2968. (void)ge::AttrUtils::SetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, node_def->op());
  2969. vector<const domi::tensorflow::NodeDef *> node_def_v = iter->second;
  2970. domi::FusionParseParamByOpFunc parse_param_func =
  2971. domi::OpRegistry::Instance()->GetFusionParseParamByOpFunc(node->GetType(), node_def->op());
  2972. Status status = FAILED;
  2973. if (parse_param_func == nullptr) {
  2974. status = tensorflow_fusion_op_parser->ParseParams(node_def_v, node);
  2975. GE_CHK_STATUS_EXEC(status, return status, "Parse Params for fusionop node %s failed", node_def->name().c_str());
  2976. } else {
  2977. vector<ge::Operator> op_src_vec;
  2978. for (const auto &node_def_src : node_def_v) {
  2979. ge::Operator op_src(node_def_src->name().c_str(), node_def_src->op().c_str());
  2980. status = domi::OperatorAutoMapping(node_def_src, op_src);
  2981. if (status != SUCCESS) {
  2982. REPORT_CALL_ERROR("E19999", "Auto mapping node_def:%s(%s) to operator failed", node_def_src->name().c_str(),
  2983. node_def_src->op().c_str());
  2984. GELOGE(status, "Node[%s] auto mapping failed", node_def_src->name().c_str());
  2985. return status;
  2986. }
  2987. auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_src);
  2988. GE_CHECK_NOTNULL(op_desc);
  2989. for (int32_t i = 0; i < node_def_src->input_size(); i++) {
  2990. ge::GeTensorDesc tensor_desc;
  2991. tensor_desc.SetName(node_def_src->input(i));
  2992. if (op_desc->AddInputDesc(tensor_desc) != GRAPH_SUCCESS) {
  2993. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", op_desc->GetName().c_str(),
  2994. op_desc->GetType().c_str());
  2995. GELOGE(FAILED, "Op [%s] type[%s] add input(%d) tensor failed.", op_desc->GetName().c_str(),
  2996. op_desc->GetType().c_str(), i);
  2997. return FAILED;
  2998. }
  2999. }
  3000. op_src_vec.push_back(op_src);
  3001. }
  3002. shared_ptr<TensorFlowFusionCustomParserAdapter> tf_custom_fusion_op_paser =
  3003. std::dynamic_pointer_cast<TensorFlowFusionCustomParserAdapter>(tensorflow_fusion_op_parser);
  3004. status = tf_custom_fusion_op_paser->ParseParams(op_src_vec, node);
  3005. if (status != SUCCESS) {
  3006. GELOGE(status, "Parse params for fusionop node %s failed", node_def->name().c_str());
  3007. return status;
  3008. }
  3009. }
  3010. return SUCCESS;
  3011. }
  3012. /**
  3013. * @ingroup domi_omg
  3014. * @brief Optimizing const nodes for custom operators
  3015. * @param [in] graph_def graph object
  3016. * @return true optimize successfully
  3017. * @return false optimize failed
  3018. *
  3019. */
  3020. Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) const {
  3021. GE_CHECK_NOTNULL(graph_def);
  3022. // 1. find all the nodes in the graph and save them to all_nodedef_map
  3023. map<string, NodeDef *> all_nodedef_map;
  3024. int graph_node_size = graph_def->node_size();
  3025. for (int i = 0; i != graph_node_size; ++i) {
  3026. // mutable_node return vale is not empty
  3027. domi::tensorflow::NodeDef *current_node = graph_def->mutable_node(i);
  3028. string node_name = current_node->name();
  3029. all_nodedef_map[node_name] = current_node;
  3030. }
  3031. GELOGD("node size is: %zu", all_nodedef_map.size());
  3032. // 2. move input to attr.
  3033. for (auto &it_node_map : all_nodedef_map) {
  3034. domi::tensorflow::NodeDef *current_node = it_node_map.second;
  3035. GE_CHECK_NOTNULL(current_node);
  3036. string current_op_name = current_node->op();
  3037. // 2.1. check whether the current op is register for move to attr.
  3038. const std::vector<domi::RemoveInputConfigure> &move_input_vec =
  3039. domi::OpRegistry::Instance()->GetRemoveInputConfigure(current_op_name);
  3040. // 2.2 check whether the current op is a TVM op.
  3041. const bool is_unknown_custom_op = move_input_vec.empty() ||
  3042. (domi::OpRegistry::Instance()->GetImplyTypeByOriOpType(current_op_name) != domi::ImplyType::TVM);
  3043. if (is_unknown_custom_op) {
  3044. GELOGI("op %s is not TVM op, move input size: %zu", current_op_name.c_str(), move_input_vec.size());
  3045. continue;
  3046. }
  3047. GELOGD("Current op %s is registered for remove input and tvm op", current_op_name.c_str());
  3048. // 2.3 copy input to attr
  3049. set<uint32_t> unused_inputs;
  3050. for (const auto &it : move_input_vec) {
  3051. uint32_t move_index;
  3052. if (it.inputIdx >= 0) {
  3053. move_index = it.inputIdx;
  3054. } else {
  3055. GE_IF_BOOL_EXEC(
  3056. -it.inputIdx > current_node->input_size(),
  3057. ErrorManager::GetInstance().ATCReportErrMessage(
  3058. "E12004", {"opname", "inputIdx", "inputsize"},
  3059. {current_op_name, std::to_string(-it.inputIdx), std::to_string(current_node->input_size())});
  3060. GELOGE(INTERNAL_ERROR,
  3061. "Op[%s] register failed, inputIdx[-%d] should be greater than inputsize[%d] when inputIdx < 0.",
  3062. current_op_name.c_str(), it.inputIdx, current_node->input_size());
  3063. return PARAM_INVALID);
  3064. move_index = current_node->input_size() + it.inputIdx;
  3065. }
  3066. // For an isolated node in deep lab V3 networ.
  3067. // solve the problem of protobuf index less current_size.
  3068. GE_IF_BOOL_EXEC(current_node->input_size() == 0, GELOGI("Input size is 0, already optimized"); continue);
  3069. if (it.moveType == domi::RemoveInputType::OMG_REMOVE_TYPE_WITH_COND) {
  3070. domi::tensorflow::AttrValue attr_value;
  3071. GE_IF_BOOL_EXEC(!(ge::TensorFlowUtil::FindAttrValue(current_node, it.attrName, attr_value)),
  3072. REPORT_INNER_ERROR("E19999", "Op:%s register AttrName[%s] has no value, check invalid",
  3073. current_op_name.c_str(), it.attrName.c_str());
  3074. GELOGE(INTERNAL_ERROR, "AttrName[%s] has no value!", it.attrName.c_str());
  3075. return PARAM_INVALID);
  3076. GE_IF_BOOL_EXEC(attr_value.b() == it.attrValue, unused_inputs.insert(move_index));
  3077. } else if (it.moveType == domi::RemoveInputType::OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE &&
  3078. it.originalType == current_op_name) {
  3079. GELOGD("Input %s:%d will be removed.", current_op_name.c_str(), move_index);
  3080. unused_inputs.insert(move_index);
  3081. } else if (it.moveType == domi::RemoveInputType::OMG_INPUT_REORDER) {
  3082. auto inputs = current_node->input();
  3083. if (static_cast<size_t>(inputs.size()) != it.input_order.size()) {
  3084. REPORT_INNER_ERROR("E19999", "Input size of node:%s(%s) is mismatched, new order size:%zu, input size:%d",
  3085. current_node->name().c_str(), current_node->op().c_str(), it.input_order.size(),
  3086. inputs.size());
  3087. GELOGE(INTERNAL_ERROR, "Size of input is mismatched, new order size is %zu, input size is %d.",
  3088. it.input_order.size(), inputs.size());
  3089. return INTERNAL_ERROR;
  3090. }
  3091. for (size_t i = 0; i < it.input_order.size(); ++i) {
  3092. int new_index = it.input_order[i];
  3093. const bool is_input_invalid = (new_index < 0) || (new_index >= inputs.size());
  3094. if (is_input_invalid) {
  3095. REPORT_INNER_ERROR("E19999", "New order of %s has invalid index %d, out of range(0, %d)",
  3096. it_node_map.first.c_str(), new_index, inputs.size());
  3097. GELOGE(INTERNAL_ERROR, "New order of %s has invalid index %d.", it_node_map.first.c_str(), new_index);
  3098. return INTERNAL_ERROR;
  3099. }
  3100. current_node->set_input(i, inputs[new_index]);
  3101. }
  3102. GELOGI("The input sequence of the node has been rearranged, node name:%s.", it_node_map.first.c_str());
  3103. }
  3104. }
  3105. // 2.4 remove the input const nodes
  3106. Status ret = RemoveInputs(graph_def, current_node, unused_inputs, all_nodedef_map);
  3107. if (ret != SUCCESS) {
  3108. REPORT_CALL_ERROR("E19999", "remove input for op:%s failed", current_op_name.c_str());
  3109. GELOGE(INTERNAL_ERROR, "Op[%s] remove input failed.", current_op_name.c_str());
  3110. return ret;
  3111. }
  3112. }
  3113. return SUCCESS;
  3114. }
  3115. Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def,
  3116. domi::tensorflow::NodeDef *node_def,
  3117. const map<string, NodeDef *> &all_node_map,
  3118. const vector<string> &removed_inputs_vec) const {
  3119. GE_CHECK_NOTNULL(graph_def);
  3120. GE_CHECK_NOTNULL(node_def);
  3121. for (const auto &remove_input : removed_inputs_vec) {
  3122. string input_node_name = NodeNameFromInput(remove_input);
  3123. auto it = all_node_map.find(input_node_name);
  3124. if (it == all_node_map.end()) {
  3125. REPORT_INNER_ERROR("E19999", "Node:%s can't find in all_node_map, check invalid", input_node_name.c_str());
  3126. GELOGE(FAILED, "Can not find node name:%s in all node map.", input_node_name.c_str());
  3127. return FAILED;
  3128. }
  3129. NodeDef *input_node_def = it->second;
  3130. if ((input_node_def->op() == parser::SWITCH) || (input_node_def->op() == parser::REFSWITCH)) {
  3131. NodeDef *identity_node_def = graph_def->add_node();
  3132. GE_CHECK_NOTNULL(identity_node_def);
  3133. std::string remove_input_name = remove_input;
  3134. remove_input_name = remove_input_name.find(":") == std::string::npos ?
  3135. input_node_name : (remove_input_name.replace(remove_input_name.find(":"), 1, "_"));
  3136. input_node_name = remove_input_name + "_identity";
  3137. identity_node_def->set_name(input_node_name);
  3138. identity_node_def->set_op(parser::IDENTITY);
  3139. identity_node_def->add_input(remove_input);
  3140. }
  3141. string control_input = "^" + input_node_name;
  3142. node_def->add_input(control_input);
  3143. GELOGD("Add control input:%s for node:%s", control_input.c_str(), node_def->name().c_str());
  3144. }
  3145. return SUCCESS;
  3146. }
  3147. /**
  3148. * @ingroup domi_omg
  3149. * @brief Delete input from nodedef
  3150. * @param [in] node_def Nodedef object
  3151. * @param [in] remove_index_set Index collection of input nodes to be deleted
  3152. * @return true remove successfully
  3153. * @return false remove failed
  3154. *
  3155. */
  3156. Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *node_def,
  3157. const set<uint32_t> &remove_index_set,
  3158. const map<string, NodeDef *> &all_node_map) const {
  3159. GE_CHECK_NOTNULL(node_def);
  3160. if (remove_index_set.empty()) {
  3161. GELOGI("The size of remove_index_set is zero.");
  3162. return SUCCESS;
  3163. }
  3164. map<string, vector<int>> remove_inputs_map;
  3165. for (auto &it : remove_index_set) {
  3166. const string &input_node_name = node_def->input(it);
  3167. remove_inputs_map[input_node_name].emplace_back(it);
  3168. GELOGD("Push input:%s, index:%d into remove map.", input_node_name.c_str(), it);
  3169. }
  3170. RemoveInputAttr(node_def, remove_inputs_map);
  3171. int index = 0;
  3172. vector<string> removed_inputs_vec;
  3173. auto *inputs = node_def->mutable_input();
  3174. for (auto input_it = inputs->begin(); input_it != inputs->end(); ++index) {
  3175. // 1.decide whether to remove the input
  3176. bool flag = false;
  3177. for (auto &remove_input : remove_inputs_map) {
  3178. string remove_input_name = remove_input.first;
  3179. vector<int> remove_input_indexs = remove_input.second;
  3180. if (((*input_it) == remove_input_name) &&
  3181. (std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end())) {
  3182. GELOGD("Remove input:%s, index:%d", remove_input_name.c_str(), index);
  3183. flag = true;
  3184. removed_inputs_vec.emplace_back(remove_input_name);
  3185. break;
  3186. }
  3187. }
  3188. if (flag) {
  3189. // 2 remove the input
  3190. input_it = inputs->erase(input_it);
  3191. } else {
  3192. ++input_it;
  3193. }
  3194. }
  3195. Status ret = AddControlEdgeAfterRemoveInputs(graph_def, node_def, all_node_map, removed_inputs_vec);
  3196. if (ret != SUCCESS) {
  3197. GELOGE(FAILED, "Add control edges for node:%s failed.", node_def->name().c_str());
  3198. return FAILED;
  3199. }
  3200. return SUCCESS;
  3201. }
  3202. void TensorFlowModelParser::RemoveInputAttr(domi::tensorflow::NodeDef *node_def,
  3203. const map<string, vector<int>> &remove_inputs_map) const {
  3204. // The caller guarantees that the pointer is not null
  3205. auto *inputs = node_def->mutable_input();
  3206. google::protobuf::Map<std::string, domi::tensorflow::AttrValue> *attr_map = node_def->mutable_attr();
  3207. const google::protobuf::Map<std::string, domi::tensorflow::AttrValue>::iterator it =
  3208. attr_map->find(ge::ATTR_NAME_INPUT_TENSOR_DESC);
  3209. if (it == attr_map->end()) {
  3210. GELOGW("Failed to find input desc from tf node_def[%s]", node_def->name().c_str());
  3211. return;
  3212. }
  3213. domi::tensorflow::AttrValue *input_attr_value = &(it->second);
  3214. auto tmp_attr = input_attr_value->mutable_list()->mutable_func();
  3215. auto attr_it = tmp_attr->begin();
  3216. int index = 0;
  3217. for (auto input_it = inputs->begin(); input_it != inputs->end(); ++input_it, ++index) {
  3218. // 1.decide whether to remove the input
  3219. bool flag = false;
  3220. for (auto &remove_input : remove_inputs_map) {
  3221. string remove_input_name = remove_input.first;
  3222. vector<int> remove_input_indexs = remove_input.second;
  3223. if ((*input_it) == remove_input_name &&
  3224. std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) {
  3225. GELOGD("Remove input attr:%s, index:%d", remove_input_name.c_str(), index);
  3226. flag = true;
  3227. break;
  3228. }
  3229. }
  3230. if (flag) {
  3231. // 2.1 remove the input attr
  3232. if (!tmp_attr->empty() && (attr_it != tmp_attr->end())) {
  3233. attr_it = tmp_attr->erase(attr_it);
  3234. } else {
  3235. ++attr_it;
  3236. }
  3237. } else {
  3238. ++attr_it;
  3239. }
  3240. }
  3241. }
  3242. Status TensorFlowModelParser::GetTensorflowGraphInOutMap(domi::tensorflow::GraphDef *graph_def) {
  3243. GE_CHECK_NOTNULL(graph_def);
  3244. for (int i = 0; i < graph_def->node_size(); i++) {
  3245. domi::tensorflow::NodeDef *node = graph_def->mutable_node(i);
  3246. const string &node_name = node->name();
  3247. node_inputs_outputs_map_.emplace(node_name, std::pair<set<string>, set<string>>{});
  3248. for (const auto &input : node->input()) {
  3249. string input_node_name;
  3250. GE_RETURN_IF_ERROR(CheckInputNodeName(input, &input_node_name, nullptr, nullptr));
  3251. node_inputs_outputs_map_[node_name].first.insert(input_node_name);
  3252. node_inputs_outputs_map_[input_node_name].second.insert(node_name);
  3253. }
  3254. }
  3255. return SUCCESS;
  3256. }
  3257. Status TensorFlowModelParser::RemoveIsolateNode(domi::tensorflow::GraphDef *graph_def) {
  3258. GE_CHECK_NOTNULL(graph_def);
  3259. set<string> node_to_delete;
  3260. for (int i = 0; i < graph_def->node_size(); i++) {
  3261. domi::tensorflow::NodeDef *node = graph_def->mutable_node(i);
  3262. const string &node_name = node->name();
  3263. if (node_inputs_outputs_map_.find(node_name) == node_inputs_outputs_map_.end()) {
  3264. REPORT_INNER_ERROR("E19999", "Node:%s can't find in node_inputs_outputs_map_, check invalid", node_name.c_str());
  3265. GELOGE(FAILED, "Can not find input output context, node:%s.", node_name.c_str());
  3266. return FAILED;
  3267. }
  3268. if ((node_inputs_outputs_map_[node_name].first.empty() && node_inputs_outputs_map_[node_name].second.empty() &&
  3269. node->op() != kDpop) ||
  3270. (node->op() == ge::parser::CONSTANT && node_inputs_outputs_map_[node_name].second.empty())) {
  3271. GELOGI("%s will inset to node_to_delete", node_name.c_str());
  3272. node_to_delete.insert(node_name);
  3273. }
  3274. }
  3275. // delete isolate nodes
  3276. auto nodeList = graph_def->mutable_node();
  3277. for (auto iter = nodeList->begin(); iter != nodeList->end();) {
  3278. if (node_to_delete.count(iter->name()) != 0) {
  3279. GELOGI("%s has zero input and output, will delete.", iter->name().c_str());
  3280. iter = nodeList->erase(iter);
  3281. } else {
  3282. iter++;
  3283. }
  3284. }
  3285. return SUCCESS;
  3286. }
  3287. Status TensorFlowModelParser::RecordFusionResult(const std::shared_ptr<ge::ScopeGraph> &scope_graph,
  3288. const domi::tensorflow::NodeDef *node, const ge::OpDescPtr &op_desc) {
  3289. // The caller guarantees that the pointer is not null
  3290. GELOGI("RecordFusionResult for %s start.", op_desc->GetName().c_str());
  3291. auto &impl_scope_graph = scope_graph->impl_;
  3292. ge::FusionScopesResult *fusion_result = impl_scope_graph->GetFusionScopesResults(node);
  3293. if (fusion_result == nullptr) {
  3294. GELOGW("fusion_result is not found.");
  3295. return SUCCESS;
  3296. }
  3297. std::vector<std::string> original_names;
  3298. auto nodes = fusion_result->Nodes();
  3299. std::transform(nodes.begin(), nodes.end(), std::back_inserter(original_names),
  3300. [](ge::OperatorPtr n) -> std::string { return ParserUtils::GetOperatorName(*n); });
  3301. GELOGI("Op %s original_names size = %zu.", op_desc->GetName().c_str(), original_names.size());
  3302. bool ret = ge::AttrUtils::SetListStr(op_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
  3303. if (!ret) {
  3304. GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), op_desc->GetName().c_str());
  3305. }
  3306. auto outputs_desc = op_desc->GetAllOutputsDesc();
  3307. auto &impl = fusion_result->impl_;
  3308. for (auto &fusion_output : impl->GetOutputs()) {
  3309. for (size_t i = 0; i < fusion_output.second.size(); ++i) {
  3310. if (fusion_output.second[i] == ge::kFusionDisableIndex) {
  3311. continue;
  3312. }
  3313. if (fusion_output.second[i] >= static_cast<int32_t>(op_desc->GetOutputsSize())) {
  3314. REPORT_INNER_ERROR("E19999", "fusion output index:%d of node:%s(%s) must less than outputs desc size %zu.",
  3315. fusion_output.second[i], op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  3316. op_desc->GetOutputsSize());
  3317. GELOGE(PARAM_INVALID, "fusion output index %d must less than outputs desc size %zu.", fusion_output.second[i],
  3318. op_desc->GetOutputsSize());
  3319. return PARAM_INVALID;
  3320. }
  3321. ret = ge::AttrUtils::SetStr(op_desc->MutableOutputDesc(fusion_output.second[i]),
  3322. ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, fusion_output.first);
  3323. if (!ret) {
  3324. GELOGW("Set %s to %s %d output fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME.c_str(), op_desc->GetName().c_str(),
  3325. fusion_output.second[i]);
  3326. }
  3327. ret = ge::AttrUtils::SetInt(op_desc->MutableOutputDesc(fusion_output.second[i]),
  3328. ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, i);
  3329. if (!ret) {
  3330. GELOGW("Set %s to %s %d output fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX.c_str(),
  3331. op_desc->GetName().c_str(), fusion_output.second[i]);
  3332. }
  3333. }
  3334. }
  3335. return SUCCESS;
  3336. }
  3337. Status TensorFlowModelParser::SetOriginNodeContext(const NodeDef *node_def, OpNodeContext &op_node_context,
  3338. const std::vector<std::pair<std::string, int32_t>> &inputs,
  3339. const std::vector<std::pair<std::string, int32_t>> &outputs) {
  3340. int32_t in_index = 0;
  3341. for (const auto &in : inputs) {
  3342. bool is_ctrl = in.second == kControlSlot;
  3343. op_node_context.input_map[in.first].emplace_back(std::make_pair(in.second, is_ctrl ? kControlSlot : in_index));
  3344. SaveEdgesControlInfo(node_def->name(), is_ctrl);
  3345. in_index = is_ctrl ? in_index : in_index + 1;
  3346. }
  3347. int32_t out_index = 0;
  3348. for (const auto &out : outputs) {
  3349. bool is_ctrl = out.second == kControlSlot;
  3350. op_node_context.output_map[out.first].emplace_back(std::make_pair(is_ctrl ? kControlSlot : out_index, out.second));
  3351. out_index = is_ctrl ? out_index : out_index + 1;
  3352. }
  3353. return SUCCESS;
  3354. }
  3355. void TensorFlowModelParser::GetFusionInputInfo(
  3356. const string &fusion_op_name, OpNodeContext &fusion_context,
  3357. std::map<string, std::pair<std::string, std::pair<int32_t, int32_t>>> &remap_data_input,
  3358. std::map<string, std::vector<string>> &remap_ctrl_input, std::set<string> &fusion_input_nodes) {
  3359. for (const auto &fusion_input : fusion_context.input_map) {
  3360. string fusion_src_name = fusion_input.first;
  3361. for (const auto &fusion_idx_pair : fusion_input.second) {
  3362. string key = fusion_op_name + std::to_string(fusion_idx_pair.second);
  3363. if (fusion_idx_pair.second != kControlSlot) {
  3364. remap_data_input[key] = {fusion_src_name, {fusion_idx_pair.first, fusion_idx_pair.second}};
  3365. } else {
  3366. remap_ctrl_input[key].emplace_back(fusion_src_name);
  3367. }
  3368. }
  3369. fusion_input_nodes.insert(fusion_src_name);
  3370. }
  3371. }
  3372. void TensorFlowModelParser::GetFusionOutputInfo(
  3373. const string &fusion_op_name, OpNodeContext &fusion_context,
  3374. std::map<string, std::vector<std::pair<std::string, std::pair<int32_t, int32_t>>>> &remap_data_output,
  3375. std::map<string, std::vector<string>> &remap_ctrl_output, std::set<string> &fusion_output_nodes) {
  3376. for (const auto &fusion_output : fusion_context.output_map) {
  3377. string fusion_dst_name = fusion_output.first;
  3378. for (const auto &fusion_idx_pair : fusion_output.second) {
  3379. string key = fusion_op_name + std::to_string(fusion_idx_pair.first);
  3380. if (fusion_idx_pair.first != kControlSlot) {
  3381. remap_data_output[key].emplace_back(
  3382. std::make_pair(fusion_dst_name, std::make_pair(fusion_idx_pair.first, fusion_idx_pair.second)));
  3383. } else {
  3384. remap_ctrl_output[key].emplace_back(fusion_dst_name);
  3385. }
  3386. }
  3387. fusion_output_nodes.insert(fusion_dst_name);
  3388. }
  3389. }
  3390. void TensorFlowModelParser::UpdateInnerInputMap(const string &fusion_op_name, OpNodeContext &fusion_context,
  3391. const std::vector<std::string> &inner_nodes_name,
  3392. std::set<string> &fusion_input_nodes) {
  3393. std::map<string, std::pair<std::string, std::pair<int32_t, int32_t>>> remap_data_input;
  3394. std::map<string, std::vector<string>> remap_ctrl_input;
  3395. GetFusionInputInfo(fusion_op_name, fusion_context, remap_data_input, remap_ctrl_input, fusion_input_nodes);
  3396. for (const auto &node_name : inner_nodes_name) {
  3397. auto context_iter = op_node_context_map_.find(node_name);
  3398. if (context_iter != op_node_context_map_.end()) {
  3399. OpNodeContext &op_node_context = context_iter->second;
  3400. // update input map of inner node
  3401. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> tmp_input_map;
  3402. for (auto iter = op_node_context.input_map.begin(); iter != op_node_context.input_map.end();) {
  3403. string src_name = iter->first;
  3404. if (src_name == ge::kInputFromFusionScope) {
  3405. std::vector<std::pair<int32_t, int32_t>> &input_idx = iter->second;
  3406. for (const auto &in_pair : input_idx) {
  3407. if (in_pair.second != kControlSlot) {
  3408. auto data = remap_data_input[fusion_op_name + std::to_string(in_pair.first)];
  3409. tmp_input_map[data.first].emplace_back(std::make_pair(data.second.first, in_pair.second));
  3410. GELOGI("Update inner input, src:%s, idx:%u->%u", data.first.c_str(), data.second.first, in_pair.second);
  3411. }
  3412. }
  3413. auto ctrl = remap_ctrl_input[fusion_op_name + std::to_string(kControlSlot)];
  3414. for (const auto &ctrl_in : ctrl) {
  3415. tmp_input_map[ctrl_in].emplace_back(std::make_pair(kControlSlot, kControlSlot));
  3416. SaveEdgesControlInfo(node_name, kControlSlot);
  3417. }
  3418. iter = op_node_context.input_map.erase(iter);
  3419. } else {
  3420. ++iter;
  3421. }
  3422. }
  3423. op_node_context.input_map.insert(tmp_input_map.cbegin(), tmp_input_map.cend());
  3424. // update output map of pre node
  3425. for (const auto &in_iter : op_node_context.input_map) {
  3426. auto src_iter = op_node_context_map_.find(in_iter.first);
  3427. if (src_iter != op_node_context_map_.end()) {
  3428. std::vector<std::pair<int32_t, int32_t>> input_pairs = in_iter.second;
  3429. OpNodeContext &src_context = src_iter->second;
  3430. src_context.output_map[node_name].assign(input_pairs.begin(), input_pairs.end());
  3431. }
  3432. }
  3433. }
  3434. }
  3435. }
  3436. void TensorFlowModelParser::UpdateInnerOutputMap(const string &fusion_op_name, OpNodeContext &fusion_context,
  3437. const std::vector<std::string> &inner_nodes_name,
  3438. std::set<string> &fusion_output_nodes) {
  3439. std::map<string, std::vector<std::pair<std::string, std::pair<int32_t, int32_t>>>> remap_data_output;
  3440. std::map<string, std::vector<string>> remap_ctrl_output;
  3441. GetFusionOutputInfo(fusion_op_name, fusion_context, remap_data_output, remap_ctrl_output, fusion_output_nodes);
  3442. for (const auto &node_name : inner_nodes_name) {
  3443. auto context_iter = op_node_context_map_.find(node_name);
  3444. if (context_iter != op_node_context_map_.end()) {
  3445. OpNodeContext &op_node_context = context_iter->second;
  3446. // update output map of inner node
  3447. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> tmp_output_map;
  3448. for (auto iter = op_node_context.output_map.begin(); iter != op_node_context.output_map.end();) {
  3449. string dst_name = iter->first;
  3450. if (dst_name == ge::kOutputToFusionScope) {
  3451. std::vector<std::pair<int32_t, int32_t>> &output_idx = iter->second;
  3452. for (const auto &out_pair : output_idx) {
  3453. if (out_pair.second != kControlSlot) {
  3454. auto data_outputs = remap_data_output[fusion_op_name + std::to_string(out_pair.second)];
  3455. for (const auto &data : data_outputs) {
  3456. tmp_output_map[data.first].emplace_back(std::make_pair(out_pair.first, data.second.second));
  3457. GELOGI("Update inner output, dst:%s, idx:%u->%u.", data.first.c_str(), out_pair.first,
  3458. data.second.second);
  3459. }
  3460. }
  3461. }
  3462. auto ctrl = remap_ctrl_output[fusion_op_name + std::to_string(kControlSlot)];
  3463. for (const auto &ctrl_in : ctrl) {
  3464. tmp_output_map[ctrl_in].emplace_back(std::make_pair(kControlSlot, kControlSlot));
  3465. }
  3466. iter = op_node_context.output_map.erase(iter);
  3467. } else {
  3468. ++iter;
  3469. }
  3470. }
  3471. op_node_context.output_map.insert(tmp_output_map.cbegin(), tmp_output_map.cend());
  3472. // update input map of pre node
  3473. for (const auto &out_iter : op_node_context.output_map) {
  3474. auto dst_iter = op_node_context_map_.find(out_iter.first);
  3475. if (dst_iter != op_node_context_map_.end()) {
  3476. std::vector<std::pair<int32_t, int32_t>> output_pairs = out_iter.second;
  3477. OpNodeContext &dst_context = dst_iter->second;
  3478. dst_context.input_map[node_name].assign(output_pairs.begin(), output_pairs.end());
  3479. }
  3480. }
  3481. }
  3482. }
  3483. }
  3484. Status TensorFlowModelParser::UpdateInnerNodeContext(const string &fusion_op_name,
  3485. const std::vector<std::string> &inner_nodes_name) {
  3486. auto fusion_iter = op_node_context_map_.find(fusion_op_name);
  3487. if (fusion_iter == op_node_context_map_.end()) {
  3488. REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", fusion_op_name.c_str());
  3489. GELOGE(INTERNAL_ERROR, "Can't find context for fusion node %s.", fusion_op_name.c_str());
  3490. return INTERNAL_ERROR;
  3491. }
  3492. OpNodeContext &fusion_context = fusion_iter->second;
  3493. std::set<string> fusion_input_nodes;
  3494. std::set<string> fusion_output_nodes;
  3495. UpdateInnerInputMap(fusion_op_name, fusion_context, inner_nodes_name, fusion_input_nodes);
  3496. UpdateInnerOutputMap(fusion_op_name, fusion_context, inner_nodes_name, fusion_output_nodes);
  3497. for (const auto &in_name : fusion_input_nodes) {
  3498. auto fusion_in = op_node_context_map_.find(in_name);
  3499. if (fusion_in != op_node_context_map_.end()) {
  3500. OpNodeContext &fusion_in_context = fusion_in->second;
  3501. fusion_in_context.output_map.erase(fusion_op_name);
  3502. }
  3503. }
  3504. for (const auto &out_name : fusion_output_nodes) {
  3505. auto fusion_out = op_node_context_map_.find(out_name);
  3506. if (fusion_out != op_node_context_map_.end()) {
  3507. OpNodeContext &fusion_out_context = fusion_out->second;
  3508. fusion_out_context.input_map.erase(fusion_op_name);
  3509. }
  3510. }
  3511. op_node_context_map_.erase(fusion_op_name);
  3512. return SUCCESS;
  3513. }
  3514. Status TensorFlowModelParser::AddFusionInnerNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph,
  3515. const string &fusion_op_name, vector<string> &node_name_list) {
  3516. auto &impl_scope_graph = scope_graph->impl_;
  3517. GE_CHECK_NOTNULL(impl_scope_graph);
  3518. ge::FusionScopesResult *fusion_result = impl_scope_graph->GetFusionScopesResults(fusion_op_name);
  3519. GE_CHECK_NOTNULL(fusion_result);
  3520. auto &impl_fusion_rlt = fusion_result->impl_;
  3521. GE_CHECK_NOTNULL(impl_fusion_rlt);
  3522. ge::FusionInnerNodesInfo inner_nodes_info = impl_fusion_rlt->GetInnerNodesInfo();
  3523. vector<string> inner_nodes_name;
  3524. for (const auto &info : inner_nodes_info) {
  3525. string node_name;
  3526. string type;
  3527. std::vector<std::pair<std::string, int32_t>> inputs;
  3528. std::vector<std::pair<std::string, int32_t>> outputs;
  3529. const ge::Operator *op = nullptr;
  3530. std::tie(node_name, type, inputs, outputs, op) = info;
  3531. NodeDef *node_def = new (std::nothrow) NodeDef();
  3532. GE_CHECK_NOTNULL(node_def);
  3533. node_def->set_name(node_name);
  3534. node_def->set_op(type);
  3535. nodedef_map_[node_name] = node_def;
  3536. fusion_nodedef_list.push_back(node_def);
  3537. for (const auto &in : inputs) {
  3538. // The input value is not used in the subsequent process. The value is added only for placeholders.
  3539. node_def->add_input(in.first);
  3540. }
  3541. domi::tensorflow::AttrValue attr_value;
  3542. attr_value.set_b(true);
  3543. ge::TensorFlowUtil::AddNodeAttr(kAttrNameIsScopeInnerNode, attr_value, node_def);
  3544. OpNodeContext &op_node_context = op_node_context_map_[node_name];
  3545. Status ret = SetOriginNodeContext(node_def, op_node_context, inputs, outputs);
  3546. if (ret != SUCCESS) {
  3547. GELOGE(ret, "Failed to add context and attrs, node:%s.", node_name.c_str());
  3548. return ret;
  3549. }
  3550. scope_inner_node_map_.insert({node_name, op});
  3551. node_name_list.emplace_back(node_name);
  3552. inner_nodes_name.emplace_back(node_name);
  3553. GELOGI("Add fusion inner node def, name:%s, type:%s.", node_name.c_str(), type.c_str());
  3554. }
  3555. Status ret = UpdateInnerNodeContext(fusion_op_name, inner_nodes_name);
  3556. if (ret != SUCCESS) {
  3557. GELOGE(ret, "Failed to update inner node context, fusion_op_name:%s.", fusion_op_name.c_str());
  3558. return ret;
  3559. }
  3560. return SUCCESS;
  3561. }
  3562. Status TensorFlowModelParser::AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph,
  3563. vector<string> &node_name_list) {
  3564. vector<string> node_name_list_new;
  3565. size_t op_node_list_size = node_name_list.size();
  3566. DumpAllNodeContext("BeforeAddFusionNodeDef");
  3567. for (size_t i = 0; i < op_node_list_size; ++i) {
  3568. const string op_node_name = node_name_list[i];
  3569. std::map<string, vector<const NodeDef *>>::const_iterator iter = fusion_op_nodedef_map_.find(op_node_name);
  3570. if (iter != fusion_op_nodedef_map_.end()) {
  3571. vector<string> fusion_op_info = fusion_op_type_map_[op_node_name];
  3572. if (fusion_op_info[0] != ge::kScopeToMultiNodes) {
  3573. NodeDef *node_def = new (std::nothrow) NodeDef();
  3574. GE_CHECK_NOTNULL(node_def);
  3575. node_def->set_name(op_node_name);
  3576. node_def->set_op(fusion_op_info[0]);
  3577. nodedef_map_[op_node_name] = node_def;
  3578. fusion_nodedef_list.push_back(node_def);
  3579. OpNodeContext &node_context = op_node_context_map_[node_def->name()];
  3580. for (const auto &input : node_context.input_map) {
  3581. // The input value is not used in the subsequent process. The value is added only for placeholders.
  3582. node_def->add_input(input.first);
  3583. }
  3584. node_name_list_new.emplace_back(op_node_name);
  3585. GELOGI("Add Fusion node def, name:%s, type:%s.", node_def->name().c_str(), node_def->op().c_str());
  3586. } else {
  3587. Status ret = AddFusionInnerNodeDef(scope_graph, op_node_name, node_name_list_new);
  3588. if (ret != SUCCESS) {
  3589. REPORT_INNER_ERROR("E19999",
  3590. "Failed to add fusion inner nodes for fusion op:%s, "
  3591. "please check FusionScopesResult set in scope fusion pass",
  3592. op_node_name.c_str());
  3593. GELOGE(ret, "Failed to add fusion inner node, fusion_op_name:%s.", op_node_name.c_str());
  3594. return ret;
  3595. }
  3596. GELOGI("Add fusion inner nodes successfully, fusion name:%s.", op_node_name.c_str());
  3597. op_node_context_map_.erase(op_node_name);
  3598. }
  3599. } else {
  3600. node_name_list_new.emplace_back(op_node_name);
  3601. }
  3602. }
  3603. node_name_list.clear();
  3604. node_name_list.assign(node_name_list_new.begin(), node_name_list_new.end());
  3605. DumpAllNodeContext("AfterAddFusionNodeDef");
  3606. return SUCCESS;
  3607. }
  3608. Status TensorFlowModelParser::AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph,
  3609. std::mutex *const graph_mutex,
  3610. const domi::tensorflow::NodeDef *node_def) {
  3611. // This is an internal function. The pointer input parameter is not empty when this function is invoked.
  3612. string node_name = node_def->name();
  3613. string node_op = node_def->op();
  3614. auto iter = parser->scope_inner_node_map_.find(node_name);
  3615. if (iter == parser->scope_inner_node_map_.end()) {
  3616. REPORT_INNER_ERROR("E19999", "Node:%s can't find in scope_inner_node_map_, check invalid", node_name.c_str());
  3617. GELOGE(PARAM_INVALID, "Failed to find scope inner node:%s, type:%s.", node_name.c_str(), node_op.c_str());
  3618. return PARAM_INVALID;
  3619. }
  3620. const ge::Operator *op = iter->second;
  3621. ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(*op);
  3622. GE_CHECK_NOTNULL(op_desc);
  3623. ge::NodePtr node;
  3624. {
  3625. std::lock_guard<std::mutex> lock(*graph_mutex);
  3626. node = graph->AddNode(op_desc);
  3627. }
  3628. if (node == nullptr) {
  3629. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(),
  3630. op_desc->GetType().c_str(), graph->GetName().c_str());
  3631. GELOGE(INTERNAL_ERROR, "Failed to Add scope inner node:%s, type:%s.", op_desc->GetName().c_str(),
  3632. op_desc->GetType().c_str());
  3633. return INTERNAL_ERROR;
  3634. }
  3635. {
  3636. std::lock_guard<std::mutex> lock(parser->nodeMapMutex_);
  3637. parser->node_map_[node_name] = node;
  3638. }
  3639. GELOGI("Add scope inner node successfully, node name:%s, type:%s.", op_desc->GetName().c_str(),
  3640. op_desc->GetType().c_str());
  3641. return SUCCESS;
  3642. }
  3643. void TensorFlowModelParser::DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase) {
  3644. GELOGD("phase:%s === Begin to dump context for node:%s ===", phase.c_str(), node_name.c_str());
  3645. for (const auto &input : ctx.input_map) {
  3646. for (const auto &input_idx : input.second) {
  3647. GELOGD(" Input info: %s:%d --> in_idx %d.", input.first.c_str(), input_idx.first, input_idx.second);
  3648. }
  3649. }
  3650. for (const auto &output : ctx.output_map) {
  3651. for (const auto &output_idx : output.second) {
  3652. GELOGD(" Output info: out_idx %d --> %s:%d.", output_idx.first, output.first.c_str(), output_idx.second);
  3653. }
  3654. }
  3655. GELOGD("phase:%s === End to dump context for node:%s ===", phase.c_str(), node_name.c_str());
  3656. }
  3657. void TensorFlowModelParser::DumpAllNodeContext(const string &phase) const {
  3658. if (!IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) {
  3659. return;
  3660. }
  3661. for (const auto &iter : op_node_context_map_) {
  3662. DumpNodeContext(iter.first, iter.second, phase);
  3663. }
  3664. }
  3665. Status TensorFlowModelParser::CheckAndUpdateInputDesc(const ge::ComputeGraphPtr &compute_graph) {
  3666. GE_CHECK_NOTNULL(compute_graph);
  3667. for (auto &node : compute_graph->GetDirectNode()) {
  3668. auto op_desc = node->GetOpDesc();
  3669. GE_CHECK_NOTNULL(op_desc);
  3670. for (auto &in_anchor : node->GetAllInDataAnchors()) {
  3671. if (!(op_desc->IsOptionalInput(static_cast<uint32_t>(in_anchor->GetIdx())))) {
  3672. continue;
  3673. }
  3674. auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
  3675. auto in_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()));
  3676. if ((peer_out_anchor != nullptr) && (in_desc == nullptr)) {
  3677. // The input is connected to the peer output but TensorDesc is invalid, update TensorDesc to valid.
  3678. ge::GeTensorDesc tensor_desc;
  3679. auto ret = op_desc->UpdateInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()), tensor_desc);
  3680. if (ret != ge::GRAPH_SUCCESS) {
  3681. REPORT_CALL_ERROR("E19999", "Update index:%d of input desc in op:%s(%s) failed", in_anchor->GetIdx(),
  3682. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  3683. GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
  3684. return ret;
  3685. }
  3686. GELOGI("Update input desc to valid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
  3687. } else if ((peer_out_anchor == nullptr) && (in_desc != nullptr)) {
  3688. // The input is not connected to the peer output but TensorDesc is valid, update TensorDesc to invalid.
  3689. ge::GeTensorDesc tensor_desc(ge::GeShape(), FORMAT_RESERVED, DT_UNDEFINED);
  3690. auto ret = op_desc->UpdateInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()), tensor_desc);
  3691. if (ret != ge::GRAPH_SUCCESS) {
  3692. REPORT_CALL_ERROR("E19999", "Update index:%d of input desc in op:%s(%s) failed", in_anchor->GetIdx(),
  3693. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  3694. GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
  3695. return ret;
  3696. }
  3697. GELOGI("Update input desc to invalid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
  3698. }
  3699. }
  3700. }
  3701. return SUCCESS;
  3702. }
  3703. Status TensorFlowModelParser::UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes) {
  3704. auto &user_specified_nodes = ge::GetParserContext().user_out_nodes;
  3705. if (!user_specified_nodes.empty()) {
  3706. for (auto &output_node_info : user_specified_nodes) {
  3707. ParserUtils::UpdateOutputNodeInfo(final_output_nodes, output_node_info);
  3708. }
  3709. }
  3710. return SUCCESS;
  3711. }
  3712. Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph) {
  3713. GE_CHECK_NOTNULL(root_graph);
  3714. for (const NodePtr &node : root_graph->GetAllNodes()) {
  3715. if ((node == nullptr) || (node->GetOpDesc() == nullptr)) {
  3716. continue;
  3717. }
  3718. std::string model_data;
  3719. if (AttrUtils::GetStr(node->GetOpDesc(), kExternalModel, model_data) && !model_data.empty()) {
  3720. ge::Model model;
  3721. auto load_ret = ge::Model::Load(ge::PtrToPtr<char_t, const uint8_t>(model_data.data()), model_data.size(), model);
  3722. if (load_ret != GRAPH_SUCCESS) {
  3723. GELOGE(INTERNAL_ERROR, "[Parse][ExternalModel]Node:%s.", node->GetName().c_str());
  3724. REPORT_CALL_ERROR("E19999", "Failed to parse external model, node:%s.", node->GetName().c_str());
  3725. return INTERNAL_ERROR;
  3726. }
  3727. Graph graph = model.GetGraph();
  3728. GELOGD("Get subgraph[%s] from model[%s].", ParserUtils::GetGraphName(graph).c_str(), node->GetName().c_str());
  3729. Status ret = MappingAndAddSubGraph(node, graph, root_graph);
  3730. if (ret != SUCCESS) {
  3731. GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]Node:%s.", node->GetName().c_str());
  3732. REPORT_CALL_ERROR("E19999", "Failed to map and add sub graph, node:%s.", node->GetName().c_str());
  3733. return INTERNAL_ERROR;
  3734. }
  3735. (void)node->GetOpDesc()->DelAttr(kExternalModel);
  3736. }
  3737. }
  3738. return SUCCESS;
  3739. }
  3740. } // namespace ge
  3741. namespace domi {
  3742. REGISTER_MODEL_PARSER_CREATOR(TENSORFLOW, ge::TensorFlowModelParser);
  3743. REGISTER_WEIGHTS_PARSER_CREATOR(TENSORFLOW, ge::TensorFlowWeightsParser);
  3744. } // namespace domi