You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_replace.cpp 82 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713
  1. #include <cstring>
  2. #include "megbrain/dtype.h"
  3. #include "megbrain/opr/basic_arith.h"
  4. #include "megbrain/opr/blas.h"
  5. #include "megbrain/opr/dnn/convolution.h"
  6. #include "megbrain/opr/dnn/pooling.h"
  7. #include "megbrain/opr/nn_int.h"
  8. #include "megbrain/opr/tensor_manip.h"
  9. #include "megbrain/utils/arith_helper.h"
  10. #if MGB_ENABLE_TENSOR_RT
  11. #include "megbrain/gopt/basic_arith.h"
  12. #include "megbrain/gopt/inference.h"
  13. #include "megbrain/gopt/misc.h"
  14. #include "megbrain/tensorrt/opr_replace.h"
  15. #include "megbrain/tensorrt/tensorrt_engine_cache.h"
  16. #include "megbrain/tensorrt/tensorrt_opr.h"
  17. #pragma GCC diagnostic push
  18. #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
  19. using namespace mgb;
  20. using namespace gopt;
  21. using namespace cg;
  22. template <typename T>
  23. using TensorRTUniquePtr = opr::intl::TensorRTUniquePtr<T>;
  24. namespace {
  25. nvinfer1::DataType mgb_dtype_to_trt_dtype(DType dtype) {
  26. switch (dtype.enumv()) {
  27. case DTypeEnum::Float32:
  28. return nvinfer1::DataType::kFLOAT;
  29. case DTypeEnum::Float16:
  30. return nvinfer1::DataType::kHALF;
  31. case DTypeEnum::QuantizedS8:
  32. return nvinfer1::DataType::kINT8;
  33. case DTypeEnum::Int32:
  34. return nvinfer1::DataType::kINT32;
  35. default:
  36. mgb_throw(
  37. InternalError,
  38. "invalid data type which is not supported in TensorRT: %s",
  39. dtype.name());
  40. }
  41. }
  42. } // namespace
  43. class TensorRTReplacePass::Impl final {
  44. static constexpr size_t OPR_FAIL_LOG_NUM = 10;
  45. static constexpr float i8_max = std::numeric_limits<int8_t>::max();
  46. using TensorRTGraphFeatureBits = opr::intl::TensorRTGraphFeatureBits;
  47. using ConvFormat = opr::Convolution::Param::Format;
  48. using ExtraDep = ThinHashMap<OperatorNodeBase*, VarNodeArray>;
  49. const Pass& m_pass;
  50. OptState& m_opt_state;
  51. SubGraph::Rewriter m_rewriter;
  52. struct TensorRTGraph {
  53. using Callback = cg::DepOprIter::Callback;
  54. nvinfer1::IBuilder* builder;
  55. nvinfer1::INetworkDefinition* network;
  56. ThinHashSet<VarNode*> inputs;
  57. ThinHashSet<VarNode*> outputs;
  58. // is used for mapping output varnode in original computing graph to
  59. // output varnode of TensorRTOpr
  60. ThinHashMap<VarNode*, size_t> output2idx;
  61. // mark input and output tensor as nchw4 format, we should insert
  62. // dimshuffle and typecvt to make the TensorRTOpr's inputs and outputs
  63. // match with those of non fused operators.
  64. ThinHashSet<VarNode*> mark_input_varnode_nchw4;
  65. ThinHashSet<VarNode*> mark_output_varnode_nchw4;
  66. VarNodeArray trt_inputs;
  67. VarNodeArray trt_outputs;
  68. // Every tensor rt graph should own a map from var node to infer tensor.
  69. // Because a var node can belong to two different tensor rt subgraph
  70. ThinHashMap<VarNode*, nvinfer1::ITensor*> varnode2itensor;
  71. TensorRTGraphFeatureBits feature_bits;
  72. TensorRTGraph(
  73. TensorRTGraphFeatureBits feature_bits =
  74. TensorRTGraphFeatureBits::NCHW_FLOAT)
  75. : builder{nvinfer1::createInferBuilder(
  76. opr::TensorRTOpr::Logger::instance())},
  77. network{nullptr},
  78. feature_bits{feature_bits} {}
  79. void mark_varnode_format_nchw4();
  80. };
  81. struct FailInfo {
  82. OperatorNodeBase* opr;
  83. std::string fail_msg;
  84. };
  85. class HostTensorKeeper : public UserDataContainer::UserData {
  86. MGB_TYPEINFO_OBJ_DECL;
  87. public:
  88. std::vector<HostTensorND> htr;
  89. };
  90. std::unique_ptr<ConstVarPropogate> m_const_var_propogate;
  91. std::vector<std::shared_ptr<TensorRTGraph>> m_tensorrt_graphs;
  92. // use ThinHashMap instead of std::unordered_map
  93. ThinHashMap<OperatorNodeBase*, size_t> m_graph_map;
  94. ThinHashMap<OperatorNodeBase*, nvinfer1::IConvolutionLayer*> m_opr2convlayer;
  95. ThinHashMap<OperatorNodeBase*, nvinfer1::IDeconvolutionLayer*> m_opr2deconvlayer;
  96. size_t m_opr_num;
  97. size_t m_opr_fail_num;
  98. std::vector<FailInfo> m_opr_fail;
  99. struct OprTrait {
  100. // judge if supported, not exist means not support
  101. thin_function<Maybe<std::string>(OperatorNodeBase*)> get_replace_fail_msg;
  102. // replace opr by trt opr, ditto
  103. thin_function<void(nvinfer1::INetworkDefinition*, OperatorNodeBase*)>
  104. add_to_nvinfer;
  105. };
  106. ThinHashMap<Typeinfo*, OprTrait> m_opr_trait;
  107. // Find parent conv of elemwise ADD opr.
  108. VarNodeArray find_parent_conv(OperatorNodeBase* opr);
  109. // Make a trt tensor for Varnode var and add it as input of trt buffer.
  110. // Return false if a tensor of var is previously made and added.
  111. // True if var is encountered for the first time.
  112. bool check_input(
  113. VarNode* var, OperatorNodeBase* opr,
  114. mgb::SmallVector<TENSORRT_NO_DIMENSIONTYPE(nvinfer1::DimensionType)>
  115. dimtypes = {});
  116. HostTensorND get_value(VarNode* var, ConvFormat format = ConvFormat::NCHW);
  117. void set_itensor_dynamic_range(VarNode* var, OperatorNodeBase* opr);
  118. float get_scale(DType data_type);
  119. // Check whether an operator is a quantized operator. If an operator is a
  120. // quantized operator, this operator can be fused into a quantized TensorRT
  121. // subgraph
  122. bool is_quantized_int8_operator(OperatorNodeBase* opr);
  123. Maybe<std::string> has_fail_msg(OperatorNodeBase* opr);
  124. static nvinfer1::ITensor& replace(
  125. nvinfer1::INetworkDefinition* newtwork, nvinfer1::ITensor& pre_output,
  126. OperatorNodeBase* opr);
  127. void update_graph();
  128. void mark_varnode_format_nchw4();
  129. void detect_replace();
  130. public:
  131. Impl(const Pass& pass, OptState& opt_state)
  132. : m_pass{pass},
  133. m_opt_state{opt_state},
  134. m_rewriter{opt_state.graph().make_rewriter()},
  135. m_const_var_propogate{std::make_unique<ConstVarPropogate>(
  136. ConstVarType::IMMUTABLE_AND_PARAM)} {
  137. #define REPLACE_FAIL_MSG_EPILOGUE \
  138. { \
  139. auto&& mgr = opr->owner_graph()->static_infer_manager(); \
  140. auto&& shp = mgr.infer_shape_fallible(opr->output(0)); \
  141. if (!shp) \
  142. return "Unsupported opr, because operator shape cannot be " \
  143. "inferred at compile time."; \
  144. else \
  145. return None; \
  146. }
  147. m_opr_trait[opr::Elemwise::typeinfo()].get_replace_fail_msg =
  148. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  149. bool has_scalar = false;
  150. for (auto&& inp : opr->input()) {
  151. if (inp->shape().is_scalar()) {
  152. has_scalar = true;
  153. break;
  154. }
  155. }
  156. if (has_scalar)
  157. return "Elemwise with scalar input is not supported.";
  158. if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 &&
  159. opr->input(0)->dtype() != dtype::Float32()) {
  160. return "Unsupported data type.";
  161. }
  162. using Mode = opr::Elemwise::Mode;
  163. static const ThinHashSet<Mode> supported_modes {
  164. #if NV_TENSOR_RT_VERSION >= 5105
  165. Mode::SIN, Mode::COS, Mode::ASIN, Mode::ACOS, Mode::CEIL, Mode::FLOOR,
  166. #endif
  167. Mode::EXP, Mode::LOG, Mode::ABS,
  168. Mode::RELU, Mode::SIGMOID, Mode::TANH, Mode::ADD, Mode::MUL,
  169. Mode::MIN, Mode::MAX, Mode::SUB, Mode::TRUE_DIV, Mode::POW,
  170. Mode::FUSE_ADD_RELU, Mode::FUSE_ADD_TANH, Mode::FUSE_ADD_SIGMOID
  171. };
  172. auto mode = opr->cast_final_safe<opr::Elemwise>().param().mode;
  173. if (!supported_modes.count(mode)) {
  174. return "Unsupported Elemwise mode.";
  175. }
  176. #if NV_TENSOR_RT_VERSION >= 6001
  177. if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
  178. TensorShapeArray inps;
  179. for (auto&& inp : opr->input()) {
  180. inps.push_back(inp->shape());
  181. }
  182. TensorShape brdcast;
  183. megdnn::Elemwise::deduce_shape(inps, brdcast);
  184. if (brdcast.ndim < 4) {
  185. return "Elemwise with QuantizedS8 data type must have more "
  186. "than 4 dimensions. Less than 3 dimensions is not "
  187. "supported since trt6.0.";
  188. }
  189. }
  190. #endif
  191. REPLACE_FAIL_MSG_EPILOGUE;
  192. };
  193. m_opr_trait[opr::ElemwiseMultiType::typeinfo()].get_replace_fail_msg =
  194. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  195. bool has_scalar = false;
  196. for (auto&& inp : opr->input()) {
  197. if (inp->shape().is_scalar()) {
  198. has_scalar = true;
  199. break;
  200. }
  201. }
  202. if (has_scalar)
  203. return "ElemwiseMultiType with scalar input is not supported.";
  204. for (auto&& inp : opr->input()) {
  205. if (inp->dtype().enumv() != DTypeEnum::QuantizedS8)
  206. return "Unsupported data type.";
  207. }
  208. if (opr->output(0)->dtype().enumv() != DTypeEnum::QuantizedS8)
  209. return "Unsupported data type.";
  210. using Mode = opr::ElemwiseMultiType::Mode;
  211. auto mode = opr->cast_final_safe<opr::ElemwiseMultiType>().param().mode;
  212. if (mode != Mode::QFUSE_ADD_RELU && mode != Mode::QADD) {
  213. return "Unsupported ElemwiseMultiType mode.";
  214. }
  215. REPLACE_FAIL_MSG_EPILOGUE;
  216. };
  217. m_opr_trait[opr::Convolution::typeinfo()].get_replace_fail_msg =
  218. [this](OperatorNodeBase* opr) -> Maybe<std::string> {
  219. if (opr->input(0)->dtype() != dtype::Float32())
  220. return "Non-Float32 convolution is not supported.";
  221. if (!m_const_var_propogate->is_const(opr->input(1)))
  222. return "Weights not constant. Not replaceable in TRT.";
  223. auto&& param = opr->cast_final_safe<opr::Convolution>().param();
  224. if (param.format != ConvFormat::NCHW)
  225. return "TensorRT replace pass only support NCHW format "
  226. "convolution.";
  227. if (param.mode == opr::Convolution::Param::Mode::CONVOLUTION)
  228. return "TensorRT does not support non cross correlation "
  229. "convolution.";
  230. REPLACE_FAIL_MSG_EPILOGUE;
  231. };
  232. m_opr_trait[opr::ConvBias::typeinfo()].get_replace_fail_msg =
  233. [this](OperatorNodeBase* opr) -> Maybe<std::string> {
  234. if (opr->input(0)->dtype() != dtype::Float32() &&
  235. opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8)
  236. return "Convolution is only supported for float32 or qint8.";
  237. if (!m_const_var_propogate->is_const(opr->input(1)))
  238. return "Weights not constant. Not replaceable in TRT.";
  239. if (opr->input().size() >= 3) {
  240. if (!m_const_var_propogate->is_const(opr->input(2)))
  241. return "Bias not constant. Not replaceable in TRT.";
  242. }
  243. auto&& param = opr->cast_final_safe<opr::ConvBias>().param();
  244. if (param.format != ConvFormat::NCHW && param.format != ConvFormat::NCHW4)
  245. return "TensorRT replace pass only support NCHW format "
  246. "convolution.";
  247. if (param.mode == opr::ConvBias::Param::Mode::CONVOLUTION)
  248. return "TensorRT does not support non cross correlation "
  249. "convolution.";
  250. REPLACE_FAIL_MSG_EPILOGUE;
  251. };
  252. m_opr_trait[opr::ConvolutionBackwardData::typeinfo()].get_replace_fail_msg =
  253. [this](OperatorNodeBase* opr) -> Maybe<std::string> {
  254. if (opr->input(0)->dtype() != dtype::Float32())
  255. return "Non-Float32 Deconvolution is not supported.";
  256. if (!m_const_var_propogate->is_const(opr->input(0)))
  257. return "Weights not constant. Not replaceable in TRT.";
  258. auto&& param = opr->cast_final_safe<opr::ConvolutionBackwardData>().param();
  259. if (param.dilate_h != 1 || param.dilate_w != 1)
  260. return "TensorRT does not support dilation deconvolution.";
  261. if (param.format != ConvFormat::NCHW)
  262. return "TensorRT replace pass only support NCHW format deconv.";
  263. if (param.mode == opr::ConvBias::Param::Mode::CONVOLUTION)
  264. return "TensorRT does not support non cross correlation "
  265. "deconvolution.";
  266. REPLACE_FAIL_MSG_EPILOGUE;
  267. };
  268. m_opr_trait[opr::Pooling::typeinfo()].get_replace_fail_msg =
  269. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  270. auto pool = opr->try_cast_final<opr::Pooling>();
  271. auto&& param = pool->param();
  272. if (param.format != opr::Pooling::Param::Format::NCHW &&
  273. param.format != opr::Pooling::Param::Format::NCHW4)
  274. return "Pooling is only supported for NCHW and NCHW4";
  275. REPLACE_FAIL_MSG_EPILOGUE;
  276. };
  277. m_opr_trait[opr::Concat::typeinfo()].get_replace_fail_msg =
  278. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  279. if (opr->input(0)->dtype() != dtype::Float32() &&
  280. opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8) {
  281. return "Concat only support float32 and quantized int8.";
  282. }
  283. // TODO: TensorRT only supports concat on channel dimension,
  284. // we can set nvinfer1::DimensionType to kCHANNEL to support
  285. // concat on other dimension
  286. if (!(opr->input(0)->shape().ndim == 4 &&
  287. opr->cast_final_safe<opr::Concat>().param().axis == 1)) {
  288. return "Concat only support input is NCHW and axis is 1.";
  289. }
  290. REPLACE_FAIL_MSG_EPILOGUE;
  291. };
  292. m_opr_trait[opr::MatrixMul::typeinfo()].get_replace_fail_msg =
  293. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  294. if (opr->input(0)->dtype() != dtype::Float32())
  295. return "Non-Float32 MatrixMul is not supported.";
  296. REPLACE_FAIL_MSG_EPILOGUE;
  297. };
  298. m_opr_trait[opr::BatchedMatrixMul::typeinfo()].get_replace_fail_msg =
  299. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  300. if (opr->input(0)->dtype() != dtype::Float32())
  301. return "Non-Float32 MatrixMul is not supported.";
  302. REPLACE_FAIL_MSG_EPILOGUE;
  303. };
  304. m_opr_trait[opr::PowC::typeinfo()].get_replace_fail_msg =
  305. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  306. if (opr->input(0)->dtype() != dtype::Float32())
  307. return "Non-Float32 PowC is not supported.";
  308. if (opr->input(0)->shape().ndim < 3)
  309. return "Dimensions of input should be greater than or equal to "
  310. "3.";
  311. REPLACE_FAIL_MSG_EPILOGUE;
  312. };
  313. #undef REPLACE_FAIL_MSG_EPILOGUE
  314. // megdnn convolution opr on cuda backend does not support quantized
  315. // dtype, so we assume that megbrain int8 network for converting to fine
  316. // grained TensorRT subgraph does not include convolution operator with
  317. // quantized int8 data type
  318. m_opr_trait[opr::Convolution::typeinfo()]
  319. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  320. OperatorNodeBase* opr) {
  321. auto&& varnode2itensor =
  322. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  323. VarNode* input = opr->input(0);
  324. VarNode* kernel = opr->input(1);
  325. check_input(input, opr);
  326. nvinfer1::Weights wt_kernel{
  327. nvinfer1::DataType::kFLOAT, get_value(kernel).raw_ptr(),
  328. static_cast<int64_t>(kernel->shape().total_nr_elems())};
  329. nvinfer1::Weights wt_bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
  330. auto&& param = opr->cast_final_safe<opr::Convolution>().param();
  331. mgb_assert(
  332. param.format == megdnn::param::Convolution::Format::NCHW &&
  333. param.mode ==
  334. megdnn::param::Convolution::Mode::CROSS_CORRELATION,
  335. "conv param is not supported by TensorRT");
  336. size_t group_offset = 0;
  337. if (param.sparse == megdnn::param::Convolution::Sparse::GROUP) {
  338. group_offset = 1;
  339. } else {
  340. mgb_assert(
  341. param.sparse == megdnn::param::Convolution::Sparse::DENSE,
  342. "param.sparse should be GROUP or DENSE");
  343. }
  344. auto conv = net->addConvolution(
  345. *varnode2itensor[input], opr->output(0)->shape()[1],
  346. nvinfer1::DimsHW{
  347. static_cast<int>(kernel->shape()[group_offset + 2]),
  348. static_cast<int>(kernel->shape()[group_offset + 3])},
  349. wt_kernel, wt_bias);
  350. mgb_assert(conv, "construct network failed");
  351. std::string layer_name = "TRT_CONV:" + opr->name();
  352. conv->setName(layer_name.c_str());
  353. conv->setStride(nvinfer1::DimsHW{
  354. static_cast<int>(param.stride_h),
  355. static_cast<int>(param.stride_w)});
  356. conv->setPadding(nvinfer1::DimsHW{
  357. static_cast<int>(param.pad_h), static_cast<int>(param.pad_w)});
  358. conv->setDilation(nvinfer1::DimsHW{
  359. static_cast<int>(param.dilate_h),
  360. static_cast<int>(param.dilate_w)});
  361. if (group_offset > 0)
  362. conv->setNbGroups(static_cast<int>(kernel->shape()[0]));
  363. m_opr2convlayer[opr] = conv;
  364. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  365. conv->getOutput(0)->setName(output_name.c_str());
  366. varnode2itensor[opr->output(0)] = conv->getOutput(0);
  367. };
  368. // support floating point data type and quantized data type
  369. m_opr_trait[opr::ConvBiasForward::typeinfo()].add_to_nvinfer =
  370. [this](nvinfer1::INetworkDefinition* net, OperatorNodeBase* opr) {
  371. auto&& varnode2itensor =
  372. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  373. using Param = opr::ConvBias::Param;
  374. using NonlineMode = Param::NonlineMode;
  375. using Sparse = Param::Sparse;
  376. using Format = Param::Format;
  377. auto conv_bias = try_cast_as_op<opr::ConvBias>(opr);
  378. auto&& param = conv_bias->param();
  379. mgb_assert(
  380. param.mode == Param::Mode::CROSS_CORRELATION,
  381. "Trt only support CROSS_CORRELATION convolution.");
  382. bool is_format_nchw4 = param.format == Format::NCHW4;
  383. bool is_qint8 = is_quantized_int8_operator(opr);
  384. if (is_format_nchw4)
  385. mgb_assert(is_qint8);
  386. // set kernel and bias
  387. VarNode* input = conv_bias->input(0);
  388. VarNode* kernel = conv_bias->input(1);
  389. check_input(input, opr);
  390. nvinfer1::Weights wt_kernel{
  391. nvinfer1::DataType::kFLOAT,
  392. get_value(kernel, param.format).raw_ptr(),
  393. static_cast<int64_t>(kernel->shape().total_nr_elems())};
  394. nvinfer1::Weights wt_bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
  395. if (conv_bias->input().size() >= 3) {
  396. VarNode* bias = conv_bias->input(2);
  397. wt_bias.values = get_value(bias, param.format).raw_ptr();
  398. wt_bias.count =
  399. static_cast<int64_t>(bias->shape().total_nr_elems());
  400. }
  401. // determine conv shape
  402. int co = 0;
  403. int sh = param.stride_h, sw = param.stride_w, ph = param.pad_h,
  404. pw = param.pad_w, dh = param.dilate_h, dw = param.dilate_w;
  405. size_t group_offset = 0;
  406. int groups = 1;
  407. if (param.sparse == Sparse::GROUP) {
  408. groups = kernel->shape()[0];
  409. group_offset = 1;
  410. } else {
  411. mgb_assert(
  412. param.sparse == Sparse::DENSE,
  413. "sparse should be GROUP or DENSE");
  414. }
  415. int fh = kernel->shape()[group_offset + 2],
  416. fw = kernel->shape()[group_offset + 3];
  417. if (param.format == Format::NCHW) {
  418. mgb_assert(
  419. conv_bias->input(0)->dtype() == dtype::Float32(),
  420. "conv bias only support Float32 with NCHW format");
  421. co = conv_bias->output(0)->shape()[1];
  422. } else if (param.format == Format::NCHW4) {
  423. mgb_assert(
  424. conv_bias->input(0)->dtype().enumv() ==
  425. DTypeEnum::QuantizedS8 &&
  426. conv_bias->output(0)->dtype().enumv() ==
  427. DTypeEnum::QuantizedS8,
  428. "conv bias only support QuantizedS8 with NCHW4 format");
  429. co = conv_bias->output(0)->shape()[1] * 4;
  430. }
  431. mgb_assert(co > 0);
  432. // process conv
  433. auto conv = net->addConvolution(
  434. *varnode2itensor[input], co, nvinfer1::DimsHW{fh, fw},
  435. wt_kernel, wt_bias);
  436. mgb_assert(conv, "construct network failed");
  437. std::string layer_name = "TRT_CONV:" + conv_bias->name();
  438. conv->setName(layer_name.c_str());
  439. conv->setStride(nvinfer1::DimsHW{sh, sw});
  440. conv->setPadding(nvinfer1::DimsHW{ph, pw});
  441. conv->setDilation(nvinfer1::DimsHW{dh, dw});
  442. if (group_offset > 0)
  443. conv->setNbGroups(groups);
  444. std::string output_name = "TRT_O:" + conv_bias->output(0)->name();
  445. conv->getOutput(0)->setName(output_name.c_str());
  446. varnode2itensor[conv_bias->output(0)] = conv->getOutput(0);
  447. if (is_qint8)
  448. set_itensor_dynamic_range(conv_bias->output(0), conv_bias);
  449. // process short cut add
  450. if (conv_bias->input().size() >= 4) {
  451. check_input(conv_bias->input(3), opr);
  452. auto add = net->addElementWise(
  453. *varnode2itensor[conv_bias->output(0)],
  454. *varnode2itensor[conv_bias->input(3)],
  455. nvinfer1::ElementWiseOperation::kSUM);
  456. mgb_assert(add, "construct network failed");
  457. std::string layer_name = "TRT_ELEM:" + conv_bias->name();
  458. add->setName(layer_name.c_str());
  459. std::string output_name =
  460. "TRT_O:" + conv_bias->output(0)->name() +
  461. "_shortcut_add";
  462. add->getOutput(0)->setName(output_name.c_str());
  463. varnode2itensor[conv_bias->output(0)] = add->getOutput(0);
  464. if (is_qint8)
  465. set_itensor_dynamic_range(conv_bias->output(0), conv_bias);
  466. }
  467. // process activation
  468. if (param.nonlineMode != Param::NonlineMode::IDENTITY) {
  469. nvinfer1::ActivationType act_type =
  470. param.nonlineMode == NonlineMode::RELU
  471. ? nvinfer1::ActivationType::kRELU
  472. : nvinfer1::ActivationType::kSIGMOID;
  473. auto act = net->addActivation(
  474. *varnode2itensor[conv_bias->output(0)], act_type);
  475. mgb_assert(act, "construct network failed");
  476. std::string layer_name = "TRT_ACTV:" + conv_bias->name();
  477. act->setName(layer_name.c_str());
  478. std::string output_name =
  479. "TRT_O:" + conv_bias->output(0)->name() + "_act";
  480. act->getOutput(0)->setName(output_name.c_str());
  481. varnode2itensor[conv_bias->output(0)] = act->getOutput(0);
  482. if (is_qint8)
  483. set_itensor_dynamic_range(conv_bias->output(0), conv_bias);
  484. }
  485. };
  486. // megbrain deconvolution operator does not support quantized data type
  487. m_opr_trait[opr::ConvolutionBackwardData::typeinfo()]
  488. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  489. OperatorNodeBase* opr) {
  490. auto&& varnode2itensor =
  491. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  492. VarNode* kernel = opr->input(0);
  493. VarNode* input = opr->input(1);
  494. check_input(input, opr);
  495. nvinfer1::Weights wt_kernel{
  496. nvinfer1::DataType::kFLOAT, get_value(kernel).raw_ptr(),
  497. static_cast<int64_t>(kernel->shape().total_nr_elems())};
  498. nvinfer1::Weights wt_bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
  499. auto&& param = opr->cast_final_safe<opr::ConvolutionBackwardData>().param();
  500. mgb_assert(
  501. param.format == megdnn::param::Convolution::Format::NCHW &&
  502. param.mode == megdnn::param::Convolution::Mode::
  503. CROSS_CORRELATION &&
  504. param.dilate_h == 1 && param.dilate_w == 1,
  505. "conv param is not supported by TensorRT");
  506. size_t group_offset = 0;
  507. if (param.sparse == megdnn::param::Convolution::Sparse::GROUP) {
  508. group_offset = 1;
  509. } else {
  510. mgb_assert(
  511. param.sparse == megdnn::param::Convolution::Sparse::DENSE,
  512. "param.sparse should be GROUP or DENSE");
  513. }
  514. auto deconv = net->addDeconvolution(
  515. *varnode2itensor[input], opr->output(0)->shape()[1],
  516. nvinfer1::DimsHW{
  517. static_cast<int>(kernel->shape()[group_offset + 2]),
  518. static_cast<int>(kernel->shape()[group_offset + 3])},
  519. wt_kernel, wt_bias);
  520. mgb_assert(deconv, "construct network failed");
  521. std::string layer_name = "TRT_DCON:" + opr->name();
  522. deconv->setName(layer_name.c_str());
  523. deconv->setStride(nvinfer1::DimsHW{
  524. static_cast<int>(param.stride_h),
  525. static_cast<int>(param.stride_w)});
  526. deconv->setPadding(nvinfer1::DimsHW{
  527. static_cast<int>(param.pad_h), static_cast<int>(param.pad_w)});
  528. if (group_offset > 0)
  529. deconv->setNbGroups(static_cast<int>(kernel->shape()[0]));
  530. m_opr2deconvlayer[opr] = deconv;
  531. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  532. deconv->getOutput(0)->setName(output_name.c_str());
  533. varnode2itensor[opr->output(0)] = deconv->getOutput(0);
  534. };
  535. // support floating point data type and quantized data type
  536. m_opr_trait[opr::Pooling::typeinfo()]
  537. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  538. OperatorNodeBase* opr) {
  539. auto&& varnode2itensor =
  540. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  541. using Param = opr::Pooling::Param;
  542. using Mode = Param::Mode;
  543. using Format = Param::Format;
  544. static ThinHashMap<Mode, nvinfer1::PoolingType> pooling_type_map = {
  545. {Mode::MAX, nvinfer1::PoolingType::kMAX},
  546. {Mode::AVERAGE, nvinfer1::PoolingType::kAVERAGE},
  547. {Mode::AVERAGE_COUNT_EXCLUDE_PADDING,
  548. nvinfer1::PoolingType::kAVERAGE}};
  549. auto&& param = opr->cast_final_safe<opr::Pooling>().param();
  550. check_input(opr->input(0), opr);
  551. auto pool = net->addPooling(
  552. *varnode2itensor[opr->input(0)], pooling_type_map.at(param.mode),
  553. nvinfer1::DimsHW{
  554. static_cast<int>(param.window_h),
  555. static_cast<int>(param.window_w)});
  556. mgb_assert(pool, "construct network failed");
  557. std::string layer_name = "TRT_POOL:" + opr->name();
  558. pool->setName(layer_name.c_str());
  559. pool->setPadding(nvinfer1::DimsHW{
  560. static_cast<int>(param.pad_h), static_cast<int>(param.pad_w)});
  561. pool->setStride(nvinfer1::DimsHW{
  562. static_cast<int>(param.stride_h),
  563. static_cast<int>(param.stride_w)});
  564. //! According to the documentation of TensorRT, the default value of
  565. //! exclusive is true. So we need to set exclusive to false when pooling
  566. //! mode is average
  567. if (param.mode == Mode::AVERAGE_COUNT_EXCLUDE_PADDING)
  568. pool->setAverageCountExcludesPadding(true);
  569. else if (param.mode == Mode::AVERAGE)
  570. pool->setAverageCountExcludesPadding(false);
  571. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  572. pool->getOutput(0)->setName(output_name.c_str());
  573. varnode2itensor[opr->output(0)] = pool->getOutput(0);
  574. if (param.format == Format::NCHW4) {
  575. mgb_assert(
  576. opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8,
  577. "Pooling with NCHW4 format should use quantized "
  578. "int8 data type");
  579. set_itensor_dynamic_range(opr->output(0), opr);
  580. }
  581. };
  582. m_opr_trait[opr::Concat::typeinfo()].add_to_nvinfer =
  583. [this](nvinfer1::INetworkDefinition* net, OperatorNodeBase* opr) {
  584. auto&& varnode2itensor =
  585. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  586. size_t input_size = opr->input().size();
  587. std::unique_ptr<nvinfer1::ITensor*[]> input_tensors(
  588. new nvinfer1::ITensor*[input_size]);
  589. for (size_t i = 0; i < input_size; ++i) {
  590. check_input(opr->input(i), opr);
  591. input_tensors[i] = varnode2itensor[opr->input(i)];
  592. }
  593. auto concat = net->addConcatenation(
  594. input_tensors.get(), static_cast<int>(input_size));
  595. mgb_assert(concat, "construct Concatenation layer failed!");
  596. std::string layer_name = "TRT_CCAT:" + opr->name();
  597. concat->setName(layer_name.c_str());
  598. int axis = opr->cast_final_safe<opr::Concat>().param().axis;
  599. concat->setAxis(axis);
  600. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  601. concat->getOutput(0)->setName(output_name.c_str());
  602. varnode2itensor[opr->output(0)] = concat->getOutput(0);
  603. if (is_quantized_int8_operator(opr)) {
  604. set_itensor_dynamic_range(opr->output(0), opr);
  605. }
  606. };
  607. // support floating point data type and quantized data type
  608. m_opr_trait[opr::Elemwise::typeinfo()]
  609. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  610. OperatorNodeBase* opr) {
  611. auto&& varnode2itensor =
  612. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  613. using Mode = opr::Elemwise::Mode;
  614. auto mode = opr->cast_final_safe<opr::Elemwise>().param().mode;
  615. auto get_dimtype = [&](int ndim) {
  616. SmallVector<TENSORRT_NO_DIMENSIONTYPE(nvinfer1::DimensionType)>
  617. dimtypes(ndim);
  618. for (int i = 0; i < ndim; i++) {
  619. dimtypes[i] = TENSORRT_NO_DIMENSIONTYPE_VALUE(
  620. nvinfer1::DimensionType::kSPATIAL);
  621. }
  622. return dimtypes;
  623. };
  624. auto on_elemwise_arity_unary = [this, &varnode2itensor, &net, &opr,
  625. &get_dimtype](
  626. nvinfer1::UnaryOperation unary_op) {
  627. size_t tensor_ndim = opr->input(0)->shape().ndim;
  628. check_input(opr->input(0), opr, get_dimtype(tensor_ndim));
  629. auto unary = net->addUnary(*varnode2itensor[opr->input(0)], unary_op);
  630. mgb_assert(unary, "construct network failed");
  631. std::string layer_name = "TRT_UNARY:" + opr->name();
  632. unary->setName(layer_name.c_str());
  633. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  634. unary->getOutput(0)->setName(output_name.c_str());
  635. varnode2itensor[opr->output(0)] = unary->getOutput(0);
  636. };
  637. auto on_elemwise_arity_activation =
  638. [this, &varnode2itensor, &net, &opr,
  639. &get_dimtype](nvinfer1::ActivationType act_type) {
  640. size_t tensor_ndim = opr->input(0)->shape().ndim;
  641. check_input(opr->input(0), opr, get_dimtype(tensor_ndim));
  642. auto act = net->addActivation(
  643. *varnode2itensor[opr->input(0)], act_type);
  644. mgb_assert(act, "construct network failed");
  645. std::string layer_name = "TRT_ACTV:" + opr->name();
  646. act->setName(layer_name.c_str());
  647. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  648. act->getOutput(0)->setName(output_name.c_str());
  649. varnode2itensor[opr->output(0)] = act->getOutput(0);
  650. };
  651. auto on_elemwise_arity_binary =
  652. [this, &varnode2itensor, &net, &opr,
  653. &get_dimtype](nvinfer1::ElementWiseOperation elem_op) {
  654. size_t ndim0 = opr->input(0)->shape().ndim,
  655. ndim1 = opr->input(1)->shape().ndim;
  656. mgb_assert(ndim0 == ndim1);
  657. size_t tensor_ndim = ndim0;
  658. bool inp0_new = check_input(
  659. opr->input(0), opr, get_dimtype(tensor_ndim));
  660. bool inp1_new = check_input(
  661. opr->input(1), opr, get_dimtype(tensor_ndim));
  662. if (inp0_new && inp1_new) {
  663. mgb_log_warn(
  664. "Both operands of Elemwise are newly prepared. "
  665. "This is rare. "
  666. "Please check. opr=%s inputs=%s",
  667. opr->cname(),
  668. cg::dump_var_info(opr->input()).c_str());
  669. }
  670. auto dims0 = varnode2itensor[opr->input(0)]->getDimensions(),
  671. dims1 = varnode2itensor[opr->input(1)]->getDimensions();
  672. mgb_throw_if(
  673. dims0.nbDims != dims1.nbDims, AssertionError,
  674. "Input dimensions of two input tensors must be "
  675. "equal (got: %d, %d).",
  676. dims0.nbDims, dims1.nbDims);
  677. auto elem = net->addElementWise(
  678. *varnode2itensor[opr->input(0)],
  679. *varnode2itensor[opr->input(1)], elem_op);
  680. mgb_assert(elem, "construct network failed");
  681. std::string layer_name = "TRT_ELEM:" + opr->name();
  682. elem->setName(layer_name.c_str());
  683. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  684. elem->getOutput(0)->setName(output_name.c_str());
  685. varnode2itensor[opr->output(0)] = elem->getOutput(0);
  686. };
  687. switch (mode) {
  688. #define cb(mode) \
  689. case Mode::mode: \
  690. on_elemwise_arity_unary(nvinfer1::UnaryOperation::k##mode); \
  691. break;
  692. #if NV_TENSOR_RT_VERSION >= 5105
  693. #define MGB_FOREACH_UNARY_OPERATION(cb) \
  694. cb(EXP) cb(LOG) cb(ABS) cb(SIN) cb(COS) cb(ASIN) cb(ACOS) cb(CEIL) cb(FLOOR)
  695. #else
  696. #define MGB_FOREACH_UNARY_OPERATION(cb) cb(EXP) cb(LOG) cb(ABS)
  697. #endif
  698. MGB_FOREACH_UNARY_OPERATION(cb)
  699. #undef cb
  700. #undef MGB_FOREACH_UNARY_OPERATION
  701. #define cb(mode) \
  702. case Mode::mode: \
  703. on_elemwise_arity_activation(nvinfer1::ActivationType::k##mode); \
  704. break;
  705. #define MGB_FOREACH_ACTIVATION_TYPE(cb) cb(RELU) cb(SIGMOID) cb(TANH)
  706. MGB_FOREACH_ACTIVATION_TYPE(cb)
  707. #undef cb
  708. #undef MGB_FOREACH_ACTIVATION_TYPE
  709. case Mode::ADD: {
  710. VarNode *opr_var, *bias_var;
  711. VarNodeArray result = find_parent_conv(opr);
  712. if (result.size() > 0) {
  713. opr_var = result[0];
  714. bias_var = result[1];
  715. nvinfer1::Weights wt_bias{
  716. nvinfer1::DataType::kFLOAT,
  717. get_value(bias_var).raw_ptr(),
  718. static_cast<int64_t>(
  719. bias_var->shape().total_nr_elems())};
  720. if (opr_var->owner_opr()->same_type<opr::Convolution>()) {
  721. m_opr2convlayer[opr_var->owner_opr()]->setBiasWeights(
  722. wt_bias);
  723. } else if (opr_var->owner_opr()
  724. ->same_type<
  725. opr::ConvolutionBackwardData>()) {
  726. m_opr2deconvlayer[opr_var->owner_opr()]->setBiasWeights(
  727. wt_bias);
  728. }
  729. varnode2itensor[opr->output(0)] = varnode2itensor[result[2]];
  730. break;
  731. }
  732. on_elemwise_arity_binary(nvinfer1::ElementWiseOperation::kSUM);
  733. break;
  734. }
  735. case Mode::MUL:
  736. on_elemwise_arity_binary(nvinfer1::ElementWiseOperation::kPROD);
  737. break;
  738. case Mode::MIN:
  739. on_elemwise_arity_binary(nvinfer1::ElementWiseOperation::kMIN);
  740. break;
  741. case Mode::MAX:
  742. on_elemwise_arity_binary(nvinfer1::ElementWiseOperation::kMAX);
  743. break;
  744. case Mode::SUB:
  745. on_elemwise_arity_binary(nvinfer1::ElementWiseOperation::kSUB);
  746. break;
  747. case Mode::TRUE_DIV:
  748. on_elemwise_arity_binary(nvinfer1::ElementWiseOperation::kDIV);
  749. break;
  750. case Mode::POW:
  751. on_elemwise_arity_binary(nvinfer1::ElementWiseOperation::kPOW);
  752. break;
  753. case Mode::FUSE_ADD_RELU: {
  754. on_elemwise_arity_binary(nvinfer1::ElementWiseOperation::kSUM);
  755. if (is_quantized_int8_operator(opr))
  756. set_itensor_dynamic_range(opr->output(0), opr);
  757. auto act = net->addActivation(
  758. *varnode2itensor[opr->output(0)],
  759. nvinfer1::ActivationType::kRELU);
  760. mgb_assert(act, "construct network failed");
  761. std::string layer_name = "TRT_ACTV:" + opr->name();
  762. act->setName(layer_name.c_str());
  763. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  764. act->getOutput(0)->setName(output_name.c_str());
  765. varnode2itensor[opr->output(0)] = act->getOutput(0);
  766. break;
  767. }
  768. case Mode::FUSE_ADD_SIGMOID: {
  769. on_elemwise_arity_binary(nvinfer1::ElementWiseOperation::kSUM);
  770. if (is_quantized_int8_operator(opr))
  771. set_itensor_dynamic_range(opr->output(0), opr);
  772. auto act = net->addActivation(
  773. *varnode2itensor[opr->output(0)],
  774. nvinfer1::ActivationType::kSIGMOID);
  775. mgb_assert(act, "construct network failed");
  776. std::string layer_name = "TRT_ACTV:" + opr->name();
  777. act->setName(layer_name.c_str());
  778. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  779. act->getOutput(0)->setName(output_name.c_str());
  780. varnode2itensor[opr->output(0)] = act->getOutput(0);
  781. break;
  782. }
  783. case Mode::FUSE_ADD_TANH: {
  784. on_elemwise_arity_binary(nvinfer1::ElementWiseOperation::kSUM);
  785. if (is_quantized_int8_operator(opr))
  786. set_itensor_dynamic_range(opr->output(0), opr);
  787. auto act = net->addActivation(
  788. *varnode2itensor[opr->output(0)],
  789. nvinfer1::ActivationType::kTANH);
  790. mgb_assert(act, "construct network failed");
  791. std::string layer_name = "TRT_ACTV:" + opr->name();
  792. act->setName(layer_name.c_str());
  793. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  794. act->getOutput(0)->setName(output_name.c_str());
  795. varnode2itensor[opr->output(0)] = act->getOutput(0);
  796. break;
  797. }
  798. default:
  799. mgb_assert(false, "Unsupported elemwise mode.");
  800. }
  801. if (is_quantized_int8_operator(opr))
  802. set_itensor_dynamic_range(opr->output(0), opr);
  803. };
  804. m_opr_trait[opr::ElemwiseMultiType::typeinfo()]
  805. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  806. OperatorNodeBase* opr) {
  807. auto&& varnode2itensor =
  808. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  809. size_t ndim0 = opr->input(0)->shape().ndim,
  810. ndim1 = opr->input(1)->shape().ndim;
  811. mgb_assert(ndim0 == ndim1);
  812. size_t tensor_ndim = ndim0;
  813. using Mode = opr::ElemwiseMultiType::Mode;
  814. SmallVector<TENSORRT_NO_DIMENSIONTYPE(nvinfer1::DimensionType)> dimtypes(
  815. tensor_ndim);
  816. for (size_t i = 0; i < tensor_ndim; i++) {
  817. dimtypes[i] = TENSORRT_NO_DIMENSIONTYPE_VALUE(
  818. nvinfer1::DimensionType::kSPATIAL);
  819. }
  820. auto mode = opr->cast_final_safe<opr::ElemwiseMultiType>().param().mode;
  821. mgb_assert(
  822. mode == Mode::QADD || mode == Mode::QFUSE_ADD_RELU,
  823. "Only QADD and QFUSE_ADD_RELU are supported on CUDA.");
  824. mgb_assert(
  825. opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8,
  826. "output data type %s is not supported",
  827. opr->output(0)->dtype().name());
  828. check_input(opr->input(0), opr, dimtypes);
  829. check_input(opr->input(1), opr, dimtypes);
  830. auto dims0 = varnode2itensor[opr->input(0)]->getDimensions(),
  831. dims1 = varnode2itensor[opr->input(1)]->getDimensions();
  832. mgb_throw_if(
  833. dims0.nbDims != dims1.nbDims, AssertionError,
  834. "Input dimensions of two input tensors must be "
  835. "equal (got: %d, %d).",
  836. dims0.nbDims, dims1.nbDims);
  837. auto elem = net->addElementWise(
  838. *varnode2itensor[opr->input(0)], *varnode2itensor[opr->input(1)],
  839. nvinfer1::ElementWiseOperation::kSUM);
  840. mgb_assert(elem, "construct network failed");
  841. std::string layer_name = "TRT_ELEM:" + opr->name();
  842. elem->setName(layer_name.c_str());
  843. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  844. elem->getOutput(0)->setName(output_name.c_str());
  845. varnode2itensor[opr->output(0)] = elem->getOutput(0);
  846. set_itensor_dynamic_range(opr->output(0), opr);
  847. if (mode == Mode::QFUSE_ADD_RELU) {
  848. auto act = net->addActivation(
  849. *varnode2itensor[opr->output(0)],
  850. nvinfer1::ActivationType::kRELU);
  851. mgb_assert(act, "construct network failed");
  852. std::string layer_name = "TRT_ACTV:" + opr->name();
  853. act->setName(layer_name.c_str());
  854. std::string output_name = "TRT_O:" + opr->output()[0]->name() + "_act";
  855. act->getOutput(0)->setName(output_name.c_str());
  856. varnode2itensor[opr->output(0)] = act->getOutput(0);
  857. set_itensor_dynamic_range(opr->output(0), opr);
  858. }
  859. };
  860. auto replace_matmul_opr = [this](nvinfer1::INetworkDefinition* net,
  861. OperatorNodeBase* opr) {
  862. auto&& varnode2itensor =
  863. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  864. SmallVector<TENSORRT_NO_DIMENSIONTYPE(nvinfer1::DimensionType)> dimtypes;
  865. bool transposeA = false, transposeB = false;
  866. if (opr->same_type<opr::MatrixMul>()) {
  867. dimtypes = {
  868. TENSORRT_NO_DIMENSIONTYPE_VALUE(
  869. nvinfer1::DimensionType::kSPATIAL),
  870. TENSORRT_NO_DIMENSIONTYPE_VALUE(
  871. nvinfer1::DimensionType::kSPATIAL)};
  872. transposeA = opr->cast_final_safe<opr::MatrixMul>().param().transposeA;
  873. transposeB = opr->cast_final_safe<opr::MatrixMul>().param().transposeB;
  874. } else {
  875. mgb_assert(opr->same_type<opr::BatchedMatrixMul>());
  876. dimtypes = {
  877. TENSORRT_NO_DIMENSIONTYPE_VALUE(
  878. nvinfer1::DimensionType::kINDEX),
  879. TENSORRT_NO_DIMENSIONTYPE_VALUE(
  880. nvinfer1::DimensionType::kSPATIAL),
  881. TENSORRT_NO_DIMENSIONTYPE_VALUE(
  882. nvinfer1::DimensionType::kSPATIAL)};
  883. transposeA = opr->cast_final_safe<opr::BatchedMatrixMul>()
  884. .param()
  885. .transposeA;
  886. transposeB = opr->cast_final_safe<opr::BatchedMatrixMul>()
  887. .param()
  888. .transposeB;
  889. }
  890. check_input(opr->input(0), opr, dimtypes);
  891. check_input(opr->input(1), opr, dimtypes);
  892. #if NV_TENSOR_RT_VERSION >= 6001
  893. nvinfer1::MatrixOperation
  894. opA = transposeA ? nvinfer1::MatrixOperation::kTRANSPOSE
  895. : nvinfer1::MatrixOperation::kNONE,
  896. opB = transposeB ? nvinfer1::MatrixOperation::kTRANSPOSE
  897. : nvinfer1::MatrixOperation::kNONE;
  898. auto matmul = net->addMatrixMultiply(
  899. *varnode2itensor[opr->input(0)], opA,
  900. *varnode2itensor[opr->input(1)], opB);
  901. #else
  902. auto matmul = net->addMatrixMultiply(
  903. *varnode2itensor[opr->input(0)], transposeA,
  904. *varnode2itensor[opr->input(1)], transposeB);
  905. #endif
  906. std::string layer_name = "TRT_MATMUL:" + opr->name();
  907. matmul->setName(layer_name.c_str());
  908. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  909. matmul->getOutput(0)->setName(output_name.c_str());
  910. varnode2itensor[opr->output(0)] = matmul->getOutput(0);
  911. };
  912. // megdnn matrix mul operator on cuda backend does not support quantized
  913. // data type
  914. m_opr_trait[opr::MatrixMul::typeinfo()].add_to_nvinfer = replace_matmul_opr;
  915. m_opr_trait[opr::BatchedMatrixMul::typeinfo()].add_to_nvinfer =
  916. replace_matmul_opr;
  917. // powc only support float32
  918. m_opr_trait[opr::PowC::typeinfo()]
  919. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  920. OperatorNodeBase* opr) {
  921. auto&& varnode2itensor =
  922. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  923. size_t tensor_ndim = opr->input(0)->shape().ndim;
  924. SmallVector<TENSORRT_NO_DIMENSIONTYPE(nvinfer1::DimensionType)> dimtypes(
  925. tensor_ndim);
  926. for (size_t i = 0; i < tensor_ndim; i++) {
  927. dimtypes[i] = TENSORRT_NO_DIMENSIONTYPE_VALUE(
  928. nvinfer1::DimensionType::kSPATIAL);
  929. }
  930. check_input(opr->input(0), opr, dimtypes);
  931. auto host_one = HostTensorND(
  932. opr->output(0)->comp_node(), {1}, dtype::Float32()),
  933. host_zero = HostTensorND(
  934. opr->output(0)->comp_node(), {1}, dtype::Float32()),
  935. host_exp = HostTensorND(
  936. opr->output(0)->comp_node(), {1}, dtype::Float32());
  937. *(reinterpret_cast<float*>(host_one.raw_ptr())) = 1;
  938. *(reinterpret_cast<float*>(host_zero.raw_ptr())) = 0;
  939. *(reinterpret_cast<float*>(host_exp.raw_ptr())) =
  940. opr->cast_final_safe<opr::PowC>().param().exp;
  941. auto ptr = opr->owner_graph()
  942. ->options()
  943. .user_data.get_user_data_or_create<HostTensorKeeper>();
  944. ptr->htr.push_back(host_one);
  945. ptr->htr.push_back(host_zero);
  946. ptr->htr.push_back(host_exp);
  947. auto scale = net->addScale(
  948. *varnode2itensor[opr->input(0)], nvinfer1::ScaleMode::kUNIFORM,
  949. nvinfer1::Weights{
  950. nvinfer1::DataType::kFLOAT, host_zero.raw_ptr(), 1},
  951. nvinfer1::Weights{
  952. nvinfer1::DataType::kFLOAT, host_one.raw_ptr(), 1},
  953. nvinfer1::Weights{
  954. nvinfer1::DataType::kFLOAT, host_exp.raw_ptr(), 1});
  955. std::string layer_name = "TRT_SCALE:" + opr->name();
  956. scale->setName(layer_name.c_str());
  957. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  958. scale->getOutput(0)->setName(output_name.c_str());
  959. varnode2itensor[opr->output(0)] = scale->getOutput(0);
  960. };
  961. m_opr_num = 0;
  962. m_opr_fail_num = 0;
  963. detect_replace();
  964. mark_varnode_format_nchw4();
  965. update_graph();
  966. if (!m_opr_fail.empty()) {
  967. std::string msg{"TRT replace summary:\n"};
  968. msg += ssprintf(" number of oprs: %zu\n", m_opr_num);
  969. msg += ssprintf(" number of unsupported oprs: %zu\n", m_opr_fail_num);
  970. msg += ssprintf(" first %zu unsupported oprs:\n", m_opr_fail.size());
  971. for (size_t i = 0; i < m_opr_fail.size(); ++i) {
  972. msg += ssprintf(
  973. " %s {%s}: %s\n", m_opr_fail[i].opr->cname(),
  974. m_opr_fail[i].opr->dyn_typeinfo()->name,
  975. m_opr_fail[i].fail_msg.c_str());
  976. }
  977. msg.pop_back();
  978. mgb_log("%s", msg.c_str());
  979. }
  980. }
  981. };
  982. MGB_TYPEINFO_OBJ_IMPL(TensorRTReplacePass::Impl::HostTensorKeeper);
  983. Maybe<std::string> TensorRTReplacePass::Impl::has_fail_msg(OperatorNodeBase* opr) {
  984. auto iter = m_opr_trait.find(opr->dyn_typeinfo());
  985. if (iter != m_opr_trait.end()) {
  986. if (iter->second.get_replace_fail_msg) {
  987. return iter->second.get_replace_fail_msg(opr);
  988. }
  989. return None;
  990. }
  991. return "Opr not supported.";
  992. }
  993. VarNodeArray TensorRTReplacePass::Impl::find_parent_conv(OperatorNodeBase* inp_opr) {
  994. OperatorNodeBase* owner_opr;
  995. VarNodeArray vars_to_check, new_vars, rst;
  996. bool conv_output_found = false;
  997. VarNode* conv_output_var = nullptr;
  998. VarNode* bias_var = nullptr;
  999. VarNode* new_output_var = nullptr;
  1000. if (m_const_var_propogate->is_const(inp_opr->input(0))) {
  1001. vars_to_check.push_back(inp_opr->input(1));
  1002. new_output_var = inp_opr->input(1);
  1003. bias_var = inp_opr->input(0);
  1004. } else if (m_const_var_propogate->is_const(inp_opr->input(1))) {
  1005. vars_to_check.push_back(inp_opr->input(0));
  1006. new_output_var = inp_opr->input(0);
  1007. bias_var = inp_opr->input(1);
  1008. } else {
  1009. // No const input. return empty rst.
  1010. return rst;
  1011. }
  1012. while (vars_to_check.size() != 0) {
  1013. for (size_t i = 0; i < vars_to_check.size(); ++i) {
  1014. owner_opr = vars_to_check[i]->owner_opr();
  1015. if (owner_opr->same_type<opr::Convolution>() ||
  1016. owner_opr->same_type<opr::ConvolutionBackwardData>()) {
  1017. conv_output_found = true;
  1018. conv_output_var = vars_to_check[i];
  1019. break;
  1020. }
  1021. if (owner_opr->same_type<opr::Elemwise>() &&
  1022. owner_opr->cast_final<opr::Elemwise>().param().mode ==
  1023. opr::Elemwise::Mode::ADD) {
  1024. for (auto var2chk : owner_opr->input()) {
  1025. new_vars.push_back(var2chk);
  1026. }
  1027. }
  1028. }
  1029. vars_to_check.clear();
  1030. if (conv_output_found)
  1031. break;
  1032. if (new_vars.size() != 0) {
  1033. vars_to_check.insert(vars_to_check.end(), new_vars.begin(), new_vars.end());
  1034. new_vars.clear();
  1035. }
  1036. }
  1037. if (conv_output_found) {
  1038. conv_output_found &=
  1039. m_graph_map[inp_opr] == m_graph_map[conv_output_var->owner_opr()];
  1040. auto&& trt_graph = m_tensorrt_graphs[m_graph_map[inp_opr] - 1];
  1041. conv_output_found &= trt_graph->outputs.count(conv_output_var) == 0;
  1042. }
  1043. if (conv_output_found) {
  1044. rst.push_back(conv_output_var);
  1045. rst.push_back(bias_var);
  1046. rst.push_back(new_output_var);
  1047. }
  1048. return rst;
  1049. }
  1050. bool TensorRTReplacePass::Impl::check_input(
  1051. VarNode* var, OperatorNodeBase* opr,
  1052. SmallVector<TENSORRT_NO_DIMENSIONTYPE(nvinfer1::DimensionType)> dimtypes) {
  1053. auto trt_graph = m_tensorrt_graphs[m_graph_map[opr] - 1];
  1054. auto&& varnode2itensor = trt_graph->varnode2itensor;
  1055. auto iter = trt_graph->inputs.find(var);
  1056. if (iter == trt_graph->inputs.end()) // not a input of trt graph
  1057. return false;
  1058. for (auto i : trt_graph->trt_inputs)
  1059. if (i == var) // already added to input
  1060. return false;
  1061. trt_graph->trt_inputs.push_back(var);
  1062. nvinfer1::ITensor* itensor;
  1063. MGB_MARK_USED_VAR(mgb_dtype_to_trt_dtype);
  1064. if (dimtypes.size() == 0) {
  1065. #if NV_TENSOR_RT_VERSION >= 6001
  1066. mgb_assert(
  1067. var->shape().ndim == 4 ||
  1068. (var->shape().ndim == 5 && var->shape()[4] == 4));
  1069. nvinfer1::Dims4 dims{
  1070. static_cast<int>(var->shape()[0]), static_cast<int>(var->shape()[1]),
  1071. static_cast<int>(var->shape()[2]), static_cast<int>(var->shape()[3])};
  1072. if (var->shape().ndim == 5) {
  1073. mgb_assert(var->shape()[4] == 4);
  1074. dims.d[1] *= 4;
  1075. }
  1076. itensor = trt_graph->network->addInput(
  1077. var->cname(), mgb_dtype_to_trt_dtype(var->dtype()), dims);
  1078. if (trt_graph->mark_input_varnode_nchw4.count(var)) {
  1079. itensor->setAllowedFormats(
  1080. 1 << static_cast<int>(nvinfer1::TensorFormat::kCHW4));
  1081. } else {
  1082. itensor->setAllowedFormats(
  1083. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR));
  1084. }
  1085. #else
  1086. if (var->shape().ndim == 4) {
  1087. // the default input tensor is a NCHW tensor
  1088. mgb_assert(
  1089. var->shape().ndim == 4,
  1090. "Default input tensor should be NCHW or NCHW4 format.");
  1091. itensor = trt_graph->network->addInput(
  1092. var->cname(), nvinfer1::DataType::kFLOAT,
  1093. nvinfer1::DimsNCHW{
  1094. static_cast<int>(var->shape()[0]),
  1095. static_cast<int>(var->shape()[1]),
  1096. static_cast<int>(var->shape()[2]),
  1097. static_cast<int>(var->shape()[3])});
  1098. } else {
  1099. mgb_assert(
  1100. var->shape().ndim == 5 && var->shape()[4] == 4,
  1101. "Input tensor format is not NCHW4 (got %s)",
  1102. var->shape().to_string().c_str());
  1103. itensor = trt_graph->network->addInput(
  1104. var->cname(), nvinfer1::DataType::kFLOAT,
  1105. nvinfer1::DimsNCHW{
  1106. static_cast<int>(var->shape()[0]),
  1107. static_cast<int>(var->shape()[1] * 4),
  1108. static_cast<int>(var->shape()[2]),
  1109. static_cast<int>(var->shape()[3])});
  1110. }
  1111. #endif
  1112. } else {
  1113. nvinfer1::Dims dims;
  1114. // process var node that marked as nchw4 format
  1115. if (trt_graph->mark_input_varnode_nchw4.count(var)) {
  1116. mgb_assert(
  1117. var->shape().ndim == 5 && var->shape()[4] == 4,
  1118. "Input tensor format is not NCHW4 (got %s)",
  1119. var->shape().to_string().c_str());
  1120. dims.nbDims = var->shape().ndim - 1;
  1121. for (size_t i = 0; i < var->shape().ndim - 1; i++) {
  1122. dims.d[i] = var->shape()[i];
  1123. #if NV_TENSOR_RT_VERSION < 6001
  1124. dims.type[i] = dimtypes[i];
  1125. #endif
  1126. }
  1127. dims.d[1] *= 4;
  1128. // process conventional var node
  1129. } else {
  1130. mgb_assert(var->shape().ndim == dimtypes.size());
  1131. mgb_assert(var->shape().ndim <= nvinfer1::Dims::MAX_DIMS);
  1132. dims.nbDims = var->shape().ndim;
  1133. for (size_t i = 0; i < var->shape().ndim; i++) {
  1134. dims.d[i] = var->shape()[i];
  1135. #if NV_TENSOR_RT_VERSION < 6001
  1136. dims.type[i] = dimtypes[i];
  1137. #endif
  1138. }
  1139. }
  1140. #if NV_TENSOR_RT_VERSION >= 6001
  1141. itensor = trt_graph->network->addInput(
  1142. var->cname(), mgb_dtype_to_trt_dtype(var->dtype()), dims);
  1143. if (trt_graph->mark_input_varnode_nchw4.count(var)) {
  1144. itensor->setAllowedFormats(
  1145. 1 << static_cast<int>(nvinfer1::TensorFormat::kCHW4));
  1146. } else {
  1147. itensor->setAllowedFormats(
  1148. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR));
  1149. }
  1150. #else
  1151. itensor = trt_graph->network->addInput(
  1152. var->cname(), nvinfer1::DataType::kFLOAT, dims);
  1153. #endif
  1154. }
  1155. varnode2itensor[var] = itensor;
  1156. if (trt_graph->feature_bits == TensorRTGraphFeatureBits::NCHW4_QINT8)
  1157. set_itensor_dynamic_range(var, opr);
  1158. return true;
  1159. }
  1160. void TensorRTReplacePass::Impl::set_itensor_dynamic_range(
  1161. VarNode* var, OperatorNodeBase* opr) {
  1162. MGB_MARK_USED_VAR(var);
  1163. MGB_MARK_USED_VAR(opr);
  1164. #if NV_TENSOR_RT_VERSION >= 5020
  1165. auto&& varnode2itensor = m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  1166. auto&& tensor = varnode2itensor[var];
  1167. auto&& data_type = var->dtype();
  1168. mgb_assert(data_type.enumv() == DTypeEnum::QuantizedS8);
  1169. float scale = get_scale(data_type);
  1170. tensor->setDynamicRange(-i8_max * scale, i8_max * scale);
  1171. #endif
  1172. }
  1173. HostTensorND TensorRTReplacePass::Impl::get_value(VarNode* var, ConvFormat format) {
  1174. auto cg = m_opt_state.graph().comp_graph();
  1175. auto inferred_val = HostTensorND(var->comp_node(), dtype::Float32());
  1176. auto cb = [&](DeviceTensorND& val) { inferred_val.copy_from(val); };
  1177. if (format == ConvFormat::NCHW) {
  1178. mgb_assert(var->dtype() == dtype::Float32());
  1179. auto orig_level = cg->options().log_level;
  1180. cg->options().log_level = 0;
  1181. MGB_TRY { cg->compile({{var, cb}})->execute(); }
  1182. MGB_FINALLY(cg->options().log_level = orig_level);
  1183. } else {
  1184. mgb_assert(format == ConvFormat::NCHW4);
  1185. if (var->shape().ndim == 5) {
  1186. // assume nchw4 layout
  1187. mgb_assert(var->shape()[4] == 4);
  1188. auto x = SymbolVar(var);
  1189. auto xshp = opr::GetVarShape::make(x);
  1190. auto cv = [&x](int v) { return x.make_scalar(v); };
  1191. auto sub = [&xshp, &cv](int idx) {
  1192. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1193. };
  1194. auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
  1195. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  1196. auto y1 = opr::Reshape::make(y0, tshp);
  1197. if (var->dtype().enumv() == DTypeEnum::QuantizedS8 ||
  1198. var->dtype().enumv() == DTypeEnum::QuantizedS32) {
  1199. y1 = opr::TypeCvt::make(y1, dtype::Float32());
  1200. }
  1201. auto orig_level = cg->options().log_level;
  1202. cg->options().log_level = 0;
  1203. cg->options().graph_opt.tensorrt = false;
  1204. MGB_TRY { cg->compile({{y1.node(), cb}})->execute(); }
  1205. MGB_FINALLY({
  1206. cg->options().log_level = orig_level;
  1207. cg->options().graph_opt.tensorrt = true;
  1208. });
  1209. } else if (var->shape().ndim == 6) {
  1210. // assume nchw4 layout
  1211. mgb_assert(var->shape()[5] == 4);
  1212. mgb_assert(
  1213. var->dtype().enumv() == DTypeEnum::QuantizedS8 ||
  1214. var->dtype() == dtype::Float32());
  1215. auto x = SymbolVar(var);
  1216. auto xshp = opr::GetVarShape::make(x);
  1217. auto cv = [&x](int v) { return x.make_scalar(v); };
  1218. auto sub = [&xshp, &cv](int idx) {
  1219. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1220. };
  1221. auto tshp =
  1222. opr::Concat::make({sub(0), sub(1), sub(2) * 4, sub(3), sub(4)}, 0);
  1223. auto y0 = opr::Dimshuffle::make(x, {0, 1, 2, 5, 3, 4});
  1224. auto y1 = opr::Reshape::make(y0, tshp);
  1225. if (var->dtype().enumv() == DTypeEnum::QuantizedS8) {
  1226. y1 = opr::TypeCvt::make(y1, dtype::Float32());
  1227. }
  1228. auto orig_level = cg->options().log_level;
  1229. cg->options().log_level = 0;
  1230. cg->options().graph_opt.tensorrt = false;
  1231. MGB_TRY { cg->compile({{y1.node(), cb}})->execute(); }
  1232. MGB_FINALLY({
  1233. cg->options().log_level = orig_level;
  1234. cg->options().graph_opt.tensorrt = true;
  1235. });
  1236. }
  1237. }
  1238. auto ptr = var->owner_graph()
  1239. ->options()
  1240. .user_data.get_user_data_or_create<HostTensorKeeper>();
  1241. ptr->htr.push_back(inferred_val);
  1242. return inferred_val;
  1243. }
  1244. float TensorRTReplacePass::Impl::get_scale(DType data_type) {
  1245. float scale = 1.f;
  1246. #define cb(_dt) \
  1247. case DTypeTrait<_dt>::enumv: \
  1248. scale = data_type.param<_dt>().scale; \
  1249. break;
  1250. switch (data_type.enumv()) {
  1251. MEGDNN_FOREACH_QUANTIZED_DTYPE(cb);
  1252. default:
  1253. mgb_throw(
  1254. InternalError, "invalid quantized data type: %s", data_type.name());
  1255. }
  1256. return scale;
  1257. #undef cb
  1258. }
  1259. bool TensorRTReplacePass::Impl::is_quantized_int8_operator(OperatorNodeBase* opr) {
  1260. bool is_quantized = true;
  1261. if (opr->same_type<opr::ConvBias>()) {
  1262. is_quantized = opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
  1263. mgb_assert(
  1264. !is_quantized ||
  1265. opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8);
  1266. return is_quantized;
  1267. }
  1268. for (auto&& inp : opr->input()) {
  1269. if (inp->dtype().enumv() != DTypeEnum::QuantizedS8) {
  1270. is_quantized = false;
  1271. break;
  1272. }
  1273. }
  1274. // assume all operator has only one output
  1275. auto&& out = opr->output(0);
  1276. if (out->dtype().enumv() != DTypeEnum::QuantizedS8) {
  1277. is_quantized = false;
  1278. }
  1279. return is_quantized;
  1280. }
  1281. void TensorRTReplacePass::Impl::detect_replace() {
  1282. auto cb = [this](OperatorNodeBase* opr) { m_const_var_propogate->add_opr(opr); };
  1283. m_opt_state.graph().iter(cb);
  1284. auto on_opr = [this](OperatorNodeBase* opr) {
  1285. ++m_opr_num;
  1286. Maybe<std::string> irreplaceable_msg = has_fail_msg(opr);
  1287. TensorRTGraphFeatureBits feature_bits =
  1288. is_quantized_int8_operator(opr) ? TensorRTGraphFeatureBits::NCHW4_QINT8
  1289. : TensorRTGraphFeatureBits::NCHW_FLOAT;
  1290. if (!irreplaceable_msg.valid()) {
  1291. size_t max = 1;
  1292. for (auto i : opr->input()) {
  1293. if (!has_fail_msg(i->owner_opr()).valid())
  1294. update_max(max, m_graph_map[i->owner_opr()]);
  1295. else
  1296. update_max(max, m_graph_map[i->owner_opr()] + 1);
  1297. }
  1298. size_t max_update = max;
  1299. for (; max_update <= m_tensorrt_graphs.size(); max_update++) {
  1300. TensorRTGraphFeatureBits trt_graph_feature_bits =
  1301. m_tensorrt_graphs[max_update - 1]->feature_bits;
  1302. if (trt_graph_feature_bits == feature_bits)
  1303. break;
  1304. }
  1305. max = max_update;
  1306. m_graph_map[opr] = max;
  1307. if (max > m_tensorrt_graphs.size()) {
  1308. opr->output(0)->comp_node().activate();
  1309. m_tensorrt_graphs.push_back(
  1310. std::make_shared<TensorRTGraph>(feature_bits));
  1311. }
  1312. for (auto i : opr->input()) {
  1313. if (m_graph_map[i->owner_opr()] != max) {
  1314. m_tensorrt_graphs[max - 1]->inputs.insert(i);
  1315. if (!has_fail_msg(i->owner_opr()).valid()) {
  1316. //! TODO: check
  1317. m_tensorrt_graphs[m_graph_map[i->owner_opr()] - 1]
  1318. ->outputs.insert(i);
  1319. }
  1320. }
  1321. }
  1322. } else {
  1323. static const ThinHashSet<Typeinfo*> ignore_types{
  1324. opr::SharedDeviceTensor::typeinfo(),
  1325. opr::ImmutableTensor::typeinfo(), opr::Host2DeviceCopy::typeinfo(),
  1326. opr::MultipleDeviceTensorHolder::typeinfo()};
  1327. if (!ignore_types.count(opr->dyn_typeinfo())) {
  1328. ++m_opr_fail_num;
  1329. if (m_opr_fail.size() < OPR_FAIL_LOG_NUM) {
  1330. FailInfo fail_info;
  1331. fail_info.opr = opr;
  1332. fail_info.fail_msg = irreplaceable_msg.val();
  1333. m_opr_fail.push_back(fail_info);
  1334. }
  1335. }
  1336. size_t max = 0;
  1337. for (auto i : opr->input()) {
  1338. if (m_graph_map[i->owner_opr()] > max)
  1339. max = m_graph_map[i->owner_opr()];
  1340. if (!has_fail_msg(i->owner_opr()).valid()) {
  1341. //! TODO: check
  1342. m_tensorrt_graphs[m_graph_map[i->owner_opr()] - 1]->outputs.insert(
  1343. i);
  1344. }
  1345. }
  1346. m_graph_map[opr] = max;
  1347. }
  1348. };
  1349. m_opt_state.graph().iter(on_opr);
  1350. for (auto i : m_opt_state.graph().endpoint_vars()) {
  1351. auto var_node = i.node();
  1352. if (!has_fail_msg(var_node->owner_opr()).valid()) {
  1353. //! TODO: check
  1354. m_tensorrt_graphs[m_graph_map[var_node->owner_opr()] - 1]->outputs.insert(
  1355. var_node);
  1356. }
  1357. }
  1358. }
  1359. void TensorRTReplacePass::Impl::mark_varnode_format_nchw4() {
  1360. for (auto trt_graph : m_tensorrt_graphs) {
  1361. trt_graph->mark_varnode_format_nchw4();
  1362. }
  1363. }
  1364. void TensorRTReplacePass::Impl::update_graph() {
  1365. using GpuAllocator = opr::TensorRTOpr::GpuAllocator;
  1366. using TensorRTOpr = opr::TensorRTOpr;
  1367. std::shared_ptr<GpuAllocator> gpu_allocator;
  1368. std::shared_ptr<ExtraDep> extra_dep = std::make_shared<ExtraDep>();
  1369. // construct trt network
  1370. auto construct_network = [this, &gpu_allocator, &extra_dep](OperatorNodeBase* opr) {
  1371. if (!has_fail_msg(opr).valid()) {
  1372. auto cn = opr->output(0)->comp_node();
  1373. auto trt_graph = m_tensorrt_graphs[m_graph_map[opr] - 1];
  1374. auto b = trt_graph->builder;
  1375. mgb_assert(b != nullptr);
  1376. if (!gpu_allocator) {
  1377. gpu_allocator = std::make_shared<GpuAllocator>(cn);
  1378. b->setGpuAllocator(gpu_allocator.get());
  1379. } else {
  1380. auto cn0 = gpu_allocator->comp_node();
  1381. mgb_assert(
  1382. cn0 == cn,
  1383. "multiple comp nodes for trt graph are not "
  1384. "supported: %s %s",
  1385. cn0.to_string().c_str(), cn.to_string().c_str());
  1386. }
  1387. if (!trt_graph->network) {
  1388. #if NV_TENSOR_RT_VERSION >= 6001
  1389. nvinfer1::NetworkDefinitionCreationFlags flags;
  1390. flags = 1 << static_cast<int>(nvinfer1::NetworkDefinitionCreationFlag::
  1391. kEXPLICIT_BATCH);
  1392. trt_graph->network = b->createNetworkV2(flags);
  1393. #else
  1394. trt_graph->network = b->createNetwork();
  1395. #endif
  1396. }
  1397. // make extra dep
  1398. for (auto&& inp : trt_graph->inputs) {
  1399. extra_dep->operator[](opr).push_back(inp);
  1400. }
  1401. auto iter = m_opr_trait.find(opr->dyn_typeinfo());
  1402. if (iter != m_opr_trait.end()) {
  1403. if (iter->second.add_to_nvinfer) {
  1404. iter->second.add_to_nvinfer(trt_graph->network, opr);
  1405. }
  1406. }
  1407. }
  1408. };
  1409. m_opt_state.graph().iter(construct_network);
  1410. // trt network markOutput
  1411. for (auto trt_graph : m_tensorrt_graphs) {
  1412. // record traverse order
  1413. size_t idx = 0;
  1414. auto&& varnode2itensor = trt_graph->varnode2itensor;
  1415. for (auto output : trt_graph->outputs) {
  1416. trt_graph->output2idx[output] = idx++;
  1417. trt_graph->network->markOutput(*varnode2itensor[output]);
  1418. #if NV_TENSOR_RT_VERSION >= 6001
  1419. if (output->dtype().enumv() == DTypeEnum::QuantizedS8) {
  1420. varnode2itensor[output]->setType(nvinfer1::DataType::kINT8);
  1421. }
  1422. if (trt_graph->mark_output_varnode_nchw4.count(output)) {
  1423. mgb_assert(output->dtype().enumv() == DTypeEnum::QuantizedS8);
  1424. varnode2itensor[output]->setAllowedFormats(
  1425. 1 << static_cast<int>(nvinfer1::TensorFormat::kCHW4));
  1426. } else {
  1427. varnode2itensor[output]->setAllowedFormats(
  1428. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR));
  1429. }
  1430. #endif
  1431. }
  1432. }
  1433. ThinHashSet<OperatorNodeBase*> visited;
  1434. // replace opr by trt
  1435. auto update_opr = [this, &gpu_allocator, &visited](OperatorNodeBase* opr) {
  1436. if (!has_fail_msg(opr).valid()) {
  1437. mgb_assert(gpu_allocator);
  1438. auto trt_graph = m_tensorrt_graphs[m_graph_map[opr] - 1];
  1439. for (auto&& inp : trt_graph->trt_inputs) {
  1440. mgb_assert(visited.count(inp->owner_opr()));
  1441. }
  1442. if (trt_graph->trt_outputs.empty()) {
  1443. // use updated varnode instead of old one
  1444. auto inps = trt_graph->trt_inputs;
  1445. VarNodeArray new_inps{inps.size()};
  1446. for (size_t i = 0; i < inps.size(); i++) {
  1447. new_inps[i] = m_rewriter.get_var(inps[i]);
  1448. #if NV_TENSOR_RT_VERSION < 6001
  1449. if (trt_graph->mark_input_varnode_nchw4.count(inps[i])) {
  1450. auto x = SymbolVar(new_inps[i]);
  1451. auto xshp = opr::GetVarShape::make(x);
  1452. auto cv = [&x](int v) { return x.make_scalar(v); };
  1453. auto sub = [&xshp, &cv](int idx) {
  1454. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1455. };
  1456. auto tshp = opr::Concat::make(
  1457. {sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
  1458. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  1459. auto y1 = opr::Reshape::make(y0, tshp);
  1460. new_inps[i] = y1.node();
  1461. }
  1462. if (inps[i]->dtype().enumv() == DTypeEnum::QuantizedS8) {
  1463. new_inps[i] = opr::TypeCvt::make(new_inps[i], dtype::Float32())
  1464. .node();
  1465. }
  1466. #endif
  1467. }
  1468. // now trt_graph does not own the unique_ptr of infer builder
  1469. m_opt_state.call_with_opr(opr, [&] {
  1470. trt_graph->trt_outputs = cg::to_var_node_array(TensorRTOpr::make(
  1471. TensorRTOpr::to_shared_ptr_builder(trt_graph->builder),
  1472. TensorRTOpr::to_shared_ptr_network(trt_graph->network),
  1473. trt_graph->feature_bits, gpu_allocator,
  1474. cg::to_symbol_var_array(new_inps)));
  1475. });
  1476. mgb_assert(
  1477. trt_graph->trt_outputs.size() == trt_graph->outputs.size(),
  1478. "mgb outputs number != tensorrt outputs number");
  1479. }
  1480. for (auto&& output : opr->output()) {
  1481. if (trt_graph->outputs.count(output)) {
  1482. size_t output_idx = trt_graph->output2idx[output];
  1483. VarNode* output_var = trt_graph->trt_outputs[output_idx];
  1484. #if NV_TENSOR_RT_VERSION < 6001
  1485. if (trt_graph->mark_output_varnode_nchw4.count(output)) {
  1486. auto x = SymbolVar(output_var);
  1487. auto xshp = opr::GetVarShape::make(x);
  1488. auto cv = [&x](int v) { return x.make_scalar(v); };
  1489. auto sub = [&xshp, &cv](int idx) {
  1490. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1491. };
  1492. auto tshp = opr::Concat::make(
  1493. {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  1494. auto y0 = opr::Reshape::make(x, tshp);
  1495. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  1496. output_var = y1.node();
  1497. }
  1498. if (output->dtype().enumv() == DTypeEnum::QuantizedS8) {
  1499. float scale = get_scale(output->dtype());
  1500. output_var = opr::TypeCvt::make(
  1501. output_var, dtype::QuantizedS8{scale})
  1502. .node();
  1503. }
  1504. #endif
  1505. m_rewriter.replace_var(
  1506. output, output_var,
  1507. mgb_ssprintf_log(
  1508. "replace opr: %s", output->owner_opr()->cname())
  1509. .c_str());
  1510. }
  1511. }
  1512. visited.insert(opr);
  1513. } else {
  1514. for (auto&& inp : opr->input()) {
  1515. mgb_assert(visited.count(inp->owner_opr()));
  1516. }
  1517. visited.insert(opr);
  1518. m_rewriter.auto_replace_outputs(opr);
  1519. }
  1520. };
  1521. m_opt_state.graph().iter(update_opr, std::move(extra_dep));
  1522. m_rewriter.apply_inplace();
  1523. }
  1524. const char* TensorRTReplacePass::name() const {
  1525. return mgb_cstr_log("tensorrt_replace");
  1526. }
  1527. void TensorRTReplacePass::apply(OptState& opt) const {
  1528. if (CompNode::get_device_count(CompNode::DeviceType::CUDA)) {
  1529. opt.set_var_replace_check_flag(
  1530. gopt::VarReplaceCheckFlag::CHECK_SHAPE |
  1531. gopt::VarReplaceCheckFlag::CHECK_DTYPE);
  1532. Impl(*this, opt);
  1533. } else {
  1534. mgb_log_debug("cuda is not available; TensorRTReplacePass is ignored");
  1535. }
  1536. }
  1537. // ===================== TensorRTGraph =================
  1538. void TensorRTReplacePass::Impl::TensorRTGraph::mark_varnode_format_nchw4() {
  1539. // consider TensorRT subgraph as a bi-directed graph and divide it into
  1540. // multi connected components, mark the subgraph's inputs or outputs varnode
  1541. // in format nchw4 iff the varnode belong to the connected components which
  1542. // contains at least one NCHW4 operator(e.g. ConvBias, Pooling)
  1543. // p[arrent] array use for Disjoint Set
  1544. ThinHashMap<OperatorNodeBase*, OperatorNodeBase*> p;
  1545. ThinHashSet<OperatorNodeBase*> outsides;
  1546. thin_function<OperatorNodeBase*(OperatorNodeBase*)> get_root;
  1547. get_root = [&](OperatorNodeBase* opr) -> OperatorNodeBase* {
  1548. mgb_assert(p.count(opr));
  1549. return p[opr] == opr ? opr : p[opr] = get_root(p[opr]);
  1550. };
  1551. auto is_format_nchw4 = [&](OperatorNodeBase* opr) {
  1552. if (outsides.count(opr)) {
  1553. return false;
  1554. }
  1555. if (opr->same_type<opr::ConvBias>()) {
  1556. auto&& param = opr->cast_final_safe<opr::ConvBias>().param();
  1557. if (param.format == opr::ConvBias::Param::Format::NCHW4)
  1558. return true;
  1559. }
  1560. if (opr->same_type<opr::Pooling>()) {
  1561. auto&& param = opr->cast_final_safe<opr::Pooling>().param();
  1562. if (param.format == opr::Pooling::Param::Format::NCHW4)
  1563. return true;
  1564. }
  1565. return false;
  1566. };
  1567. auto cb = [&](OperatorNodeBase* opr) {
  1568. mgb_assert(!p.count(opr));
  1569. p[opr] = opr;
  1570. for (auto&& inp : opr->input()) {
  1571. auto root = get_root(inp->owner_opr());
  1572. // ensure that if one of oprs in tree is nchw4
  1573. // the root of the tree must be nchw4
  1574. if (is_format_nchw4(root)) {
  1575. p[get_root(opr)] = root;
  1576. } else {
  1577. p[root] = get_root(opr);
  1578. }
  1579. }
  1580. };
  1581. DepOprIter iter{cb};
  1582. for (auto&& inp : inputs) {
  1583. p[inp->owner_opr()] = inp->owner_opr();
  1584. iter.set_visited(inp->owner_opr());
  1585. outsides.insert(inp->owner_opr());
  1586. }
  1587. for (auto&& out : outputs) {
  1588. iter.add(out->owner_opr());
  1589. }
  1590. for (auto&& inp : inputs) {
  1591. if (is_format_nchw4(get_root(inp->owner_opr()))) {
  1592. mark_input_varnode_nchw4.insert(inp);
  1593. }
  1594. }
  1595. for (auto&& out : outputs) {
  1596. if (is_format_nchw4(get_root(out->owner_opr()))) {
  1597. mark_output_varnode_nchw4.insert(out);
  1598. }
  1599. }
  1600. }
  1601. void mgb::tensorrt::transform_dest_vars_inplace(
  1602. mgb::cg::VarNodeArray& dest_vars, cg::GraphCommonOptimizeOptions& options) {
  1603. gopt::GraphOptimizer optimizer;
  1604. //! As in megengine, the layout is NCHW, while tensorrt pass currently
  1605. //! only support NCHW4(int8), so we transform layout to nchw4 firstly.
  1606. if (options.has_set_nchw4()) {
  1607. options.disable_nchw4();
  1608. optimizer.add_pass<FuseConvBiasNonlinPass>();
  1609. optimizer.add_pass(EnableNCHW4Pass::make_nchw4_converter());
  1610. }
  1611. optimizer.add_pass<ExpandFusedArithPass>();
  1612. optimizer.add_pass<gopt::TensorRTReplacePass>();
  1613. optimizer.add_pass<ArithFusePass>();
  1614. #if NV_TENSOR_RT_VERSION < 6001
  1615. optimizer.add_pass<ShuffleShuffleRemovePass>();
  1616. optimizer.add_pass<RemoveRedundantTypeCvtPass>();
  1617. #endif
  1618. optimizer.apply_inplace(dest_vars);
  1619. }
  1620. #pragma GCC diagnostic pop
  1621. #endif
  1622. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}