You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolution.cpp 94 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399
  1. /**
  2. * \file src/opr/impl/dnn/convolution.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/dnn/convolution.h"
  12. #include "megbrain/opr/io.h"
  13. #include "megbrain/graph/grad_impl.h"
  14. #include "megbrain/system.h"
  15. #include "megbrain/utils/hash_ct.h"
  16. #include "megbrain/utils/timer.h"
  17. #include "megdnn/oprs/utils.h"
  18. //! TODO: here has to be know some megdnn::opr when there is produced midout.h
  19. //! fix it if there is another graceful way.
  20. #include "megdnn/oprs.h"
  21. #include "midout.h"
  22. MIDOUT_DECL(megbrain_opr_convolution)
  23. #define MIDOUT_B(...) \
  24. MIDOUT_BEGIN(megbrain_opr_convolution, __VA_ARGS__) {
  25. #define MIDOUT_E \
  26. } \
  27. MIDOUT_END();
  28. #include "../internal/megdnn_opr_wrapper.inl"
  29. #include "../internal/invoke.h"
  30. #include <array>
  31. #include <chrono>
  32. #include <cstring>
  33. #include <thread>
  34. using namespace mgb;
  35. using namespace opr;
  36. using namespace cg::static_infer;
  37. using intl::WorkspaceLimitGetter;
  38. #define CACHE_KEY_VERSION "v2"
  39. #define MGB_FOREACH_FASTRUN_OPR(cb) \
  40. cb(ConvolutionForward); \
  41. cb(ConvBiasForward); \
  42. cb(ConvolutionBackwardData); \
  43. cb(ConvolutionBackwardFilter); \
  44. cb(Convolution3DForward); \
  45. cb(Convolution3DBackwardData); \
  46. cb(Convolution3DBackwardFilter); \
  47. cb(LocalShareForward); \
  48. cb(LocalShareBackwardData); \
  49. cb(LocalShareBackwardFilter); \
  50. cb(DeformableConvForward); \
  51. cb(DeformableConvBackwardFilter); \
  52. cb(DeformableConvBackwardData); \
  53. cb(BatchConvBiasForward);
  54. namespace mgb {
  55. namespace opr {
  56. namespace intl {
  57. #define cb(_Opr) \
  58. template <> \
  59. struct AutoAddWorkspaceNeedLimitGetter<megdnn::_Opr> { \
  60. static constexpr bool val = true; \
  61. };
  62. MGB_FOREACH_FASTRUN_OPR(cb)
  63. #undef cb
  64. } // namespace intl
  65. } // namespace opr
  66. } // namespace mgb
  67. namespace {
  68. template <class MegDNNOpr>
  69. struct MegDNNOpr2MGBOpr;
  70. #define cb(_Opr) \
  71. template <> \
  72. struct MegDNNOpr2MGBOpr<megdnn::_Opr> { \
  73. using MGBOpr = opr::_Opr; \
  74. };
  75. MGB_FOREACH_FASTRUN_OPR(cb)
  76. #undef cb
  77. template <typename Opr>
  78. struct OprArityTrait;
  79. template <typename Opr, int _arity_in, int _arity_out>
  80. struct OprArityTraitTmpl {
  81. static constexpr int arity_in = _arity_in;
  82. static constexpr int arity_out = _arity_out;
  83. static constexpr int arity = arity_in + arity_out;
  84. };
  85. #define INST_ARITY(_Opr, _in, _out) \
  86. template <> \
  87. struct OprArityTrait<_Opr> : public OprArityTraitTmpl<_Opr, _in, _out> {};
  88. INST_ARITY(megdnn::ConvolutionBackwardData, 2, 1);
  89. INST_ARITY(megdnn::ConvolutionBackwardFilter, 2, 1);
  90. INST_ARITY(megdnn::Convolution3DForward, 2, 1);
  91. INST_ARITY(megdnn::Convolution3DBackwardData, 2, 1);
  92. INST_ARITY(megdnn::Convolution3DBackwardFilter, 2, 1);
  93. INST_ARITY(megdnn::LocalShareForward, 2, 1);
  94. INST_ARITY(megdnn::LocalShareBackwardData, 2, 1);
  95. INST_ARITY(megdnn::LocalShareBackwardFilter, 2, 1);
  96. INST_ARITY(megdnn::Convolution, 2, 1);
  97. INST_ARITY(megdnn::DeformableConvForward, 4, 1);
  98. INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1);
  99. INST_ARITY(megdnn::BatchConvBiasForward, 4, 1);
  100. INST_ARITY(megdnn::ConvBias, 4, 1);
  101. INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3);
  102. #undef INST_ARITY
  103. template <typename Opr>
  104. constexpr bool opr_supports_preprocess() {
  105. return std::is_same<Opr, megdnn::ConvolutionForward>::value ||
  106. std::is_same<Opr, megdnn::ConvBias>::value;
  107. }
  108. template <typename Opr, bool has_prep>
  109. struct PreprocessFilterImpl {
  110. using T = union {};
  111. };
  112. template <typename Opr>
  113. struct PreprocessFilterImpl<Opr, true> {
  114. using T = typename Opr::PreprocessedFilter;
  115. };
  116. template <typename Opr>
  117. using PreprocessFilter =
  118. typename PreprocessFilterImpl<Opr, opr_supports_preprocess<Opr>()>::T;
  119. // timeout delta to be added with fastest known algorithm for new algos
  120. constexpr double TIMEOUT_TOLERANCE = 2;
  121. template <typename Opr>
  122. struct AlgoChooserFuncId {};
  123. #define DEF_FUNC_ID(func) \
  124. template <> \
  125. struct AlgoChooserFuncId<megdnn::func> { \
  126. __attribute__( \
  127. (unused)) static constexpr sys::TimedFuncInvoker::FuncId ID = \
  128. static_cast<sys::TimedFuncInvoker::FuncId>( \
  129. MGB_HASH_STR("megdnn::" #func)); \
  130. };
  131. MGB_FOREACH_FASTRUN_OPR(DEF_FUNC_ID)
  132. #undef DEF_FUNC_ID
  133. /* =================== TimedProfiler =================== */
  134. /*!
  135. * \brief profile a megdnn opr conv with given param
  136. *
  137. * This class only provides static methods, and the entry point is
  138. * TimedProfiler::profile; it would run profiler in a timed environment by
  139. * sys::TimedFuncInvoker
  140. *
  141. * \tparam Opr megdnn opr impl
  142. */
  143. template <typename Opr>
  144. class TimedProfiler {
  145. static constexpr int arity_in = OprArityTrait<Opr>::arity_in;
  146. static constexpr int arity_out = OprArityTrait<Opr>::arity_out;
  147. static constexpr int arity = OprArityTrait<Opr>::arity;
  148. using ConvTensorShapes = std::array<TensorShape, arity>;
  149. public:
  150. struct Param {
  151. char algo_name[128];
  152. size_t workspace;
  153. DTypeEnum dtypes[arity];
  154. CompNode::Locator comp_node_loc;
  155. ConvTensorShapes shapes;
  156. typename Opr::Param opr_param;
  157. bool allow_weight_preprocess;
  158. //! filled by profile()
  159. mutable double actual_timeout;
  160. };
  161. struct Result {
  162. double time;
  163. };
  164. static Maybe<Result> profile(const Param& param, double& timeout) {
  165. mgb_assert(timeout >= 0);
  166. if (!timeout) {
  167. timeout = timeout_setting;
  168. } else if (timeout_setting) {
  169. timeout = std::min(timeout, timeout_setting);
  170. }
  171. param.actual_timeout =
  172. timeout ? timeout : std::numeric_limits<double>::infinity();
  173. auto res = sys::TimedFuncInvoker::ins().invoke(
  174. AlgoChooserFuncId<Opr>::ID,
  175. TParam::from_pod(const_cast<Param&>(param)), timeout);
  176. if (res.valid())
  177. return res.val().template as_single_pod<Result>();
  178. return None;
  179. }
  180. private:
  181. using TParam = sys::TimedFuncInvoker::Param;
  182. using TResult = sys::TimedFuncInvoker::Result;
  183. static const double timeout_setting;
  184. static double init_timeout_setting();
  185. static TResult prof_impl(const TParam& raw_param);
  186. static void prof_init_device(const TParam& raw_param);
  187. };
  188. template <typename Opr>
  189. const double TimedProfiler<Opr>::timeout_setting =
  190. TimedProfiler<Opr>::init_timeout_setting();
  191. template <typename Opr>
  192. double TimedProfiler<Opr>::init_timeout_setting() {
  193. #if MGB_ENABLE_FASTRUN
  194. sys::TimedFuncInvoker::ins().register_func(
  195. AlgoChooserFuncId<Opr>::ID, &TimedProfiler<Opr>::prof_impl,
  196. &TimedProfiler<Opr>::prof_init_device);
  197. auto to_set = MGB_GETENV("MGB_CONV_PROFILING_TIMEOUT");
  198. if (to_set)
  199. return std::stod(to_set);
  200. #endif
  201. return 0;
  202. }
  203. #define APPLY(statement, ...) \
  204. mgb::apply([&](const auto&... args) { return statement; }, \
  205. std::tuple_cat(__VA_ARGS__))
  206. template <typename Opr>
  207. typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
  208. const TParam& raw_param) {
  209. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("TimedProfiler::prof_impl")))
  210. auto&& param = raw_param.as_single_pod<Param>();
  211. CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc);
  212. auto megdnn_opr = intl::create_megdnn_opr<Opr>(cn);
  213. std::array<TensorLayout, arity> layouts;
  214. auto from_enum = [&](DTypeEnum enumv) -> DType {
  215. switch (enumv) {
  216. #define cb(_dt) \
  217. case DTypeTrait<_dt>::enumv: \
  218. return _dt(1.0f, static_cast<uint8_t>(0))
  219. cb(dtype::Quantized8Asymm);
  220. #undef cb
  221. #define cb(_dt) \
  222. case DTypeTrait<_dt>::enumv: \
  223. return _dt(1.0f)
  224. cb(dtype::QuantizedS8);
  225. cb(dtype::QuantizedS16);
  226. cb(dtype::QuantizedS32);
  227. default:
  228. return DType::from_enum(enumv);
  229. #undef cb
  230. }
  231. };
  232. for (int i = 0; i < arity; ++i) {
  233. layouts[i] = {param.shapes[i], from_enum(param.dtypes[i])};
  234. }
  235. megdnn_opr->param() = param.opr_param;
  236. {
  237. typename Opr::Algorithm* algo = nullptr;
  238. for (auto i : APPLY(megdnn_opr->get_all_algorithms(args...), layouts)) {
  239. if (!strcmp(i->name(), param.algo_name)) {
  240. algo = i;
  241. break;
  242. }
  243. }
  244. mgb_assert(algo, "algorithm %s not found", param.algo_name);
  245. megdnn_opr->execution_policy() = {algo};
  246. }
  247. // Allocate preprocessed weight buffers.
  248. TensorLayoutArray preprocessed_layout;
  249. if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) {
  250. if (param.allow_weight_preprocess) {
  251. preprocessed_layout = APPLY(
  252. _(megdnn_opr)->deduce_preprocessed_filter_layout(args...),
  253. layouts);
  254. }
  255. });
  256. {
  257. // first allocate a whole chunk to avoid memory fragmentation (here we
  258. // rely on memory allocator to reuse memory)
  259. auto align = cn.get_mem_addr_alignment();
  260. size_t tot_size = align;
  261. for (int i = 0; i < arity; ++i) {
  262. tot_size += layouts[i].span().high_byte + align;
  263. }
  264. for (const auto& layout : preprocessed_layout) {
  265. tot_size += layout.span().high_byte + align;
  266. }
  267. tot_size += param.workspace;
  268. DeviceTensorStorage storage{cn};
  269. storage.ensure_size(tot_size);
  270. }
  271. // allocate input and output memory
  272. std::array<DeviceTensorND, arity_in> inp_val;
  273. std::array<DeviceTensorND, arity_out> out_val;
  274. DeviceTensorND workspace;
  275. for (int i = 0; i < arity_in; ++i) {
  276. inp_val[i]
  277. .comp_node(cn)
  278. .dtype(layouts[i].dtype)
  279. .resize(layouts[i]);
  280. }
  281. for (int i = 0; i < arity_out; ++i) {
  282. out_val[i]
  283. .comp_node(cn)
  284. .dtype(layouts[arity_in + i].dtype)
  285. .resize(layouts[arity_in + i]);
  286. }
  287. megdnn::Workspace mdn_workspace;
  288. // allocate workspace
  289. if (param.workspace) {
  290. workspace.comp_node(cn).dtype(dtype::Byte()).resize({param.workspace});
  291. mdn_workspace.size = param.workspace;
  292. mdn_workspace.raw_ptr = workspace.raw_ptr();
  293. }
  294. // allocate storage for preprocessed filter
  295. SmallVector<DeviceTensorND> flt_val(preprocessed_layout.size());
  296. for (size_t i = 0; i < preprocessed_layout.size(); i++) {
  297. flt_val[i] = {cn, preprocessed_layout[i], preprocessed_layout[i].dtype,
  298. preprocessed_layout[i].format};
  299. }
  300. for (int i = 0; i < arity_in; ++i) {
  301. fill_zero_dev_tensor(inp_val[i]);
  302. }
  303. PreprocessFilter<Opr> prep_flt;
  304. if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) {
  305. if (!preprocessed_layout.empty()) {
  306. auto&& pf = _(prep_flt);
  307. pf.algorithm_id = nullptr;
  308. pf.tensors.resize(flt_val.size());
  309. for (size_t i = 0; i < flt_val.size(); i++) {
  310. pf.tensors[i] = flt_val[i].as_megdnn();
  311. }
  312. APPLY(_(megdnn_opr)->exec_preprocess(args..., &pf, mdn_workspace),
  313. std::forward_as_tuple(layouts[0], inp_val[1].as_megdnn()),
  314. array_skip<2>(layouts));
  315. }
  316. });
  317. RealTimer timer;
  318. auto ev_start = cn.create_event(CompNode::Event::NEED_TIMER),
  319. ev_end = cn.create_event(CompNode::Event::NEED_TIMER);
  320. ev_start->record();
  321. if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) {
  322. auto&& opr = _(megdnn_opr);
  323. PreprocessFilter<Opr>* pf =
  324. preprocessed_layout.empty() ? nullptr : &prep_flt;
  325. APPLY(opr->exec(args.as_megdnn()..., pf, mdn_workspace), inp_val,
  326. out_val);
  327. }, /* else */ [&](auto _) {
  328. APPLY(_(megdnn_opr)->exec(args.as_megdnn()..., mdn_workspace), inp_val,
  329. out_val);
  330. });
  331. ev_end->record();
  332. double next_report_time = 0.5;
  333. while (!ev_end->finished()) {
  334. if (timer.get_secs() >= next_report_time) {
  335. mgb_log_warn(
  336. "profiling conv algo %s already took %.3f/%.3f secs"
  337. " (limit can be set by MGB_CONV_PROFILING_TIMEOUT) ",
  338. param.algo_name, timer.get_secs(), param.actual_timeout);
  339. next_report_time = timer.get_secs() + 1;
  340. }
  341. using namespace std::literals;
  342. std::this_thread::sleep_for(1000us);
  343. }
  344. mgb_assert(ev_start->finished());
  345. return TResult::from_pod(Result{ev_start->elapsed_time_until(*ev_end)});
  346. MIDOUT_E
  347. };
  348. template <typename Opr>
  349. void TimedProfiler<Opr>::prof_init_device(const TParam& raw_param) {
  350. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("TimedProfiler::prof_init_device")))
  351. auto&& param = raw_param.as_single_pod<Param>();
  352. CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc);
  353. // wait for cuda init, so its time does not get accounted in timeout
  354. cn.sync();
  355. MIDOUT_E
  356. }
  357. /* =================== AlgoChooser =================== */
  358. /*!
  359. * \brief choose algorithm according to ExecutionPolicy
  360. *
  361. * This class only provides static methods, and the entry point is
  362. * AlgoChooser::setup_algo. When profiling is needed, it would first try to
  363. * retrive profiling stats from cache, and run TimedProfiler when necessary
  364. *
  365. * \tparam Opr megdnn operator impl
  366. */
  367. template <typename Opr>
  368. class AlgoChooser {
  369. static constexpr int arity_in = OprArityTrait<Opr>::arity_in;
  370. static constexpr int arity_out = OprArityTrait<Opr>::arity_out;
  371. static constexpr int arity = OprArityTrait<Opr>::arity;
  372. using ImplAlgo = typename Opr::Algorithm*;
  373. using MGBOpr = typename MegDNNOpr2MGBOpr<Opr>::MGBOpr;
  374. using ConvTensorLayouts = std::array<TensorLayout, arity>;
  375. class ExeContext {
  376. const ConvTensorLayouts& m_layouts;
  377. Opr* m_megdnn_opr;
  378. const MGBOpr* m_mgb_opr;
  379. bool m_allow_weight_preprocess;
  380. public:
  381. ExeContext(const ConvTensorLayouts& layouts, Opr* megdnn_opr,
  382. const MGBOpr* mgb_opr, bool allow_weight_preprocess)
  383. : m_layouts{layouts},
  384. m_megdnn_opr{megdnn_opr},
  385. m_mgb_opr{mgb_opr},
  386. m_allow_weight_preprocess{allow_weight_preprocess} {
  387. mgb_assert(m_layouts.size() == layouts.size());
  388. static_assert(
  389. std::tuple_size<ConvTensorLayouts>::value == 3 ||
  390. std::tuple_size<ConvTensorLayouts>::value == 5 ||
  391. std::tuple_size<ConvTensorLayouts>::value == 8,
  392. "Convolution AlgoChooser assumes arity = 3 , 5 or 8 (for "
  393. "deformable conv)");
  394. }
  395. Opr* megdnn_opr() const { return m_megdnn_opr; }
  396. const MGBOpr* mgb_opr() const { return m_mgb_opr; }
  397. const TensorLayout& inp_layout(size_t idx) const {
  398. return m_layouts[idx];
  399. }
  400. const ConvTensorLayouts& layouts() const { return m_layouts; }
  401. ImplAlgo choose_by_heuristic(bool reproducible = false) const {
  402. auto opr = m_mgb_opr;
  403. auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
  404. opr->owner_graph(), opr->comp_node(),
  405. opr->execution_policy().workspace_limit);
  406. return APPLY(m_megdnn_opr->get_algorithm_heuristic(
  407. args..., workspace_limit, reproducible),
  408. m_layouts);
  409. }
  410. //! get all candidate algos, and the one choose_by_heuristic() is
  411. //! put first
  412. std::vector<ImplAlgo> get_all_candidates() const {
  413. auto heu = choose_by_heuristic();
  414. auto&& ret =
  415. APPLY(m_megdnn_opr->get_all_algorithms(args...), m_layouts);
  416. bool found = false;
  417. for (size_t i = 0; i < ret.size(); ++i) {
  418. if (ret[i] == heu) {
  419. found = true;
  420. std::swap(ret[i], ret[0]);
  421. break;
  422. }
  423. }
  424. mgb_assert(found,
  425. "algo %s got by heuristic not found in "
  426. "candidate list",
  427. heu->name());
  428. return std::move(ret);
  429. }
  430. //! get candidate algos with workspace limit.
  431. std::vector<ImplAlgo> get_all_candidates_with_workspace_limit() const {
  432. auto&& all_algos = get_all_candidates();
  433. auto opr = m_mgb_opr;
  434. auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
  435. opr->owner_graph(), opr->comp_node(),
  436. opr->execution_policy().workspace_limit);
  437. std::vector<ImplAlgo> ret;
  438. for (auto&& algo : all_algos) {
  439. if (get_workspace_size_bytes(algo) <= workspace_limit) {
  440. ret.push_back(algo);
  441. }
  442. }
  443. return ret;
  444. }
  445. //! get workspace size required for specific algo
  446. size_t get_workspace_size_bytes(ImplAlgo algo) const {
  447. m_megdnn_opr->execution_policy() = {algo};
  448. size_t result;
  449. if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) {
  450. auto&& opr = _(m_megdnn_opr);
  451. auto prep = this->construct_fake_preprocess_filter();
  452. PreprocessFilter<Opr>* prep_ptr =
  453. prep.valid() ? &prep.val() : nullptr;
  454. result = std::max(
  455. APPLY(opr->get_preprocess_workspace_in_bytes(args...),
  456. m_layouts),
  457. APPLY(opr->get_workspace_in_bytes(args..., prep_ptr),
  458. m_layouts));
  459. }, /* else */ [&](auto _) {
  460. result = APPLY(_(m_megdnn_opr)->get_workspace_in_bytes(args...),
  461. m_layouts);
  462. });
  463. return result;
  464. }
  465. /*!
  466. * \brief profile a single algorithm
  467. *
  468. * This is actually a wrapper that constructs param and call
  469. * TimedProfiler<Opr>::profile for the actual profiling
  470. *
  471. * \param[in,out] timeout set the timeout, and return the actual
  472. * timeout used during profiling
  473. */
  474. Maybe<AlgoChooserProfileCache::ResultEntry> profile_single_algo(
  475. ImplAlgo algo, double& timeout) const;
  476. private:
  477. Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const {
  478. Maybe<PreprocessFilter<Opr>> result = None;
  479. if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) {
  480. if (!m_allow_weight_preprocess)
  481. return;
  482. auto opr = _(m_megdnn_opr);
  483. auto layout =
  484. APPLY(opr->deduce_preprocessed_filter_layout(args...),
  485. m_layouts);
  486. if (layout.empty())
  487. return;
  488. result = PreprocessFilter<Opr>{};
  489. auto& res = result.val();
  490. res.algorithm_id = nullptr;
  491. res.tensors.resize(layout.size());
  492. for (size_t i = 0; i < layout.size(); i++) {
  493. res.tensors[i] = megdnn::TensorND(nullptr, layout[i]);
  494. }
  495. });
  496. return result;
  497. }
  498. };
  499. //! entrance for getting algorithm according to execution strategy
  500. static ImplAlgo get_algo(ExeContext& ctx) {
  501. using S = mixin::Convolution::ExecutionPolicy::Strategy;
  502. MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE);
  503. switch (ctx.mgb_opr()->execution_policy().strategy) {
  504. case S::HEURISTIC:
  505. return ctx.choose_by_heuristic();
  506. case S::HEURISTIC_REPRODUCIBLE:
  507. return ctx.choose_by_heuristic(true);
  508. case S::PROFILE_HEURISTIC: {
  509. ImplAlgo algo = choose_by_profile(ctx, false, false);
  510. if (algo == nullptr)
  511. algo = ctx.choose_by_heuristic();
  512. return algo;
  513. }
  514. #if MGB_ENABLE_FASTRUN
  515. case S::PROFILE:
  516. return choose_by_profile(ctx, false);
  517. case S::PROFILE_REPRODUCIBLE:
  518. return choose_by_profile(ctx, true);
  519. #endif
  520. default:
  521. mgb_throw(GraphError,
  522. "bad convolution ExecutionPolicy strategy");
  523. }
  524. }
  525. static void get_origin_param_and_layouts(const ExeContext&,
  526. ConvTensorLayouts&,
  527. typename Opr::Param&) {}
  528. //! get all profile result, either by retrieving cache or profiling
  529. static AlgoChooserProfileCache::Result get_profile_result(
  530. ExeContext& ctx, bool enable_update);
  531. static ImplAlgo choose_by_profile(ExeContext& ctx,
  532. bool require_reproducible,
  533. bool enable_update = true);
  534. public:
  535. /*!
  536. * \brief setup algorithm and return workspace size
  537. */
  538. static size_t setup_algo(const ConvTensorLayouts& layouts, Opr* megdnn_opr,
  539. const MGBOpr* mgb_opr,
  540. bool allow_weight_preprocess = false) {
  541. if (WorkspaceLimitGetter::is_prealloc_run(mgb_opr->owner_graph())) {
  542. return 0;
  543. }
  544. ExeContext ctx(layouts, megdnn_opr, mgb_opr, allow_weight_preprocess);
  545. auto algo = get_algo(ctx);
  546. size_t workspace = ctx.get_workspace_size_bytes(algo);
  547. mgb_log_debug(
  548. "%s:tensor layouts (%s %s, %s %s)->(%s %s) :algo=%s "
  549. "workspace=%.2fMiB reproducible=%d",
  550. mgb_opr->dyn_typeinfo()->name,
  551. layouts[0].to_string().c_str(),
  552. layouts[0].dtype.name(),
  553. layouts[1].to_string().c_str(),
  554. layouts[1].dtype.name(),
  555. layouts[layouts.size() - 1].to_string().c_str(),
  556. layouts[layouts.size() - 1].dtype.name(), algo->name(),
  557. workspace / (1024 * 1024.0), algo->is_reproducible());
  558. megdnn_opr->execution_policy() = {algo};
  559. return workspace;
  560. }
  561. };
  562. template <typename Opr>
  563. AlgoChooserProfileCache::Result AlgoChooser<Opr>::get_profile_result(
  564. ExeContext& ctx, bool enable_update) {
  565. AlgoChooserProfileCache& cache = ctx.mgb_opr()->profile_cache();
  566. ConvTensorLayouts origin_layouts = ctx.layouts();
  567. typename Opr::Param origin_param = ctx.mgb_opr()->param();
  568. get_origin_param_and_layouts(ctx, origin_layouts, origin_param);
  569. AlgoChooserProfileCache::Key cache_key{origin_layouts.data(),
  570. origin_layouts.size(), &origin_param,
  571. sizeof(origin_param)};
  572. {
  573. auto&& rst = cache.get(cache_key);
  574. if (rst.valid())
  575. return rst.val();
  576. }
  577. AlgoChooserProfileCache::Result prof_rst;
  578. if (!enable_update)
  579. return prof_rst;
  580. std::string str_on_inp_shape = ssprintf(
  581. "on input layouts (%s, %s)", ctx.layouts()[0].to_string().c_str(),
  582. ctx.layouts()[1].to_string().c_str());
  583. double cur_timeout = 0;
  584. RealTimer timer;
  585. for (auto algo : ctx.get_all_candidates_with_workspace_limit()) {
  586. Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst;
  587. std::string msg = ssprintf("profiling %s algorithm %s %s",
  588. ctx.mgb_opr()->dyn_typeinfo()->name,
  589. algo->name(), str_on_inp_shape.c_str());
  590. timer.reset();
  591. MGB_TRY { cur_rst = ctx.profile_single_algo(algo, cur_timeout); }
  592. MGB_CATCH(std::exception & exc, {
  593. mgb_log_warn("caught exception during %s: %s", msg.c_str(),
  594. exc.what());
  595. continue;
  596. })
  597. MGB_CATCH(..., {
  598. mgb_log_warn("caught exception during %s", msg.c_str());
  599. continue;
  600. })
  601. if (!cur_rst.valid()) {
  602. mgb_log_warn("timeout when %s; timeout setting: %.3fsec",
  603. msg.c_str(), cur_timeout);
  604. continue;
  605. }
  606. if (!cur_timeout) {
  607. cur_timeout = timer.get_secs() + TIMEOUT_TOLERANCE;
  608. } else {
  609. cur_timeout =
  610. std::min(cur_timeout, timer.get_secs() + TIMEOUT_TOLERANCE);
  611. }
  612. auto&& rst = cur_rst.val();
  613. mgb_log_debug("%s: workspace: %zu; time: %.3gsec", msg.c_str(),
  614. rst.workspace, rst.time);
  615. prof_rst.push_back(rst);
  616. }
  617. mgb_assert(!prof_rst.empty(), "no usable convolution algorithm %s",
  618. str_on_inp_shape.c_str());
  619. cache.put(cache_key, prof_rst);
  620. return prof_rst;
  621. }
  622. template <>
  623. void AlgoChooser<megdnn::ConvBias>::get_origin_param_and_layouts(
  624. const ExeContext& ctx, ConvTensorLayouts& layouts,
  625. megdnn::ConvBias::Param& param) {
  626. auto format = static_cast<megdnn::param::ConvBias::Format>(
  627. ctx.megdnn_opr()->param().format);
  628. size_t output_block_size = ctx.megdnn_opr()->param().output_block_size;
  629. megdnn::ConvBias::deduce_winograd_origin_layout_and_param(
  630. format, output_block_size, ctx.layouts()[0], ctx.layouts()[1],
  631. layouts[1], param);
  632. }
  633. template <typename Opr>
  634. typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile(
  635. ExeContext& ctx, bool require_reproducible, bool enable_update) {
  636. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile")))
  637. auto opr = ctx.mgb_opr();
  638. if (opr->owner_graph()->options().no_profiling_on_shape_change) {
  639. auto algo = ctx.megdnn_opr()->execution_policy().algorithm;
  640. if (algo)
  641. return algo;
  642. }
  643. std::unordered_map<std::string, ImplAlgo> algo_map;
  644. for (auto i : ctx.get_all_candidates()) {
  645. auto ins = algo_map.emplace(i->name(), i);
  646. mgb_assert(ins.second, "duplicated algo name: %s", i->name());
  647. }
  648. auto&& prof = get_profile_result(ctx, enable_update);
  649. if (prof.empty())
  650. return nullptr;
  651. for (auto&& i : prof) {
  652. if ((!require_reproducible || i.reproducible)) {
  653. auto iter = algo_map.find(i.algo);
  654. mgb_assert(
  655. iter != algo_map.end(),
  656. "algorithm %s exists in "
  657. "profiling result but not in algo_map; please report this "
  658. "bug; opr: %s{%s}, shapes: %s %s %s",
  659. ctx.mgb_opr()->cname(), ctx.mgb_opr()->dyn_typeinfo()->name,
  660. ctx.layouts()[0].TensorShape::to_string().c_str(),
  661. ctx.layouts()[1].TensorShape::to_string().c_str(),
  662. ctx.layouts()[2].TensorShape::to_string().c_str(),
  663. i.algo.c_str());
  664. return iter->second;
  665. }
  666. }
  667. mgb_log_error(
  668. "Workspace requirement (%zu) could not be satisfied. Abort now to "
  669. "avoid further problems",
  670. WorkspaceLimitGetter::get_workspace_limit(
  671. opr->owner_graph(), opr->comp_node(),
  672. opr->execution_policy().workspace_limit));
  673. mgb_trap();
  674. MIDOUT_E
  675. }
  676. template <typename Opr>
  677. Maybe<AlgoChooserProfileCache::ResultEntry>
  678. AlgoChooser<Opr>::ExeContext::profile_single_algo(ImplAlgo algo,
  679. double& timeout) const {
  680. typename TimedProfiler<Opr>::Param param;
  681. auto name = algo->name();
  682. // force check copy size <= dest len-1 from gcc8 for safe
  683. auto len = sizeof(param.algo_name);
  684. strncpy(param.algo_name, name, len - 1);
  685. param.algo_name[len - 1] = '\0';
  686. mgb_assert(!param.algo_name[sizeof(param.algo_name) - 2],
  687. "algo name too long: %s; len=%zu", name, strlen(name));
  688. param.workspace = get_workspace_size_bytes(algo);
  689. for (int i = 0; i < arity; ++i) {
  690. auto&& src = m_layouts[i];
  691. mgb_assert(src.format.is_default() &&
  692. (src.dtype.category() == DTypeCategory::FLOAT ||
  693. src.dtype.category() == DTypeCategory::INT ||
  694. src.dtype.category() == DTypeCategory::QUANTIZED),
  695. "unsupported layout in profiling: %s",
  696. src.to_string().c_str());
  697. param.dtypes[i] = src.dtype.enumv();
  698. }
  699. param.comp_node_loc = m_mgb_opr->output(0)->comp_node().locator();
  700. mgb_assert(param.shapes.size() == m_layouts.size());
  701. for (size_t i = 0; i < param.shapes.size(); ++i)
  702. param.shapes[i] = m_layouts[i];
  703. param.opr_param = m_megdnn_opr->param();
  704. param.allow_weight_preprocess = m_allow_weight_preprocess;
  705. auto rst = TimedProfiler<Opr>::profile(param, timeout);
  706. // MIOpen conv profiles all available algos when a specfic shape is
  707. // provided for the first time, which probably adds to the result time.
  708. // Therefore, a second profile execution is needed.
  709. if (strncmp(name, "MIOpen", 6) == 0)
  710. rst = TimedProfiler<Opr>::profile(param, timeout);
  711. if (!rst.valid())
  712. return None;
  713. return AlgoChooserProfileCache::ResultEntry{
  714. algo->name(), algo->is_reproducible(), rst.val().time,
  715. param.workspace};
  716. }
  717. } // anonymous namespace
  718. /* ==================== misc impl ==================== */
  719. mixin::Convolution::~Convolution() = default;
  720. void mixin::Convolution::set_execution_policy(const ExecutionPolicy& policy) {
  721. mgb_throw_if(
  722. m_policy_accessed, InternalError,
  723. "attempt to modify ExecutionPolicy after it has been accessed");
  724. m_policy = policy;
  725. }
  726. template <class MgbOpr, class MegDNNOpr>
  727. void mixin::Convolution::init_output_static_infer_desc_for_bwd_data(
  728. cg::OperatorNodeBase* self) {
  729. using namespace cg::static_infer;
  730. auto&& mgr = self->owner_graph()->static_infer_manager();
  731. DepVal inp_deps;
  732. inp_deps.reserve(4);
  733. for (int i = 0; i < 2; ++i) {
  734. inp_deps.push_back({self->input(i), DepType::SHAPE});
  735. }
  736. // output shape
  737. if (self->input().size() == 3) {
  738. mgr.register_shape_infer(self->output(0),
  739. ShapeInferDesc::make_identity(self->input(2)));
  740. } else {
  741. auto infer_shp = [self](TensorShape& dest, const InpVal& inp) {
  742. TensorLayout ol{self->output(0)->dtype()};
  743. static_cast<MgbOpr*>(self)->megdnn_opr()->deduce_layout(
  744. {inp.val.at(0).shape(), self->input(0)->dtype()},
  745. {inp.val.at(1).shape(), self->input(1)->dtype()}, ol);
  746. dest = ol;
  747. return true;
  748. };
  749. mgr.register_shape_infer(self->output(0),
  750. {SourceType::DEP, inp_deps, infer_shp});
  751. }
  752. // workspace size
  753. auto infer_wk = [self](TensorShape& dest, const InpVal& inp) {
  754. auto&& iv = inp.val;
  755. dest.ndim = 1;
  756. dest.shape[0] = AlgoChooser<MegDNNOpr>::setup_algo(
  757. {TensorLayout{iv[0].shape(), self->input(0)->dtype(),
  758. self->input(0)->format()},
  759. {iv[1].shape(), self->input(1)->dtype(),
  760. self->input(1)->format()},
  761. {iv.at(2).shape(), self->output(0)->dtype(),
  762. self->output(0)->format()}},
  763. static_cast<MgbOpr*>(self)->megdnn_opr(),
  764. static_cast<MgbOpr*>(self));
  765. return true;
  766. };
  767. inp_deps.push_back({self->output(0), DepType::SHAPE});
  768. auto workspace_dep_var =
  769. WorkspaceLimitGetter::register_to_graph(self->owner_graph());
  770. if (workspace_dep_var) {
  771. inp_deps.push_back({workspace_dep_var, DepType::VALUE});
  772. }
  773. mgr.register_shape_infer(self->output(1),
  774. {SourceType::DEP, inp_deps, infer_wk});
  775. }
  776. #define IMPL_CONV(_cls, _prof_name) \
  777. void _cls::init_profile_cache() { \
  778. std::string name(_prof_name CACHE_KEY_VERSION); \
  779. name.append(megdnn_opr()->get_algorithm_set_name()); \
  780. m_profile_cache = std::make_unique<AlgoChooserProfileCache>( \
  781. comp_node(), name.c_str()); \
  782. } \
  783. std::pair<const void*, size_t> _cls::param_blob() const { \
  784. return {&param(), sizeof(Param)}; \
  785. } \
  786. MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls)
  787. AlgoChooserProfileCache& mixin::Convolution::profile_cache() const {
  788. if (!m_profile_cache) {
  789. const_cast<Convolution*>(this)->init_profile_cache();
  790. mgb_assert(m_profile_cache);
  791. }
  792. return *m_profile_cache;
  793. }
  794. class mixin::WeightPreprocessExecutor::PreprocessedFilterExecDep final
  795. : public cg::GraphExecutable::ExecDependency {
  796. std::unique_ptr<PreprocessedFilter> m_pf;
  797. SmallVector<DeviceTensorND> m_filter_storage;
  798. public:
  799. explicit PreprocessedFilterExecDep(
  800. std::unique_ptr<PreprocessedFilter> preprocessed_filter,
  801. SmallVector<DeviceTensorND> filter_storage)
  802. : m_pf(std::move(preprocessed_filter)),
  803. m_filter_storage(std::move(filter_storage)) {}
  804. };
  805. void mixin::WeightPreprocessExecutor::mixin_update_preprocessed_filter(
  806. cg::OperatorNodeBase& opr) {
  807. if (!mixin_allow_weight_preprocess(opr)) return;
  808. auto new_layout = deduce_preprocessed_filter_layout();
  809. if (new_layout.empty()) {
  810. // Weight preprocess was needed before, but no longer needed.
  811. if (m_preprocessed_filter) {
  812. m_preprocessed_filter.reset();
  813. m_filter_storage.clear();
  814. }
  815. return;
  816. }
  817. bool should_update = false;
  818. size_t new_size = new_layout.size();
  819. if (!m_preprocessed_filter ||
  820. m_preprocessed_filter->tensors.size() != new_size) {
  821. should_update = true;
  822. } else {
  823. for (size_t i = 0; i < new_size; i++) {
  824. if (!new_layout[i].eq_layout(
  825. m_preprocessed_filter->tensors[i].layout)) {
  826. should_update = true;
  827. break;
  828. }
  829. }
  830. }
  831. if (!should_update) return;
  832. if (!m_preprocessed_filter) {
  833. m_preprocessed_filter.reset(new PreprocessedFilter{});
  834. }
  835. m_preprocessed_filter->tensors.resize(new_size);
  836. m_filter_storage.resize(new_size);
  837. m_preprocessed_filter->algorithm_id = nullptr;
  838. for (size_t i = 0; i < new_size; i++) {
  839. m_filter_storage[i] = {opr.output(0)->comp_node(), new_layout[i],
  840. new_layout[i].dtype, new_layout[i].format};
  841. m_preprocessed_filter->tensors[i] = m_filter_storage[i].as_megdnn();
  842. }
  843. scn_do_execute_preprocess();
  844. }
  845. void mixin::WeightPreprocessExecutor::record_preprocessed_weight(
  846. cg::GraphExecutable::ExecDependencyArray& deps) {
  847. deps.emplace_back(new PreprocessedFilterExecDep{
  848. std::move(m_preprocessed_filter), std::move(m_filter_storage)});
  849. }
  850. bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess(
  851. const cg::OperatorNodeBase& opr) const {
  852. if (!opr.owner_graph()->options().graph_opt.weight_preprocess) {
  853. return false;
  854. }
  855. if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE))
  856. return false;
  857. if (cg::is_const_var_value(opr.input(1)))
  858. return true;
  859. auto* input_opr = opr.input(1)->owner_opr();
  860. if (input_opr->same_type<opr::MultipleDeviceTensorHolder>() ||
  861. input_opr->same_type<opr::MultipleDeviceTensorWithFormatHolder>())
  862. return true;
  863. auto* sdt = input_opr->try_cast_final<opr::SharedDeviceTensor>();
  864. if (sdt && sdt->const_value())
  865. return true;
  866. auto* sdtf = input_opr->try_cast_final<opr::SharedDeviceTensorWithFormat>();
  867. if (sdtf && sdtf->const_value())
  868. return true;
  869. return false;
  870. }
  871. /* ==================== ConvolutionForward ==================== */
  872. IMPL_CONV(ConvolutionForward, "conv_fwd");
  873. ConvolutionForward::ConvolutionForward(VarNode* src, VarNode* filter,
  874. const Param& param,
  875. const ExecutionPolicy& policy,
  876. const OperatorNodeConfig& config)
  877. : Super{src->owner_graph(), config, "conv", {src, filter}} {
  878. init_megdnn_opr(*this, param);
  879. m_policy = policy;
  880. add_input({src, filter});
  881. }
  882. SymbolVar ConvolutionForward::make(SymbolVar src, SymbolVar filter,
  883. const Param& param,
  884. const ExecutionPolicy& policy,
  885. const OperatorNodeConfig& config) {
  886. return src.insert_single_output_opr<ConvolutionForward>(
  887. src.node(), filter.node(), param, policy, config);
  888. }
  889. void ConvolutionForward::init_output_dtype() {
  890. DType output_dtype = config().output_dtype();
  891. megdnn_opr()->deduce_dtype(input(0)->dtype(), input(1)->dtype(),
  892. output_dtype);
  893. output(0)->dtype(output_dtype);
  894. }
  895. #if MGB_ENABLE_GRAD
  896. MGB_IMPL_OPR_GRAD(ConvolutionForward) {
  897. mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
  898. "only float data type supported for grad");
  899. mgb_assert(wrt_idx == 0 || wrt_idx == 1);
  900. mgb_assert(out_grad.size() == 2);
  901. if (wrt_idx == 0) {
  902. // data
  903. SymbolVar grad = ConvolutionBackwardData::make(
  904. opr.input(1), out_grad[0], opr.input(0), opr.param(),
  905. opr.execution_policy());
  906. return grad.node();
  907. } else {
  908. // filter
  909. SymbolVar grad = ConvolutionBackwardFilter::make(
  910. opr.input(0), out_grad[0], opr.input(1), opr.param(),
  911. opr.execution_policy());
  912. return grad.node();
  913. }
  914. }
  915. #endif
  916. size_t ConvolutionForward::get_workspace_size_bytes(
  917. const TensorShapeArray& input_shapes,
  918. const TensorShapeArray& output_shapes) const {
  919. mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
  920. return AlgoChooser<megdnn::ConvolutionForward>::setup_algo(
  921. {TensorLayout{input_shapes[0], input(0)->dtype(),
  922. input(0)->format()},
  923. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  924. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  925. megdnn_opr(), this, allow_weight_preprocess());
  926. }
  927. void ConvolutionForward::init_output_format() {
  928. mgb_assert(output().size() == 2);
  929. output(0)->format(input(0)->format());
  930. }
  931. void ConvolutionForward::scn_do_execute() {
  932. update_preprocessed_filter();
  933. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  934. input(1)->dev_tensor().as_megdnn(),
  935. output(0)->dev_tensor().as_megdnn(),
  936. preprocessed_filter(),
  937. intl::get_megdnn_workspace_from_var(output().back()));
  938. }
  939. void ConvolutionForward::add_input_layout_constraint() {
  940. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  941. }
  942. void ConvolutionForward::init_output_static_infer_desc() {
  943. Super::set_nr_managed_outputs(this->output().size() - 1);
  944. Super::init_output_static_infer_desc();
  945. init_output_static_infer_desc_workspace(
  946. intl::AutoAddWorkspaceNeedLimitGetter<
  947. megdnn::ConvolutionForward>::val);
  948. }
  949. void ConvolutionForward::get_output_var_shape(
  950. const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
  951. TensorLayout input_layout{inp_shape[0], input(0)->dtype(),
  952. input(0)->format()};
  953. TensorLayout filter_layout{inp_shape[1], input(1)->dtype(),
  954. input(1)->format()};
  955. TensorLayout dst_layout{output(0)->dtype(), output(0)->format()};
  956. megdnn_opr()->deduce_layout(input_layout, filter_layout, dst_layout);
  957. out_shape[0] = dst_layout;
  958. }
  959. void ConvolutionForward::record_execute_deps(
  960. cg::GraphExecutable::ExecDependencyArray& deps) {
  961. record_megdnn_opr(deps);
  962. record_preprocessed_weight(deps);
  963. }
  964. SmallVector<TensorLayout>
  965. ConvolutionForward::deduce_preprocessed_filter_layout() {
  966. return megdnn_opr()->deduce_preprocessed_filter_layout(
  967. input(0)->layout(), input(1)->layout(), output(0)->layout());
  968. }
  969. void ConvolutionForward::scn_do_execute_preprocess() {
  970. megdnn_opr()->exec_preprocess(
  971. input(0)->layout(), input(1)->dev_tensor().as_megdnn(),
  972. output(0)->layout(), preprocessed_filter(),
  973. intl::get_megdnn_workspace_from_var(output().back()));
  974. }
  975. /* ==================== ConvolutionBackwardData ==================== */
  976. IMPL_CONV(ConvolutionBackwardData, "conv_bwd_data");
  977. ConvolutionBackwardData::ConvolutionBackwardData(
  978. VarNode* filter, VarNode* diff, VarNode* src_for_shp,
  979. const Param& param, const ExecutionPolicy& policy,
  980. const OperatorNodeConfig& config)
  981. : Super{filter->owner_graph(),
  982. config,
  983. "conv_bwd_data",
  984. {filter, diff}} {
  985. init_megdnn_opr(*this, param);
  986. m_policy = policy;
  987. add_input({filter, diff});
  988. if (src_for_shp) {
  989. add_input({src_for_shp});
  990. }
  991. }
  992. SymbolVar ConvolutionBackwardData::make(SymbolVar filter, SymbolVar diff,
  993. SymbolVar src, const Param& param,
  994. const ExecutionPolicy& policy,
  995. const OperatorNodeConfig& config) {
  996. return filter.insert_single_output_opr<ConvolutionBackwardData>(
  997. filter.node(), diff.node(), src.node(), param, policy, config);
  998. }
  999. SymbolVar ConvolutionBackwardData::make(SymbolVar filter, SymbolVar data,
  1000. const Param& param,
  1001. const ExecutionPolicy& policy,
  1002. const OperatorNodeConfig& config) {
  1003. return make(filter, data, {}, param, policy, config);
  1004. }
  1005. void ConvolutionBackwardData::add_input_layout_constraint() {
  1006. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  1007. }
  1008. void ConvolutionBackwardData::init_output_static_infer_desc() {
  1009. init_output_static_infer_desc_for_bwd_data<ConvolutionBackwardData,
  1010. megdnn::ConvolutionBackwardData>(
  1011. this);
  1012. }
  1013. void ConvolutionBackwardData::init_output_dtype() {
  1014. DType output_dtype = config().output_dtype();
  1015. megdnn_opr()->deduce_dtype(input(0)->dtype(), input(1)->dtype(),
  1016. output_dtype);
  1017. output(0)->dtype(output_dtype);
  1018. }
  1019. void ConvolutionBackwardData::init_output_format() {
  1020. mgb_assert(output().size() == 2);
  1021. output(0)->format(input(1)->format());
  1022. }
  1023. cg::OperatorNodeBase::NodeProp* ConvolutionBackwardData::do_make_node_prop()
  1024. const {
  1025. auto prop = Super::Super::do_make_node_prop();
  1026. if (input().size() == 3) {
  1027. using D = NodeProp::DepType;
  1028. prop->reset_dep_type(input(), {D::DEV_VALUE, D::DEV_VALUE, D::SHAPE});
  1029. }
  1030. return prop;
  1031. }
  1032. void ConvolutionBackwardData::scn_do_execute() {
  1033. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  1034. input(1)->dev_tensor().as_megdnn(),
  1035. output(0)->dev_tensor().as_megdnn(),
  1036. intl::get_megdnn_workspace_from_var(output(1)));
  1037. }
  1038. #if MGB_ENABLE_GRAD
  1039. MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) {
  1040. mgb_assert(!out_grad[1]);
  1041. if (wrt_idx == 0) {
  1042. return ConvolutionBackwardFilter::make(out_grad[0], opr.input(1),
  1043. opr.input(0), opr.param(),
  1044. opr.execution_policy())
  1045. .node();
  1046. }
  1047. if (wrt_idx == 1) {
  1048. return Convolution::make(out_grad[0], opr.input(0), opr.param(),
  1049. opr.execution_policy())
  1050. .node();
  1051. }
  1052. return nullptr;
  1053. }
  1054. #endif
  1055. /* ==================== ConvolutionBackwardFilter ==================== */
  1056. IMPL_CONV(ConvolutionBackwardFilter, "conv_bwd_filter");
  1057. ConvolutionBackwardFilter::ConvolutionBackwardFilter(
  1058. VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
  1059. const ExecutionPolicy& policy, const OperatorNodeConfig& config)
  1060. : Super({src->owner_graph(),
  1061. config,
  1062. "conv_bwd_filter",
  1063. {src, diff, filter}},
  1064. 2, false) {
  1065. init_megdnn_opr(*this, param);
  1066. m_policy = policy;
  1067. add_input({src, diff, filter});
  1068. }
  1069. SymbolVar ConvolutionBackwardFilter::make(SymbolVar src, SymbolVar diff,
  1070. SymbolVar filter, const Param& param,
  1071. const ExecutionPolicy& policy,
  1072. const OperatorNodeConfig& config) {
  1073. return src.insert_single_output_opr<ConvolutionBackwardFilter>(
  1074. src.node(), diff.node(), filter.node(), param, policy, config);
  1075. }
  1076. size_t ConvolutionBackwardFilter::get_workspace_size_bytes(
  1077. const TensorShapeArray& input_shapes,
  1078. const TensorShapeArray& output_shapes) const {
  1079. mgb_assert(input_shapes.size() == 3 && output_shapes.size() == 1);
  1080. return AlgoChooser<megdnn::ConvolutionBackwardFilter>::setup_algo(
  1081. {TensorLayout{input_shapes[0], input(0)->dtype(),
  1082. input(0)->format()},
  1083. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  1084. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1085. megdnn_opr(), this);
  1086. }
  1087. #if MGB_ENABLE_GRAD
  1088. MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) {
  1089. mgb_assert(!out_grad[1]);
  1090. if (wrt_idx == 0) {
  1091. return ConvolutionBackwardData::make(out_grad[0], opr.input(1),
  1092. opr.input(0), opr.param(),
  1093. opr.execution_policy())
  1094. .node();
  1095. }
  1096. if (wrt_idx == 1) {
  1097. return Convolution::make(opr.input(0), out_grad[0], opr.param(),
  1098. opr.execution_policy())
  1099. .node();
  1100. }
  1101. return nullptr;
  1102. }
  1103. #endif
  1104. /* ==================== Convolution3DForward ==================== */
  1105. IMPL_CONV(Convolution3DForward, "conv3d_fwd");
  1106. Convolution3DForward::Convolution3DForward(VarNode* src, VarNode* filter,
  1107. const Param& param,
  1108. const ExecutionPolicy& policy,
  1109. const OperatorNodeConfig& config)
  1110. : Super{src->owner_graph(), config, "conv3d", {src, filter}} {
  1111. init_megdnn_opr(*this, param);
  1112. m_policy = policy;
  1113. add_input({src, filter});
  1114. }
  1115. SymbolVar Convolution3DForward::make(SymbolVar src, SymbolVar filter,
  1116. const Param& param,
  1117. const ExecutionPolicy& policy,
  1118. const OperatorNodeConfig& config) {
  1119. return src.insert_single_output_opr<Convolution3DForward>(
  1120. src.node(), filter.node(), param, policy, config);
  1121. }
  1122. void Convolution3DForward::init_output_dtype() {
  1123. switch (param().data_type) {
  1124. case Param::DataType::FLOAT:
  1125. output(0)->dtype(input(0)->dtype());
  1126. break;
  1127. #if !MEGDNN_DISABLE_FLOAT16
  1128. case Param::DataType::FLOAT_IO16xC32:
  1129. mgb_assert(input(0)->dtype() == dtype::Float16(),
  1130. "invalid input dtype %s", input(0)->name().c_str());
  1131. output(0)->dtype(input(0)->dtype());
  1132. break;
  1133. #endif
  1134. default:
  1135. mgb_throw(MegBrainError, "bad data_type enum");
  1136. }
  1137. }
  1138. #if MGB_ENABLE_GRAD
  1139. MGB_IMPL_OPR_GRAD(Convolution3DForward) {
  1140. mgb_assert(opr.param().data_type ==
  1141. Convolution3DForward::Param::DataType::FLOAT,
  1142. "only float data type supported for grad");
  1143. mgb_assert(wrt_idx == 0 || wrt_idx == 1);
  1144. mgb_assert(out_grad.size() == 2);
  1145. if (wrt_idx == 0) {
  1146. // data
  1147. SymbolVar grad = Convolution3DBackwardData::make(
  1148. opr.input(1), out_grad[0], opr.input(0), opr.param(),
  1149. opr.execution_policy());
  1150. return grad.node();
  1151. } else {
  1152. // filter
  1153. SymbolVar grad = Convolution3DBackwardFilter::make(
  1154. opr.input(0), out_grad[0], opr.input(1), opr.param(),
  1155. opr.execution_policy());
  1156. return grad.node();
  1157. }
  1158. }
  1159. #endif
  1160. size_t Convolution3DForward::get_workspace_size_bytes(
  1161. const TensorShapeArray& input_shapes,
  1162. const TensorShapeArray& output_shapes) const {
  1163. mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
  1164. return AlgoChooser<megdnn::Convolution3DForward>::setup_algo(
  1165. {TensorLayout{input_shapes[0], input(0)->dtype(),
  1166. input(0)->format()},
  1167. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  1168. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1169. megdnn_opr(), this);
  1170. }
  1171. /* ==================== Convolution3DBackwardData ==================== */
  1172. IMPL_CONV(Convolution3DBackwardData, "conv3d_bwd_data");
  1173. Convolution3DBackwardData::Convolution3DBackwardData(
  1174. VarNode* filter, VarNode* diff, VarNode* src_for_shp,
  1175. const Param& param, const ExecutionPolicy& policy,
  1176. const OperatorNodeConfig& config)
  1177. : Super{filter->owner_graph(),
  1178. config,
  1179. "conv3d_bwd_data",
  1180. {filter, diff}} {
  1181. init_megdnn_opr(*this, param);
  1182. m_policy = policy;
  1183. add_input({filter, diff});
  1184. if (src_for_shp) {
  1185. add_input({src_for_shp});
  1186. }
  1187. }
  1188. SymbolVar Convolution3DBackwardData::make(SymbolVar filter, SymbolVar diff,
  1189. SymbolVar src, const Param& param,
  1190. const ExecutionPolicy& policy,
  1191. const OperatorNodeConfig& config) {
  1192. return filter.insert_single_output_opr<Convolution3DBackwardData>(
  1193. filter.node(), diff.node(), src.node(), param, policy, config);
  1194. }
  1195. SymbolVar Convolution3DBackwardData::make(SymbolVar filter, SymbolVar data,
  1196. const Param& param,
  1197. const ExecutionPolicy& policy,
  1198. const OperatorNodeConfig& config) {
  1199. return make(filter, data, {}, param, policy, config);
  1200. }
  1201. void Convolution3DBackwardData::add_input_layout_constraint() {
  1202. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  1203. }
  1204. void Convolution3DBackwardData::init_output_static_infer_desc() {
  1205. init_output_static_infer_desc_for_bwd_data<
  1206. Convolution3DBackwardData, megdnn::Convolution3DBackwardData>(this);
  1207. }
  1208. cg::OperatorNodeBase::NodeProp* Convolution3DBackwardData::do_make_node_prop()
  1209. const {
  1210. auto prop = Super::Super::do_make_node_prop();
  1211. if (input().size() == 3) {
  1212. using D = NodeProp::DepType;
  1213. prop->reset_dep_type(input(), {D::DEV_VALUE, D::DEV_VALUE, D::SHAPE});
  1214. }
  1215. return prop;
  1216. }
  1217. void Convolution3DBackwardData::scn_do_execute() {
  1218. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  1219. input(1)->dev_tensor().as_megdnn(),
  1220. output(0)->dev_tensor().as_megdnn(),
  1221. intl::get_megdnn_workspace_from_var(output(1)));
  1222. }
  1223. #if MGB_ENABLE_GRAD
  1224. MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) {
  1225. mgb_assert(!out_grad[1]);
  1226. if (wrt_idx == 0) {
  1227. return Convolution3DBackwardFilter::make(out_grad[0], opr.input(1),
  1228. opr.input(0), opr.param(),
  1229. opr.execution_policy())
  1230. .node();
  1231. }
  1232. if (wrt_idx == 1) {
  1233. return Convolution3D::make(out_grad[0], opr.input(0), opr.param(),
  1234. opr.execution_policy())
  1235. .node();
  1236. }
  1237. return nullptr;
  1238. }
  1239. #endif
  1240. /* ==================== Convolution3DBackwardFilter ==================== */
  1241. IMPL_CONV(Convolution3DBackwardFilter, "conv3d_bwd_filter");
  1242. Convolution3DBackwardFilter::Convolution3DBackwardFilter(
  1243. VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
  1244. const ExecutionPolicy& policy, const OperatorNodeConfig& config)
  1245. : Super({src->owner_graph(),
  1246. config,
  1247. "conv3d_bwd_filter",
  1248. {src, diff, filter}},
  1249. 2, false) {
  1250. init_megdnn_opr(*this, param);
  1251. m_policy = policy;
  1252. add_input({src, diff, filter});
  1253. }
  1254. SymbolVar Convolution3DBackwardFilter::make(SymbolVar src, SymbolVar diff,
  1255. SymbolVar filter,
  1256. const Param& param,
  1257. const ExecutionPolicy& policy,
  1258. const OperatorNodeConfig& config) {
  1259. return src.insert_single_output_opr<Convolution3DBackwardFilter>(
  1260. src.node(), diff.node(), filter.node(), param, policy, config);
  1261. }
  1262. size_t Convolution3DBackwardFilter::get_workspace_size_bytes(
  1263. const TensorShapeArray& input_shapes,
  1264. const TensorShapeArray& output_shapes) const {
  1265. mgb_assert(input_shapes.size() == 3 && output_shapes.size() == 1);
  1266. return AlgoChooser<megdnn::Convolution3DBackwardFilter>::setup_algo(
  1267. {TensorLayout{input_shapes[0], input(0)->dtype(),
  1268. input(0)->format()},
  1269. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  1270. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1271. megdnn_opr(), this);
  1272. }
  1273. /* ========================== MaskConvolution ========================== */
  1274. MGB_DYN_TYPE_OBJ_FINAL_IMPL(MaskConvolution);
  1275. MaskConvolution::MaskConvolution(VarNode* src, VarNode* filter, VarNode* mask,
  1276. const Param& param,
  1277. const OperatorNodeConfig& config)
  1278. : Super(src->owner_graph(), config, "mask_conv_fwd",
  1279. {src, filter, mask}) {
  1280. init_megdnn_opr(*this, param);
  1281. add_input({src, filter, mask});
  1282. }
  1283. SymbolVar MaskConvolution::make(SymbolVar src, SymbolVar filter, SymbolVar mask,
  1284. const Param& param,
  1285. const OperatorNodeConfig& config) {
  1286. return src.insert_single_output_opr<MaskConvolution>(
  1287. src.node(), filter.node(), mask.node(), param, config);
  1288. }
  1289. void MaskConvolution::init_output_dtype() {
  1290. auto dtype = input(2)->dtype();
  1291. mgb_assert(dtype == dtype::Int32() || dtype == dtype::Int16() ||
  1292. dtype == dtype::Int8(),
  1293. "dtype must be int8, int16 or int32, while get %s",
  1294. dtype.name());
  1295. output(0)->dtype(input(0)->dtype());
  1296. }
  1297. MGB_DYN_TYPE_OBJ_FINAL_IMPL(MaskPropagate);
  1298. MaskPropagate::MaskPropagate(VarNode* src, const Param& param,
  1299. const OperatorNodeConfig& config)
  1300. : Super(src->owner_graph(), config, "mask_propagate", {src}) {
  1301. init_megdnn_opr(*this, param);
  1302. add_input({src});
  1303. }
  1304. void MaskPropagate::init_output_dtype() {
  1305. auto dtype = input(0)->dtype();
  1306. mgb_assert(dtype == dtype::Int32() || dtype == dtype::Int16() ||
  1307. dtype == dtype::Int8());
  1308. output(0)->dtype(dtype);
  1309. }
  1310. SymbolVar MaskPropagate::make(SymbolVar src, const Param& param,
  1311. const OperatorNodeConfig& config) {
  1312. return src.insert_single_output_opr<MaskPropagate>(src.node(), param,
  1313. config);
  1314. }
  1315. /* ==================== ConvBiasForward ==================== */
  1316. IMPL_CONV(ConvBiasForward, "conv_bias_fwd");
  1317. ConvBiasForward::ConvBiasForward(VarNode* src, VarNode* filter,
  1318. const Param& param,
  1319. const ExecutionPolicy& policy,
  1320. const OperatorNodeConfig& config)
  1321. : Super{src->owner_graph(), config, "conv_bias", {src, filter}} {
  1322. init_megdnn_opr(*this, param);
  1323. m_policy = policy;
  1324. add_input({src, filter});
  1325. }
  1326. ConvBiasForward::ConvBiasForward(VarNode* src, VarNode* filter, VarNode* bias,
  1327. const Param& param,
  1328. const ExecutionPolicy& policy,
  1329. const OperatorNodeConfig& config)
  1330. : Super{src->owner_graph(), config, "conv_bias", {src, filter, bias}} {
  1331. m_policy = policy;
  1332. init_megdnn_opr(*this, param);
  1333. add_input({src, filter, bias});
  1334. }
  1335. ConvBiasForward::ConvBiasForward(VarNode* src, VarNode* filter, VarNode* bias,
  1336. VarNode* z, const Param& param,
  1337. const ExecutionPolicy& policy,
  1338. const OperatorNodeConfig& config)
  1339. : Super{src->owner_graph(),
  1340. config,
  1341. "conv_bias",
  1342. {src, filter, bias, z}} {
  1343. m_policy = policy;
  1344. init_megdnn_opr(*this, param);
  1345. add_input({src, filter, bias, z});
  1346. }
  1347. void ConvBiasForward::add_input_layout_constraint() {
  1348. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  1349. }
  1350. SymbolVar ConvBiasForward::make(SymbolVar src, SymbolVar filter,
  1351. const Param& param,
  1352. const ExecutionPolicy& policy,
  1353. const OperatorNodeConfig& config) {
  1354. return src.insert_single_output_opr<ConvBiasForward>(
  1355. src.node(), filter.node(), param, policy, config);
  1356. }
  1357. SymbolVar ConvBiasForward::make(SymbolVar src, SymbolVar filter, SymbolVar bias,
  1358. const Param& param,
  1359. const ExecutionPolicy& policy,
  1360. const OperatorNodeConfig& config) {
  1361. return src.insert_single_output_opr<ConvBiasForward>(
  1362. src.node(), filter.node(), bias.node(), param, policy, config);
  1363. }
  1364. SymbolVar ConvBiasForward::make(SymbolVar src, SymbolVar filter, SymbolVar bias,
  1365. SymbolVar z, const Param& param,
  1366. const ExecutionPolicy& policy,
  1367. const OperatorNodeConfig& config) {
  1368. return src.insert_single_output_opr<ConvBiasForward>(
  1369. src.node(), filter.node(), bias.node(), z.node(), param, policy,
  1370. config);
  1371. }
  1372. void ConvBiasForward::init_output_dtype() {
  1373. DType output_dtype = config().output_dtype();
  1374. DType i0, i1, i2, i3;
  1375. mgb_assert(input().size() >= 2 && input().size() <= 4);
  1376. i0 = input(0)->dtype();
  1377. i1 = input(1)->dtype();
  1378. if (input().size() >= 3)
  1379. i2 = input(2)->dtype();
  1380. if (input().size() == 4)
  1381. i3 = input(3)->dtype();
  1382. megdnn_opr()->deduce_dtype(i0, i1, i2, i3, output_dtype);
  1383. output(0)->dtype(output_dtype);
  1384. }
  1385. size_t ConvBiasForward::get_workspace_size_bytes(
  1386. const TensorShapeArray& input_shapes,
  1387. const TensorShapeArray& output_shapes) const {
  1388. auto mo = megdnn_opr();
  1389. TensorLayout i0, i1, i2, i3;
  1390. mgb_assert(input_shapes.size() >= 2 && input_shapes.size() <= 4);
  1391. i0 = {input_shapes[0], input(0)->dtype(), input(0)->format()};
  1392. i1 = {input_shapes[1], input(1)->dtype(), input(1)->format()};
  1393. if (input_shapes.size() >= 3)
  1394. i2 = {input_shapes[2], input(2)->dtype(), input(2)->format()};
  1395. else {
  1396. DType dtype;
  1397. mo->deduce_dtype(input(0)->dtype(), input(1)->dtype(), DType{}, DType{},
  1398. dtype);
  1399. i2 = {{}, dtype};
  1400. }
  1401. if (input_shapes.size() == 4)
  1402. i3 = {input_shapes[3], input(3)->dtype(), input(3)->format()};
  1403. else
  1404. i3 = {{}, output(0)->dtype(), output(0)->format()};
  1405. return AlgoChooser<megdnn::ConvBias>::setup_algo(
  1406. {i0,
  1407. i1,
  1408. i2,
  1409. i3,
  1410. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1411. mo, this, allow_weight_preprocess());
  1412. }
  1413. void ConvBiasForward::scn_do_execute() {
  1414. update_preprocessed_filter();
  1415. auto&& inp = input();
  1416. auto mo = megdnn_opr();
  1417. if (inp.size() == 2) {
  1418. TensorLayout bias_layout;
  1419. bias_layout.ndim = 0;
  1420. if (output(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
  1421. bias_layout.dtype = dtype::QuantizedS32(
  1422. output(0)->dtype().param<dtype::QuantizedS8>().scale);
  1423. } else {
  1424. bias_layout.dtype = output(0)->dtype();
  1425. }
  1426. TensorLayout z_layout;
  1427. z_layout.ndim = 0;
  1428. z_layout.dtype = output(0)->dtype();
  1429. megdnn::TensorND bias_tensor{nullptr, bias_layout};
  1430. megdnn::TensorND z_tensor{nullptr, z_layout};
  1431. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  1432. inp[1]->dev_tensor().as_megdnn(), bias_tensor, z_tensor,
  1433. output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
  1434. intl::get_megdnn_workspace_from_var(output().back()));
  1435. } else if (inp.size() == 3) {
  1436. TensorLayout z_layout;
  1437. z_layout.ndim = 0;
  1438. z_layout.dtype = output(0)->dtype();
  1439. megdnn::TensorND z_tensor{nullptr, z_layout};
  1440. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  1441. inp[1]->dev_tensor().as_megdnn(),
  1442. inp[2]->dev_tensor().as_megdnn(), z_tensor,
  1443. output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
  1444. intl::get_megdnn_workspace_from_var(output().back()));
  1445. } else {
  1446. mgb_assert(inp.size() == 4);
  1447. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  1448. inp[1]->dev_tensor().as_megdnn(),
  1449. inp[2]->dev_tensor().as_megdnn(),
  1450. inp[3]->dev_tensor().as_megdnn(),
  1451. output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
  1452. intl::get_megdnn_workspace_from_var(output().back()));
  1453. }
  1454. }
  1455. void ConvBiasForward::get_output_var_shape(const TensorShapeArray& inp_shape,
  1456. TensorShapeArray& out_shape) const {
  1457. auto mo = megdnn_opr();
  1458. TensorLayout dst;
  1459. mo->deduce_layout({inp_shape[0], input(0)->dtype(), input(0)->format()},
  1460. {inp_shape[1], input(1)->dtype(), input(0)->format()}, {},
  1461. {}, dst);
  1462. out_shape[0] = dst;
  1463. }
  1464. void ConvBiasForward::init_output_static_infer_desc() {
  1465. Super::set_nr_managed_outputs(this->output().size() - 1);
  1466. Super::init_output_static_infer_desc();
  1467. this->init_output_static_infer_desc_workspace(
  1468. intl::AutoAddWorkspaceNeedLimitGetter<
  1469. megdnn::ConvBiasForward>::val);
  1470. }
  1471. void ConvBiasForward::init_output_format() {
  1472. mgb_assert(output().size() == 2);
  1473. output(0)->format(input(0)->format());
  1474. }
  1475. void ConvBiasForward::check_winograd_param_valid(
  1476. const megdnn::ConvBias::WinogradParam& param,
  1477. const DType& dtype) {
  1478. if (dtype.enumv() == DTypeEnum::Float32) {
  1479. mgb_assert(param.channel_block_size == 1 ||
  1480. param.channel_block_size == 4 ||
  1481. param.channel_block_size == 8,
  1482. "only support 1/4/8 for the channel_block_size of "
  1483. "winograd param, got %u",
  1484. param.channel_block_size);
  1485. } else {
  1486. mgb_assert((MEGDNN_FLOAT16_SELECT(dtype.enumv() == DTypeEnum::Float16,
  1487. false) ||
  1488. dtype.enumv() == DTypeEnum::QuantizedS8 ||
  1489. dtype.enumv() == DTypeEnum::Quantized8Asymm) &&
  1490. (param.channel_block_size == 1 ||
  1491. param.channel_block_size == 4 ||
  1492. param.channel_block_size == 8),
  1493. "only support 1/4/8 for the channel_block_size of "
  1494. "winograd param, got %u",
  1495. param.channel_block_size);
  1496. }
  1497. }
  1498. megdnn::param::MatrixMul::Format ConvBiasForward::get_matmul_format(
  1499. const megdnn::ConvBias::WinogradParam& param) {
  1500. switch (param.channel_block_size) {
  1501. case 1:
  1502. return megdnn::param::MatrixMul::Format::DEFAULT;
  1503. break;
  1504. case 4:
  1505. return megdnn::param::MatrixMul::Format::MK4;
  1506. break;
  1507. case 8:
  1508. return megdnn::param::MatrixMul::Format::MK8;
  1509. break;
  1510. default:
  1511. mgb_throw(InternalError,
  1512. "Only Support 1/4/8 for "
  1513. "channel_block_size, got: %u",
  1514. param.channel_block_size);
  1515. }
  1516. }
  1517. SmallVector<TensorLayout> ConvBiasForward::deduce_preprocessed_filter_layout() {
  1518. TensorLayout i2, i3;
  1519. if (input().size() > 2) {
  1520. i2 = input(2)->layout();
  1521. }
  1522. if (input().size() > 3) {
  1523. i3 = input(3)->layout();
  1524. }
  1525. return megdnn_opr()->deduce_preprocessed_filter_layout(
  1526. input(0)->layout(), input(1)->layout(), i2, i3,
  1527. output(0)->layout());
  1528. }
  1529. void ConvBiasForward::scn_do_execute_preprocess() {
  1530. TensorLayout bias_layout(output(0)->dtype()), z_layout(output(0)->dtype());
  1531. if (input().size() > 2) {
  1532. bias_layout = input(2)->layout();
  1533. }
  1534. if (input().size() > 3) {
  1535. z_layout = input(3)->layout();
  1536. }
  1537. megdnn_opr()->exec_preprocess(
  1538. input(0)->layout(), input(1)->dev_tensor().as_megdnn(), bias_layout,
  1539. z_layout, output(0)->layout(), preprocessed_filter(),
  1540. intl::get_megdnn_workspace_from_var(output().back()));
  1541. }
  1542. /* ===================== LocalShareForward ==================== */
  1543. IMPL_CONV(LocalShareForward, "local_share");
  1544. LocalShareForward::LocalShareForward(VarNode* src, VarNode* filter,
  1545. const Param& param,
  1546. const ExecutionPolicy& policy,
  1547. const OperatorNodeConfig& config)
  1548. : Super{src->owner_graph(), config, "local_share", {src, filter}} {
  1549. init_megdnn_opr(*this, param);
  1550. m_policy = policy;
  1551. add_input({src, filter});
  1552. }
  1553. SymbolVar LocalShareForward::make(SymbolVar src, SymbolVar filter,
  1554. const Param& param,
  1555. const ExecutionPolicy& policy,
  1556. const OperatorNodeConfig& config) {
  1557. return src.insert_single_output_opr<LocalShareForward>(
  1558. src.node(), filter.node(), param, policy, config);
  1559. }
  1560. void LocalShareForward::init_output_dtype() {
  1561. DType output_dtype = config().output_dtype();
  1562. mgb_assert(!output_dtype.valid() || output_dtype == dtype::Float32());
  1563. output_dtype = dtype::Float32();
  1564. output(0)->dtype(output_dtype);
  1565. }
  1566. void LocalShareForward::init_output_format() {
  1567. mgb_assert(output().size() == 2);
  1568. output(0)->format(input(0)->format());
  1569. }
  1570. size_t LocalShareForward::get_workspace_size_bytes(
  1571. const TensorShapeArray& input_shapes,
  1572. const TensorShapeArray& output_shapes) const {
  1573. mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
  1574. return AlgoChooser<megdnn::LocalShareForward>::setup_algo(
  1575. {TensorLayout{input_shapes[0], input(0)->dtype(),
  1576. input(0)->format()},
  1577. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  1578. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1579. megdnn_opr(), this);
  1580. }
  1581. #if MGB_ENABLE_GRAD
  1582. MGB_IMPL_OPR_GRAD(LocalShareForward) {
  1583. mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
  1584. "only float data type supported for grad");
  1585. mgb_assert(wrt_idx == 0 || wrt_idx == 1);
  1586. mgb_assert(out_grad.size() == 2);
  1587. if (wrt_idx == 0) {
  1588. // data
  1589. SymbolVar grad = LocalShareBackwardData::make(
  1590. opr.input(1), out_grad[0], opr.input(0),
  1591. opr.param(), opr.execution_policy());
  1592. return grad.node();
  1593. } else {
  1594. // filter
  1595. SymbolVar grad = LocalShareBackwardFilter::make(
  1596. opr.input(0), out_grad[0], opr.input(1),
  1597. opr.param(), opr.execution_policy());
  1598. return grad.node();
  1599. }
  1600. }
  1601. #endif
  1602. /* ===================== LocalShareBackwardData ==================== */
  1603. IMPL_CONV(LocalShareBackwardData, "local_share_bwd_data");
  1604. LocalShareBackwardData::LocalShareBackwardData(VarNode* filter, VarNode* diff,
  1605. VarNode* src_for_shp,
  1606. const Param& param,
  1607. const ExecutionPolicy& policy,
  1608. const OperatorNodeConfig& config)
  1609. : Super{filter->owner_graph(), config, "local_share_bwd_data", {filter, diff}} {
  1610. init_megdnn_opr(*this, param);
  1611. m_policy = policy;
  1612. add_input({filter, diff});
  1613. if (src_for_shp) {
  1614. add_input({src_for_shp});
  1615. }
  1616. }
  1617. SymbolVar LocalShareBackwardData::make(SymbolVar filter, SymbolVar diff,
  1618. SymbolVar src, const Param& param,
  1619. const ExecutionPolicy& policy,
  1620. const OperatorNodeConfig& config) {
  1621. return filter.insert_single_output_opr<LocalShareBackwardData>(
  1622. filter.node(), diff.node(), src.node(), param, policy, config);
  1623. }
  1624. void LocalShareBackwardData::init_output_static_infer_desc() {
  1625. init_output_static_infer_desc_for_bwd_data<LocalShareBackwardData,
  1626. megdnn::LocalShareBackwardData>(
  1627. this);
  1628. }
  1629. void LocalShareBackwardData::init_output_dtype() {
  1630. DType output_dtype = config().output_dtype();
  1631. mgb_assert(!output_dtype.valid() || output_dtype == dtype::Float32());
  1632. output_dtype = dtype::Float32();
  1633. output(0)->dtype(output_dtype);
  1634. }
  1635. void LocalShareBackwardData::add_input_layout_constraint() {
  1636. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  1637. }
  1638. cg::OperatorNodeBase::NodeProp* LocalShareBackwardData::do_make_node_prop()
  1639. const {
  1640. auto prop = Super::Super::do_make_node_prop();
  1641. mgb_assert(input().size() == 3);
  1642. using D = NodeProp::DepType;
  1643. prop->reset_dep_type(input(), {D::DEV_VALUE, D::DEV_VALUE, D::SHAPE});
  1644. return prop;
  1645. }
  1646. void LocalShareBackwardData::scn_do_execute() {
  1647. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  1648. input(1)->dev_tensor().as_megdnn(),
  1649. output(0)->dev_tensor().as_megdnn(),
  1650. intl::get_megdnn_workspace_from_var(output(1)));
  1651. }
  1652. #if MGB_ENABLE_GRAD
  1653. MGB_IMPL_OPR_GRAD(LocalShareBackwardData) {
  1654. mgb_assert(!out_grad[1]);
  1655. if (wrt_idx == 0) {
  1656. return LocalShareBackwardFilter::make(out_grad[0], opr.input(1),
  1657. opr.input(0), opr.param(),
  1658. opr.execution_policy())
  1659. .node();
  1660. }
  1661. if (wrt_idx == 1) {
  1662. return LocalShare::make(out_grad[0], opr.input(0), opr.param(),
  1663. opr.execution_policy())
  1664. .node();
  1665. }
  1666. return nullptr;
  1667. }
  1668. #endif
  1669. /* ==================== LocalShareBackwardFilter ==================== */
  1670. IMPL_CONV(LocalShareBackwardFilter, "local_share_bwd_filter");
  1671. LocalShareBackwardFilter::LocalShareBackwardFilter(
  1672. VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
  1673. const ExecutionPolicy& policy, const OperatorNodeConfig& config)
  1674. : Super({src->owner_graph(),
  1675. config,
  1676. "local_share_bwd_filter",
  1677. {src, diff, filter}},
  1678. 2, false) {
  1679. init_megdnn_opr(*this, param);
  1680. m_policy = policy;
  1681. add_input({src, diff, filter});
  1682. }
  1683. SymbolVar LocalShareBackwardFilter::make(
  1684. SymbolVar src, SymbolVar diff, SymbolVar filter,
  1685. const Param &param,
  1686. const ExecutionPolicy &policy,
  1687. const OperatorNodeConfig &config) {
  1688. return src.insert_single_output_opr<LocalShareBackwardFilter>(
  1689. src.node(), diff.node(), filter.node(), param, policy, config);
  1690. }
  1691. size_t LocalShareBackwardFilter::get_workspace_size_bytes(
  1692. const TensorShapeArray &input_shapes,
  1693. const TensorShapeArray &output_shapes) const {
  1694. mgb_assert(input_shapes.size() == 3 && output_shapes.size() == 1);
  1695. return AlgoChooser<megdnn::LocalShareBackwardFilter>::setup_algo(
  1696. {TensorLayout{input_shapes[0], input(0)->dtype(),
  1697. input(0)->format()},
  1698. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  1699. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1700. megdnn_opr(), this);
  1701. }
  1702. #if MGB_ENABLE_GRAD
  1703. MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) {
  1704. mgb_assert(!out_grad[1]);
  1705. if (wrt_idx == 0) {
  1706. return LocalShareBackwardData::make(out_grad[0], opr.input(1),
  1707. opr.input(0), opr.param(), opr.execution_policy()).node();
  1708. }
  1709. if (wrt_idx == 1) {
  1710. return LocalShare::make(
  1711. opr.input(0), out_grad[0], opr.param(), opr.execution_policy()).
  1712. node();
  1713. }
  1714. return nullptr;
  1715. }
  1716. #endif
  1717. /* ===================== DeformableConvForward ==================== */
  1718. IMPL_CONV(DeformableConvForward, "deformable_conv");
  1719. DeformableConvForward::DeformableConvForward(VarNode* src, VarNode* filter,
  1720. VarNode* offset, VarNode* mask,
  1721. const Param& param,
  1722. const ExecutionPolicy& policy,
  1723. const OperatorNodeConfig& config)
  1724. : Super{src->owner_graph(),
  1725. config,
  1726. "deformable_conv",
  1727. {src, filter, offset, mask}} {
  1728. mgb_assert(src->dtype() == dtype::Float32() &&
  1729. filter->dtype() == dtype::Float32() &&
  1730. offset->dtype() == dtype::Float32() &&
  1731. mask->dtype() == dtype::Float32(),
  1732. "input should be float32, got %s, %s, %s, %s",
  1733. src->dtype().name(), filter->dtype().name(),
  1734. offset->dtype().name(), mask->dtype().name());
  1735. init_megdnn_opr(*this, param);
  1736. m_policy = policy;
  1737. add_input({src, filter, offset, mask});
  1738. }
  1739. SymbolVar DeformableConvForward::make(SymbolVar src, SymbolVar filter,
  1740. SymbolVar offset, SymbolVar mask,
  1741. const Param& param,
  1742. const ExecutionPolicy& policy,
  1743. const OperatorNodeConfig& config) {
  1744. return src.insert_single_output_opr<DeformableConvForward>(
  1745. src.node(), filter.node(), offset.node(), mask.node(), param,
  1746. policy, config);
  1747. }
  1748. void DeformableConvForward::init_output_dtype() {
  1749. DType output_dtype = config().output_dtype();
  1750. mgb_assert(!output_dtype.valid() || output_dtype == dtype::Float32());
  1751. output_dtype = dtype::Float32();
  1752. output(0)->dtype(output_dtype);
  1753. }
  1754. void DeformableConvForward::init_output_format() {
  1755. mgb_assert(output().size() == 2);
  1756. output(0)->format(input(0)->format());
  1757. }
  1758. size_t DeformableConvForward::get_workspace_size_bytes(
  1759. const TensorShapeArray& input_shapes,
  1760. const TensorShapeArray& output_shapes) const {
  1761. mgb_assert(input_shapes.size() == 4 && output_shapes.size() == 1);
  1762. return AlgoChooser<megdnn::DeformableConvForward>::setup_algo(
  1763. {TensorLayout{input_shapes[0], input(0)->dtype(),
  1764. input(0)->format()},
  1765. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  1766. {input_shapes[2], input(2)->dtype(), input(2)->format()},
  1767. {input_shapes[3], input(3)->dtype(), input(3)->format()},
  1768. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1769. megdnn_opr(), this);
  1770. }
  1771. #if MGB_ENABLE_GRAD
  1772. MGB_IMPL_OPR_GRAD(DeformableConvForward) {
  1773. mgb_assert(opr.input(0)->dtype() == dtype::Float32(),
  1774. "only float data type supported for grad");
  1775. mgb_assert(wrt_idx < 4);
  1776. mgb_assert(!out_grad[1]);
  1777. mgb_assert(out_grad.size() == 2);
  1778. // data, offset and mask
  1779. auto grad_arr = DeformableConvBackwardData::make_all(
  1780. opr.input(0), opr.input(1), opr.input(2), opr.input(3), out_grad[0],
  1781. opr.param(), opr.execution_policy(), opr.config());
  1782. // filter
  1783. auto filter_grad = DeformableConvBackwardFilter::make(
  1784. opr.input(0), opr.input(1), opr.input(2), opr.input(3), out_grad[0],
  1785. opr.param(), opr.execution_policy(), opr.config());
  1786. SymbolVarArray grads = {grad_arr[0], filter_grad, grad_arr[1], grad_arr[2]};
  1787. return grads[wrt_idx].node();
  1788. }
  1789. #endif
  1790. /* ==================== DeformableConvBackwardData ==================== */
  1791. IMPL_CONV(DeformableConvBackwardData, "deformalbe_conv_backward_data");
  1792. DeformableConvBackwardData::DeformableConvBackwardData(
  1793. VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask,
  1794. VarNode* diff, const Param& param, const ExecutionPolicy& policy,
  1795. const OperatorNodeConfig& config)
  1796. : Super{filter->owner_graph(),
  1797. config,
  1798. "deformable_conv_backward_data",
  1799. {src, filter, offset, mask, diff}} {
  1800. mgb_assert(src->dtype() == dtype::Float32() and
  1801. filter->dtype() == dtype::Float32() and
  1802. offset->dtype() == dtype::Float32() and
  1803. mask->dtype() == dtype::Float32() and
  1804. diff->dtype() == dtype::Float32(),
  1805. "input should be float32, got %s, %s, %s, %s %s",
  1806. src->dtype().name(), filter->dtype().name(),
  1807. offset->dtype().name(), mask->dtype().name(),
  1808. diff->dtype().name());
  1809. init_megdnn_opr(*this, param);
  1810. m_policy = policy;
  1811. add_input({src, filter, offset, mask, diff});
  1812. }
  1813. SymbolVarArray DeformableConvBackwardData::make_all(
  1814. SymbolVar src, SymbolVar filter, SymbolVar offset, SymbolVar mask,
  1815. SymbolVar diff, const Param& param, const ExecutionPolicy& policy,
  1816. const OperatorNodeConfig& config) {
  1817. auto graph = src.node()->owner_graph();
  1818. auto back_node =
  1819. graph->insert_opr(std::make_unique<DeformableConvBackwardData>(
  1820. src.node(), filter.node(), offset.node(), mask.node(),
  1821. diff.node(), param, policy, config));
  1822. return {back_node->output(0), back_node->output(1), back_node->output(2)};
  1823. }
  1824. SymbolVar DeformableConvBackwardData::make(SymbolVar src, SymbolVar filter,
  1825. SymbolVar offset, SymbolVar mask,
  1826. SymbolVar diff, const Param& param,
  1827. const ExecutionPolicy& policy,
  1828. const OperatorNodeConfig& config) {
  1829. auto&& all =
  1830. make_all(src, filter, offset, mask, diff, param, policy, config);
  1831. return all[0];
  1832. }
  1833. void DeformableConvBackwardData::scn_do_execute() {
  1834. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), // src
  1835. input(1)->dev_tensor().as_megdnn(), // filter
  1836. input(2)->dev_tensor().as_megdnn(), // offset
  1837. input(3)->dev_tensor().as_megdnn(), // mask
  1838. input(4)->dev_tensor().as_megdnn(), // diff
  1839. output(0)->dev_tensor().as_megdnn(), // src_grad
  1840. output(1)->dev_tensor().as_megdnn(), // offset_grad
  1841. output(2)->dev_tensor().as_megdnn(), // mask_grad
  1842. intl::get_megdnn_workspace_from_var(output(3)));
  1843. }
  1844. void DeformableConvBackwardData::get_output_var_shape(
  1845. const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
  1846. TensorShape im_shp = inp_shape[0];
  1847. TensorShape offset_shp = inp_shape[2];
  1848. TensorShape mask_shp = inp_shape[3];
  1849. mgb_assert(im_shp.ndim == 4, "invalid src shape: %s",
  1850. im_shp.to_string().c_str());
  1851. mgb_assert(offset_shp.ndim == 4, "invalid offset shape: %s",
  1852. offset_shp.to_string().c_str());
  1853. mgb_assert(mask_shp.ndim == 4, "invalid mask shape: %s",
  1854. mask_shp.to_string().c_str());
  1855. mgb_assert(out_shape.size() == 3);
  1856. out_shape[0] = im_shp;
  1857. out_shape[1] = offset_shp;
  1858. out_shape[2] = mask_shp;
  1859. }
  1860. size_t DeformableConvBackwardData::get_workspace_size_bytes(
  1861. const TensorShapeArray& inp_shape,
  1862. const TensorShapeArray& out_shape) const {
  1863. size_t ws = AlgoChooser<megdnn::DeformableConvBackwardData>::setup_algo(
  1864. {TensorLayout{inp_shape[0], input(0)->dtype(), input(0)->format()},
  1865. {inp_shape[1], input(1)->dtype(), input(1)->format()},
  1866. {inp_shape[2], input(2)->dtype(), input(2)->format()},
  1867. {inp_shape[3], input(3)->dtype(), input(3)->format()},
  1868. {inp_shape[4], input(4)->dtype(), input(4)->format()},
  1869. {out_shape[0], output(0)->dtype(), output(0)->format()},
  1870. {out_shape[1], output(1)->dtype(), output(1)->format()},
  1871. {out_shape[2], output(2)->dtype(), output(2)->format()}},
  1872. megdnn_opr(), this);
  1873. return ws;
  1874. }
  1875. void DeformableConvBackwardData::init_output_dtype() {
  1876. DType output_dtype = config().output_dtype();
  1877. mgb_assert(!output_dtype.valid() || output_dtype == dtype::Float32());
  1878. output_dtype = dtype::Float32();
  1879. output(0)->dtype(output_dtype);
  1880. output(1)->dtype(output_dtype);
  1881. output(2)->dtype(output_dtype);
  1882. }
  1883. void DeformableConvBackwardData::init_output_format() {
  1884. mgb_assert(output().size() == 4);
  1885. output(0)->format(input(0)->format());
  1886. output(1)->format(input(2)->format());
  1887. output(2)->format(input(3)->format());
  1888. }
  1889. cg::OperatorNodeBase::NodeProp* DeformableConvBackwardData::do_make_node_prop()
  1890. const {
  1891. auto prop = Super::Super::do_make_node_prop();
  1892. using D = NodeProp::DepType;
  1893. mgb_assert(input().size() == 5);
  1894. prop->reset_dep_type(input(), {D::DEV_VALUE, D::DEV_VALUE, D::DEV_VALUE,
  1895. D::DEV_VALUE, D::DEV_VALUE});
  1896. return prop;
  1897. }
  1898. void DeformableConvBackwardData::init_output_static_infer_desc() {
  1899. Super::set_nr_managed_outputs(this->output().size() - 1);
  1900. Super::init_output_static_infer_desc();
  1901. this->init_output_static_infer_desc_workspace(
  1902. intl::AutoAddWorkspaceNeedLimitGetter<
  1903. megdnn::DeformableConvBackwardData>::val);
  1904. }
  1905. /* ==================== DeformableConvBackwardFilter ==================== */
  1906. IMPL_CONV(DeformableConvBackwardFilter, "deformalbe_conv_backward_filter");
  1907. DeformableConvBackwardFilter::DeformableConvBackwardFilter(
  1908. VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask,
  1909. VarNode* diff, const Param& param, const ExecutionPolicy& policy,
  1910. const OperatorNodeConfig& config)
  1911. : Super({src->owner_graph(),
  1912. config,
  1913. "deformable_conv_backward_filter",
  1914. {src, filter, offset, mask, diff}},
  1915. 1, false) {
  1916. mgb_assert(src->dtype() == dtype::Float32() and
  1917. filter->dtype() == dtype::Float32() and
  1918. offset->dtype() == dtype::Float32() and
  1919. mask->dtype() == dtype::Float32() and
  1920. diff->dtype() == dtype::Float32(),
  1921. "input should be float32, got %s, %s, %s, %s %s",
  1922. src->dtype().name(), filter->dtype().name(),
  1923. offset->dtype().name(), mask->dtype().name(),
  1924. diff->dtype().name());
  1925. init_megdnn_opr(*this, param);
  1926. m_policy = policy;
  1927. add_input({src, filter, offset, mask, diff});
  1928. }
  1929. SymbolVar DeformableConvBackwardFilter::make(SymbolVar src, SymbolVar filter,
  1930. SymbolVar offset, SymbolVar mask,
  1931. SymbolVar diff, const Param& param,
  1932. const ExecutionPolicy& policy,
  1933. const OperatorNodeConfig& config) {
  1934. return src.insert_single_output_opr<DeformableConvBackwardFilter>(
  1935. src.node(), filter.node(), offset.node(), mask.node(), diff.node(),
  1936. param, policy, config);
  1937. }
  1938. void DeformableConvBackwardFilter::scn_do_execute() {
  1939. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), // src
  1940. input(2)->dev_tensor().as_megdnn(), // offset
  1941. input(3)->dev_tensor().as_megdnn(), // mask
  1942. input(4)->dev_tensor().as_megdnn(), // diff
  1943. output(0)->dev_tensor().as_megdnn(), // filter_diff
  1944. intl::get_megdnn_workspace_from_var(output(1)));
  1945. }
  1946. size_t DeformableConvBackwardFilter::get_workspace_size_bytes(
  1947. const TensorShapeArray& input_shapes,
  1948. const TensorShapeArray& output_shapes) const {
  1949. mgb_assert(input_shapes.size() == 5 && output_shapes.size() == 1);
  1950. return AlgoChooser<megdnn::DeformableConvBackwardFilter>::setup_algo(
  1951. {TensorLayout{input_shapes[0], input(0)->dtype(),
  1952. input(0)->format()},
  1953. {input_shapes[2], input(2)->dtype(), input(2)->format()},
  1954. {input_shapes[3], input(3)->dtype(), input(3)->format()},
  1955. {input_shapes[4], input(4)->dtype(), input(4)->format()},
  1956. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1957. megdnn_opr(), this);
  1958. }
  1959. /* ==================== BatchConvBiasForward ==================== */
  1960. IMPL_CONV(BatchConvBiasForward, "batch_conv_bias_fwd");
  1961. BatchConvBiasForward::BatchConvBiasForward(VarNode* src, VarNode* filter,
  1962. const Param& param,
  1963. const ExecutionPolicy& policy,
  1964. const OperatorNodeConfig& config)
  1965. : Super{src->owner_graph(), config, "batch_conv_bias", {src, filter}} {
  1966. init_megdnn_opr(*this, param);
  1967. m_policy = policy;
  1968. add_input({src, filter});
  1969. }
  1970. BatchConvBiasForward::BatchConvBiasForward(VarNode* src, VarNode* filter,
  1971. VarNode* bias, const Param& param,
  1972. const ExecutionPolicy& policy,
  1973. const OperatorNodeConfig& config)
  1974. : Super{src->owner_graph(),
  1975. config,
  1976. "batch_conv_bias",
  1977. {src, filter, bias}} {
  1978. m_policy = policy;
  1979. init_megdnn_opr(*this, param);
  1980. add_input({src, filter, bias});
  1981. }
  1982. BatchConvBiasForward::BatchConvBiasForward(VarNode* src, VarNode* filter,
  1983. VarNode* bias, VarNode* z,
  1984. const Param& param,
  1985. const ExecutionPolicy& policy,
  1986. const OperatorNodeConfig& config)
  1987. : Super{src->owner_graph(),
  1988. config,
  1989. "batch_conv_bias",
  1990. {src, filter, bias, z}} {
  1991. m_policy = policy;
  1992. init_megdnn_opr(*this, param);
  1993. add_input({src, filter, bias, z});
  1994. }
  1995. void BatchConvBiasForward::add_input_layout_constraint() {
  1996. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  1997. }
  1998. SymbolVar BatchConvBiasForward::make(SymbolVar src, SymbolVar filter,
  1999. const Param& param,
  2000. const ExecutionPolicy& policy,
  2001. const OperatorNodeConfig& config) {
  2002. return src.insert_single_output_opr<BatchConvBiasForward>(
  2003. src.node(), filter.node(), param, policy, config);
  2004. }
  2005. SymbolVar BatchConvBiasForward::make(SymbolVar src, SymbolVar filter,
  2006. SymbolVar bias, const Param& param,
  2007. const ExecutionPolicy& policy,
  2008. const OperatorNodeConfig& config) {
  2009. return src.insert_single_output_opr<BatchConvBiasForward>(
  2010. src.node(), filter.node(), bias.node(), param, policy, config);
  2011. }
  2012. SymbolVar BatchConvBiasForward::make(SymbolVar src, SymbolVar filter,
  2013. SymbolVar bias, SymbolVar z,
  2014. const Param& param,
  2015. const ExecutionPolicy& policy,
  2016. const OperatorNodeConfig& config) {
  2017. return src.insert_single_output_opr<BatchConvBiasForward>(
  2018. src.node(), filter.node(), bias.node(), z.node(), param, policy,
  2019. config);
  2020. }
  2021. void BatchConvBiasForward::init_output_dtype() {
  2022. DType output_dtype = config().output_dtype();
  2023. DType i0, i1, i2, i3;
  2024. mgb_assert(input().size() >= 2 && input().size() <= 4);
  2025. i0 = input(0)->dtype();
  2026. i1 = input(1)->dtype();
  2027. if (input().size() >= 3)
  2028. i2 = input(2)->dtype();
  2029. if (input().size() == 4)
  2030. i3 = input(3)->dtype();
  2031. megdnn_opr()->deduce_dtype(i0, i1, i2, i3, output_dtype);
  2032. output(0)->dtype(output_dtype);
  2033. }
  2034. size_t BatchConvBiasForward::get_workspace_size_bytes(
  2035. const TensorShapeArray& input_shapes,
  2036. const TensorShapeArray& output_shapes) const {
  2037. auto mo = megdnn_opr();
  2038. TensorLayout i0, i1, i2, i3;
  2039. mgb_assert(input_shapes.size() >= 2 && input_shapes.size() <= 4);
  2040. i0 = {input_shapes[0], input(0)->dtype(), input(0)->format()};
  2041. i1 = {input_shapes[1], input(1)->dtype(), input(1)->format()};
  2042. if (input_shapes.size() >= 3)
  2043. i2 = {input_shapes[2], input(2)->dtype(), input(2)->format()};
  2044. else {
  2045. DType dtype;
  2046. mo->deduce_dtype(input(0)->dtype(), input(1)->dtype(), DType{}, DType{},
  2047. dtype);
  2048. i2 = {{}, dtype};
  2049. }
  2050. if (input_shapes.size() == 4)
  2051. i3 = {input_shapes[3], input(3)->dtype(), input(3)->format()};
  2052. else
  2053. i3 = {{}, output(0)->dtype(), output(0)->format()};
  2054. return AlgoChooser<megdnn::BatchConvBias>::setup_algo(
  2055. {i0,
  2056. i1,
  2057. i2,
  2058. i3,
  2059. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  2060. mo, this);
  2061. }
  2062. void BatchConvBiasForward::scn_do_execute() {
  2063. auto&& inp = input();
  2064. auto mo = megdnn_opr();
  2065. if (inp.size() == 2) {
  2066. TensorLayout bias_layout;
  2067. bias_layout.ndim = 0;
  2068. if (output(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
  2069. bias_layout.dtype = dtype::QuantizedS32(
  2070. output(0)->dtype().param<dtype::QuantizedS8>().scale);
  2071. } else {
  2072. bias_layout.dtype = output(0)->dtype();
  2073. }
  2074. TensorLayout z_layout;
  2075. z_layout.ndim = 0;
  2076. z_layout.dtype = output(0)->dtype();
  2077. megdnn::TensorND bias_tensor{nullptr, bias_layout};
  2078. megdnn::TensorND z_tensor{nullptr, z_layout};
  2079. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  2080. inp[1]->dev_tensor().as_megdnn(), bias_tensor, z_tensor,
  2081. output(0)->dev_tensor().as_megdnn(),
  2082. intl::get_megdnn_workspace_from_var(output().back()));
  2083. } else if (inp.size() == 3) {
  2084. TensorLayout z_layout;
  2085. z_layout.ndim = 0;
  2086. z_layout.dtype = output(0)->dtype();
  2087. megdnn::TensorND z_tensor{nullptr, z_layout};
  2088. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  2089. inp[1]->dev_tensor().as_megdnn(),
  2090. inp[2]->dev_tensor().as_megdnn(), z_tensor,
  2091. output(0)->dev_tensor().as_megdnn(),
  2092. intl::get_megdnn_workspace_from_var(output().back()));
  2093. } else {
  2094. mgb_assert(inp.size() == 4);
  2095. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  2096. inp[1]->dev_tensor().as_megdnn(),
  2097. inp[2]->dev_tensor().as_megdnn(),
  2098. inp[3]->dev_tensor().as_megdnn(),
  2099. output(0)->dev_tensor().as_megdnn(),
  2100. intl::get_megdnn_workspace_from_var(output().back()));
  2101. }
  2102. }
  2103. void BatchConvBiasForward::get_output_var_shape(
  2104. const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
  2105. auto mo = megdnn_opr();
  2106. TensorLayout dst;
  2107. mo->deduce_layout({inp_shape[0], input(0)->dtype(), input(0)->format()},
  2108. {inp_shape[1], input(1)->dtype(), input(0)->format()}, {},
  2109. {}, dst);
  2110. out_shape[0] = dst;
  2111. }
  2112. void BatchConvBiasForward::init_output_static_infer_desc() {
  2113. Super::set_nr_managed_outputs(this->output().size() - 1);
  2114. Super::init_output_static_infer_desc();
  2115. this->init_output_static_infer_desc_workspace(
  2116. intl::AutoAddWorkspaceNeedLimitGetter<
  2117. megdnn::BatchConvBiasForward>::val);
  2118. }
  2119. void BatchConvBiasForward::init_output_format() {
  2120. mgb_assert(output().size() == 2);
  2121. output(0)->format(input(0)->format());
  2122. }
  2123. #undef IMPL_CONV
  2124. #undef MGB_FOREACH_FASTRUN_OPR
  2125. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台