You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo_chooser.cpp 52 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202
  1. #include <limits>
  2. #include <unordered_set>
  3. #include "megbrain/exception.h"
  4. #include "megbrain/rdnn/algo_chooser.h"
  5. #include "megbrain/utils/invoke.h"
  6. //! TODO: here has to be know some megdnn::opr when there is produced midout.h
  7. //! fix it if there is another graceful way.
  8. #include "megdnn/opr_param_defs.h"
  9. #include "megdnn/oprs.h"
  10. #include "megdnn/oprs/base.h"
  11. #include "midout.h"
  12. MIDOUT_DECL(megbrain_opr_algo_chooser)
  13. #define MIDOUT_B(...) MIDOUT_BEGIN(megbrain_opr_algo_chooser, __VA_ARGS__) {
  14. #define MIDOUT_E \
  15. } \
  16. MIDOUT_END();
  17. using namespace megdnn;
  18. using namespace mgb;
  19. #define APPLY(statement, ...) \
  20. mgb::apply( \
  21. [&](const auto&... args) { return statement; }, \
  22. std::tuple_cat(__VA_ARGS__))
  23. // timeout delta to be added with fastest known algorithm for new algos
  24. constexpr double TIMEOUT_TOLERANCE = 2;
  25. namespace {
  26. template <class MegDNNOpr>
  27. struct MegDNNOpr2Typename;
  28. #define cb(_Opr) \
  29. template <> \
  30. struct MegDNNOpr2Typename<megdnn::_Opr> { \
  31. static const char* name; \
  32. }; \
  33. const char* MegDNNOpr2Typename<megdnn::_Opr>::name = #_Opr;
  34. DNN_FOREACH_FASTRUN_OPR(cb)
  35. #undef cb
  36. template <typename Opr>
  37. std::string profile_name(Opr* opr) {
  38. std::string ret = std::string(::MegDNNOpr2Typename<Opr>::name) + CACHE_KEY_VERSION;
  39. ret.append(opr->get_algorithm_set_name());
  40. return ret;
  41. }
  42. template <typename Opr>
  43. std::string format_fixlayouts(
  44. const typename rdnn::AlgoChooser<Opr>::FixedTensorLayouts& layouts,
  45. size_t arity_in, size_t arity_out, const std::string& delimiter = " -> ") {
  46. std::string ret;
  47. if (arity_in) {
  48. ret.append("(");
  49. for (size_t i = 0; i < arity_in; ++i) {
  50. if (i) {
  51. ret.append(", ");
  52. }
  53. ret.append(layouts[i].to_string() + " ");
  54. }
  55. ret.append(")");
  56. }
  57. if (arity_in && arity_out) {
  58. ret.append(delimiter);
  59. }
  60. if (arity_out) {
  61. ret.append("(");
  62. for (size_t i = 0; i < arity_out; ++i) {
  63. if (i) {
  64. ret.append(", ");
  65. }
  66. ret.append(layouts[i + arity_in].to_string() + " ");
  67. }
  68. ret.append(")");
  69. }
  70. return ret;
  71. }
  72. /**
  73. * \brief Check if the sub opr list has circular dependence.
  74. */
  75. class CircularDepsChecker {
  76. struct SearchItemStorage {
  77. std::string data_hold;
  78. size_t hash = 0;
  79. SearchItemStorage(const Algorithm::SearchItem& item) {
  80. Algorithm::serialize_write_pod(item.opr_type, data_hold);
  81. for (auto&& layout : item.layouts) {
  82. data_hold += layout.serialize();
  83. }
  84. data_hold += item.param;
  85. }
  86. SearchItemStorage& init_hash() {
  87. hash = XXHash64CT::hash(data_hold.data(), data_hold.size(), 20201225);
  88. return *this;
  89. }
  90. bool operator==(const SearchItemStorage& rhs) const {
  91. return data_hold == rhs.data_hold;
  92. }
  93. struct Hash {
  94. size_t operator()(const SearchItemStorage& s) const { return s.hash; }
  95. };
  96. };
  97. std::unordered_set<SearchItemStorage, SearchItemStorage::Hash> m_set;
  98. public:
  99. void put(const megdnn::Algorithm::SearchItem& key) {
  100. SearchItemStorage key_storage(key);
  101. key_storage.init_hash();
  102. mgb_assert(
  103. m_set.find(key_storage) == m_set.end(),
  104. "Circular dependency during flatten search space");
  105. auto ret = m_set.insert(std::move(key_storage));
  106. mgb_assert(ret.second);
  107. }
  108. void remove(const megdnn::Algorithm::SearchItem& key) {
  109. SearchItemStorage key_storage(key);
  110. key_storage.init_hash();
  111. auto&& iter = m_set.find(key_storage);
  112. mgb_assert(iter != m_set.end());
  113. m_set.erase(iter);
  114. }
  115. };
  116. ///////////////// OprTypeTrait /////////////////////////////
  117. template <megdnn::Algorithm::OprType>
  118. struct OprFromOprTypeTrait;
  119. template <typename Opr>
  120. struct OprTypeFromOprTrait;
  121. #define cb(_opr_type, _opr) \
  122. template <> \
  123. struct OprFromOprTypeTrait<megdnn::Algorithm::OprType::_opr_type> { \
  124. using Opr = megdnn::_opr; \
  125. }; \
  126. template <> \
  127. struct OprTypeFromOprTrait<megdnn::_opr> { \
  128. constexpr static megdnn::Algorithm::OprType opr_type = \
  129. megdnn::Algorithm::OprType::_opr_type; \
  130. }
  131. cb(MATRIX_MUL_FORWARD, MatrixMulForward);
  132. cb(BATCHED_MATRIX_MUL_FORWARD, BatchedMatrixMulForward);
  133. cb(CONVOLUTION_FORWARD, ConvolutionForward);
  134. cb(CONVOLUTION_BACKWARD_DATA, ConvolutionBackwardData);
  135. cb(CONVOLUTION_BACKWARD_FILTER, ConvolutionBackwardFilter);
  136. cb(CONVOLUTION3D_FORWARD, Convolution3DForward);
  137. cb(CONVOLUTION3D_BACKWARD_DATA, Convolution3DBackwardData);
  138. cb(CONVOLUTION3D_BACKWARD_FILTER, Convolution3DBackwardFilter);
  139. cb(LOCAL_SHARE_FORWARD, LocalShareForward);
  140. cb(LOCAL_SHARE_BACKWARD_DATA, LocalShareBackwardData);
  141. cb(LOCAL_SHARE_BACKWARD_FILTER, LocalShareBackwardFilter);
  142. cb(DEFORMABLE_CONV_FORWARD, DeformableConvForward);
  143. cb(DEFORMABLE_CONV_BACKWARD_DATA, DeformableConvBackwardData);
  144. cb(DEFORMABLE_CONV_BACKWARD_FILTER, DeformableConvBackwardFilter);
  145. cb(BATCH_CONV_FORWARD, BatchConvBiasForward);
  146. cb(CONVBIAS_FORWARD, ConvBiasForward);
  147. cb(POOLING_FORWARD, PoolingForward);
  148. cb(POOLING_BACKWARD, PoolingBackward);
  149. #undef cb
  150. // clang-format off
  151. #define FOREACH_OPR_TYPE_WITH_STMT(cb, stmt) \
  152. cb(MATRIX_MUL_FORWARD, stmt) \
  153. cb(BATCHED_MATRIX_MUL_FORWARD, stmt) \
  154. cb(CONVOLUTION_FORWARD, stmt) \
  155. cb(CONVOLUTION_BACKWARD_DATA, stmt) \
  156. cb(CONVOLUTION_BACKWARD_FILTER, stmt) \
  157. cb(CONVOLUTION3D_FORWARD, stmt) \
  158. cb(CONVOLUTION3D_BACKWARD_DATA, stmt) \
  159. cb(CONVOLUTION3D_BACKWARD_FILTER, stmt) \
  160. cb(LOCAL_SHARE_FORWARD, stmt) \
  161. cb(LOCAL_SHARE_BACKWARD_DATA, stmt) \
  162. cb(LOCAL_SHARE_BACKWARD_FILTER, stmt) \
  163. cb(DEFORMABLE_CONV_FORWARD, stmt) \
  164. cb(DEFORMABLE_CONV_BACKWARD_DATA, stmt) \
  165. cb(DEFORMABLE_CONV_BACKWARD_FILTER, stmt) \
  166. cb(BATCH_CONV_FORWARD, stmt) \
  167. cb(CONVBIAS_FORWARD, stmt) \
  168. cb(POOLING_FORWARD, stmt) \
  169. cb(POOLING_BACKWARD, stmt)
  170. // clang-format on
  171. #define _OPR_TYPE_CASE(_opr_type, _stmt) \
  172. case Algorithm::OprType::_opr_type: { \
  173. using _Opr = typename OprFromOprTypeTrait<Algorithm::OprType::_opr_type>::Opr; \
  174. _stmt; \
  175. break; \
  176. }
  177. #define FOREACH_OPR_TYPE_DISPATCH(_search_items, _stmt) \
  178. for (size_t _item_idx = 0; _item_idx < _search_items.size(); _item_idx++) { \
  179. auto&& _item = _search_items[_item_idx]; \
  180. switch (_item.opr_type) { \
  181. FOREACH_OPR_TYPE_WITH_STMT(_OPR_TYPE_CASE, _stmt) \
  182. default: \
  183. mgb_throw(MegBrainError, "unknown opr_type"); \
  184. } \
  185. }
  186. template <typename Opr>
  187. TensorLayoutArray to_layout_array(
  188. const typename rdnn::AlgoChooser<Opr>::FixedTensorLayouts& layouts) {
  189. TensorLayoutArray ret;
  190. for (auto&& layout : layouts) {
  191. ret.push_back(layout);
  192. }
  193. return ret;
  194. }
  195. template <typename Opr>
  196. typename rdnn::AlgoChooser<Opr>::FixedTensorLayouts to_fixed_layouts(
  197. const TensorLayoutArray& layouts) {
  198. typename rdnn::AlgoChooser<Opr>::FixedTensorLayouts ret;
  199. mgb_assert(ret.size() == layouts.size());
  200. size_t idx = 0;
  201. for (auto&& layout : layouts) {
  202. ret[idx++] = layout;
  203. }
  204. return ret;
  205. }
  206. /**
  207. * flatten search space in postorder traversal
  208. * The subopr search construct a search tree
  209. *
  210. * A
  211. * / \
  212. * B1B2 C
  213. * / \
  214. * D1D2D3 E
  215. * We use postorder traverse the search tree.
  216. * D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A
  217. */
  218. template <typename Opr>
  219. std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
  220. const typename rdnn::AlgoChooser<Opr>::AlgoChooserHelper& helper,
  221. CircularDepsChecker& checker) {
  222. auto&& search_item = megdnn::Algorithm::SearchItem{
  223. OprTypeFromOprTrait<Opr>::opr_type, helper.param(),
  224. to_layout_array<Opr>(helper.fastrun_layouts())};
  225. checker.put(search_item);
  226. std::vector<megdnn::Algorithm::SearchItem> ret;
  227. for (auto algo_info : helper.get_all_candidates()) {
  228. megdnn::Algorithm* algo = helper.get_algorithm_from_desc(algo_info.desc);
  229. mgb_assert(algo, "Unknown algo description");
  230. std::vector<megdnn::Algorithm::SearchItem>&& sub_items = algo->get_subopr_list(
  231. to_layout_array<Opr>(helper.fastrun_layouts()), helper.megdnn_opr());
  232. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  233. auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(helper.comp_node());
  234. megdnn_opr->param() =
  235. Algorithm::deserialize_read_pod<typename _Opr::Param>(_item.param);
  236. typename rdnn::AlgoChooser<_Opr>::AlgoChooserHelper sub_helper(
  237. to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
  238. _item.param, helper.comp_node(), helper.execution_policy(),
  239. helper.allow_weight_preprocess(), helper.desc());
  240. auto space = flatten_search_space<_Opr>(sub_helper, checker);
  241. ret.insert(ret.end(), space.begin(), space.end());
  242. });
  243. }
  244. ret.push_back(search_item);
  245. checker.remove(search_item);
  246. return ret;
  247. }
  248. //! serialize a algo's desc to string. format is
  249. //! handle_type|algo_type|size_of_param|size_of_name|string_of_param|string_of_name
  250. static void serialize_write_pod(const Algorithm::Info::Desc& val, std::string& result) {
  251. megdnn::Algorithm::serialize_write_pod(val.handle_type, result);
  252. megdnn::Algorithm::serialize_write_pod(val.type, result);
  253. uint32_t param_size = val.param.size();
  254. uint32_t name_size = val.name.size();
  255. megdnn::Algorithm::serialize_write_pod<uint32_t>(param_size, result);
  256. megdnn::Algorithm::serialize_write_pod<uint32_t>(name_size, result);
  257. megdnn::Algorithm::serialize_write_pod(val.param, result);
  258. megdnn::Algorithm::serialize_write_pod(val.name, result);
  259. }
  260. static Algorithm::Info::Desc deserialize_read_pod(
  261. const std::string& data, size_t offset = 0) {
  262. Algorithm::Info::Desc ret;
  263. #define cb(_val, _type) \
  264. _val = megdnn::Algorithm::deserialize_read_pod<_type>(data.data(), offset); \
  265. offset += sizeof(_val)
  266. cb(ret.handle_type, megdnn::Handle::HandleType);
  267. cb(ret.type, uint32_t);
  268. uint32_t param_size = 0;
  269. uint32_t name_size = 0;
  270. cb(param_size, uint32_t);
  271. cb(name_size, uint32_t);
  272. if (param_size > 0) {
  273. ret.param = std::string(data.data() + offset, param_size);
  274. offset += param_size;
  275. }
  276. if (name_size > 0) {
  277. ret.name = std::string(data.data() + offset, name_size);
  278. offset += name_size;
  279. }
  280. return ret;
  281. }
  282. } // namespace
  283. namespace megdnn {
  284. namespace param {
  285. MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy)
  286. } // namespace param
  287. } // namespace megdnn
  288. namespace mgb {
  289. namespace rdnn {
  290. template <class Opr>
  291. class LayoutsModifier {
  292. using FixedTensorLayouts = typename AlgoChooser<Opr>::FixedTensorLayouts;
  293. public:
  294. static void on(FixedTensorLayouts&, const typename Opr::Param&, size_t) {}
  295. private:
  296. //! index of batch in tensor, 3 for CHWN4 e.g.
  297. static size_t index_of_batch(const typename Opr::Param&) { return 0; }
  298. //! indices contain batch in inputs and outputs, src(0) dst(2) for conv e.g.
  299. static std::vector<size_t> sm_indices_contain_batch;
  300. };
  301. template <class Opr>
  302. std::vector<size_t> LayoutsModifier<Opr>::sm_indices_contain_batch = {};
  303. #define DEFAULT_OPR_WITHOUT_INPUT_BROADCAST(opr, idxs) \
  304. template <> \
  305. class LayoutsModifier<opr> { \
  306. public: \
  307. using FixedTensorLayouts = typename AlgoChooser<opr>::FixedTensorLayouts; \
  308. static void on( \
  309. FixedTensorLayouts& layouts, const opr::Param& param, \
  310. size_t new_batch_size) { \
  311. size_t batch_index = index_of_batch(param); \
  312. for (size_t index : sm_indices_contain_batch) { \
  313. layouts.at(index)[batch_index] = new_batch_size; \
  314. } \
  315. } \
  316. \
  317. private: \
  318. static size_t index_of_batch(const opr::Param&) { return 0; } \
  319. static std::vector<size_t> sm_indices_contain_batch; \
  320. }; \
  321. std::vector<size_t> LayoutsModifier<opr>::sm_indices_contain_batch = idxs;
  322. DEFAULT_OPR_WITHOUT_INPUT_BROADCAST(
  323. megdnn::Convolution3DForward, (std::initializer_list<size_t>{0, 2}))
  324. DEFAULT_OPR_WITHOUT_INPUT_BROADCAST(
  325. megdnn::Convolution3DBackwardData, (std::initializer_list<size_t>{1, 2}))
  326. DEFAULT_OPR_WITHOUT_INPUT_BROADCAST(
  327. megdnn::Convolution3DBackwardFilter, (std::initializer_list<size_t>{0, 1}))
  328. DEFAULT_OPR_WITHOUT_INPUT_BROADCAST(
  329. megdnn::BatchedMatrixMul, (std::initializer_list<size_t>{0, 1, 2}))
  330. #undef DEFAULT_OPR_WITHOUT_INPUT_BROADCAST
  331. #define CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(opr, idxs) \
  332. template <> \
  333. class LayoutsModifier<opr> { \
  334. public: \
  335. using FixedTensorLayouts = typename AlgoChooser<opr>::FixedTensorLayouts; \
  336. static void on( \
  337. FixedTensorLayouts& layouts, const opr::Param& param, \
  338. size_t new_batch_size) { \
  339. size_t batch_index = index_of_batch(param); \
  340. for (size_t index : sm_indices_contain_batch) { \
  341. layouts.at(index)[batch_index] = new_batch_size; \
  342. } \
  343. } \
  344. \
  345. private: \
  346. static size_t index_of_batch(const opr::Param& param) { \
  347. if (param.format == opr::Param::Format::CHWN4) { \
  348. return 3; \
  349. } \
  350. return 0; \
  351. } \
  352. static std::vector<size_t> sm_indices_contain_batch; \
  353. }; \
  354. std::vector<size_t> LayoutsModifier<opr>::sm_indices_contain_batch = idxs;
  355. CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(
  356. megdnn::ConvolutionForward, (std::initializer_list<size_t>{0, 2}))
  357. CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(
  358. megdnn::ConvolutionBackwardData, (std::initializer_list<size_t>{1, 2}))
  359. CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(
  360. megdnn::ConvolutionBackwardFilter, (std::initializer_list<size_t>{0, 1}))
  361. CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(
  362. megdnn::LocalShareForward, (std::initializer_list<size_t>{0, 2}))
  363. CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(
  364. megdnn::LocalShareBackwardData, (std::initializer_list<size_t>{1, 2}))
  365. CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(
  366. megdnn::LocalShareBackwardFilter, (std::initializer_list<size_t>{0, 1}))
  367. CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(
  368. megdnn::DeformableConvForward, (std::initializer_list<size_t>{0, 2, 3, 4}))
  369. CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(
  370. megdnn::DeformableConvBackwardData,
  371. (std::initializer_list<size_t>{0, 2, 3, 4, 5, 6, 7}))
  372. CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(
  373. megdnn::DeformableConvBackwardFilter,
  374. (std::initializer_list<size_t>{0, 1, 2, 3}))
  375. CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(
  376. megdnn::BatchConvBiasForward, (std::initializer_list<size_t>{0, 1, 2, 3, 4}))
  377. #undef CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST
  378. template <>
  379. class LayoutsModifier<megdnn::ConvBiasForward> {
  380. public:
  381. using FixedTensorLayouts =
  382. typename AlgoChooser<megdnn::ConvBiasForward>::FixedTensorLayouts;
  383. static void on(
  384. FixedTensorLayouts& layouts, const megdnn::ConvBiasForward::Param& param,
  385. size_t new_batch_size) {
  386. size_t batch_index = index_of_batch(param);
  387. for (size_t index : sm_indices_contain_batch) {
  388. layouts.at(index)[batch_index] = new_batch_size;
  389. }
  390. for (size_t index : sm_indices_contain_batch_broadcast) {
  391. if (!check_bias_share_in_channel(layouts.at(index), param.format)) {
  392. layouts.at(index)[batch_index] = new_batch_size;
  393. }
  394. }
  395. }
  396. private:
  397. static std::vector<size_t> sm_indices_contain_batch;
  398. static std::vector<size_t> sm_indices_contain_batch_broadcast;
  399. static size_t index_of_batch(const megdnn::ConvBiasForward::Param& param) {
  400. if (param.format == megdnn::ConvBiasForward::Param::Format::CHWN4) {
  401. return 3;
  402. }
  403. return 0;
  404. }
  405. };
  406. std::vector<size_t> LayoutsModifier<megdnn::ConvBiasForward>::sm_indices_contain_batch =
  407. {0, 3, 4};
  408. std::vector<size_t>
  409. LayoutsModifier<megdnn::ConvBiasForward>::sm_indices_contain_batch_broadcast = {
  410. 2};
  411. template <>
  412. class LayoutsModifier<megdnn::MatrixMul> {
  413. public:
  414. using FixedTensorLayouts =
  415. typename AlgoChooser<megdnn::MatrixMul>::FixedTensorLayouts;
  416. static void on(
  417. FixedTensorLayouts& layouts, const megdnn::MatrixMul::Param& param,
  418. size_t new_batch_size) {
  419. //! Because we do not know whether the batch size is in the dimension m
  420. //! or the dimension n, we just ignore both m and n here.
  421. // FIXME Find a way to make mgb obtain batch size information from R or
  422. // automatically
  423. layouts.at(2)[0] = new_batch_size;
  424. layouts.at(2)[1] = new_batch_size;
  425. if (param.transposeA) {
  426. layouts.at(0)[1] = new_batch_size;
  427. } else {
  428. layouts.at(0)[0] = new_batch_size;
  429. }
  430. if (param.transposeB) {
  431. layouts.at(1)[0] = new_batch_size;
  432. } else {
  433. layouts.at(1)[1] = new_batch_size;
  434. }
  435. }
  436. };
  437. ///////////////////////////// AlgoChooserHelper //////////////////////////
  438. template <typename Opr>
  439. AlgoChooser<Opr>::AlgoChooserHelper::AlgoChooserHelper(
  440. const FixedTensorLayouts& layouts, Opr* megdnn_opr,
  441. const std::string& param_str, const CompNode& cn,
  442. const megdnn::param::ExecutionPolicy& execution_policy,
  443. bool allow_weight_preprocess, const AlgoChooserDesc& desc)
  444. : m_fastrun_layouts{layouts},
  445. m_incache_layouts{layouts},
  446. m_dnn_opr{megdnn_opr},
  447. m_param{param_str},
  448. m_cn{cn},
  449. m_execution_policy{execution_policy},
  450. m_allow_weight_preprocess{allow_weight_preprocess},
  451. m_desc{desc} {
  452. auto fastrun_batch_size = desc.shared_batch_size;
  453. if (fastrun_batch_size) {
  454. LayoutsModifier<Opr>::on(m_incache_layouts, m_dnn_opr->param(), 0);
  455. LayoutsModifier<Opr>::on(
  456. m_fastrun_layouts, m_dnn_opr->param(), fastrun_batch_size);
  457. }
  458. if (m_desc.no_profiling_on_shape_change) {
  459. for (size_t i = 0; i < m_incache_layouts.size(); i++) {
  460. for (size_t j = 0; j < m_incache_layouts.at(i).ndim; j++) {
  461. m_incache_layouts.at(i)[j] = 0;
  462. m_incache_layouts.at(i).stride[j] = 0;
  463. }
  464. }
  465. }
  466. mgb_assert(m_fastrun_layouts.size() == layouts.size());
  467. static_assert(
  468. std::tuple_size<FixedTensorLayouts>::value == 2 ||
  469. std::tuple_size<FixedTensorLayouts>::value == 3 ||
  470. std::tuple_size<FixedTensorLayouts>::value == 4 ||
  471. std::tuple_size<FixedTensorLayouts>::value == 5 ||
  472. std::tuple_size<FixedTensorLayouts>::value == 8,
  473. "Pooling assumes arity = 2 or 4,Convolution AlgoChooser assumes "
  474. "arity = 3 , 5 or 8 (for deformable conv)");
  475. }
  476. template <typename Opr>
  477. typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::AlgoChooserHelper::
  478. choose_by_heuristic(const ExecutionStrategy& selected_strategy) const {
  479. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_heuristic")))
  480. ImplExecutionPolicy policy;
  481. auto workspace_limit =
  482. m_desc.get_workspace_limit(m_cn, m_execution_policy.workspace_limit);
  483. auto attr = extract_algo_attribute(selected_strategy);
  484. policy.algo = APPLY(m_dnn_opr->get_algorithm_info_heuristic(
  485. args..., workspace_limit, attr.first, attr.second),
  486. m_fastrun_layouts)
  487. .desc;
  488. Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
  489. mgb_assert(algo, "Unknown algo description");
  490. std::vector<Algorithm::SearchItem>&& sub_items =
  491. algo->get_subopr_list(to_layout_array<Opr>(m_fastrun_layouts), m_dnn_opr);
  492. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  493. auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn);
  494. megdnn_opr->param() =
  495. Algorithm::deserialize_read_pod<typename _Opr::Param>(_item.param);
  496. typename AlgoChooser<_Opr>::AlgoChooserHelper sub_helper(
  497. to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), _item.param,
  498. m_cn, m_execution_policy, m_allow_weight_preprocess, m_desc);
  499. policy.sub_policy.push_back(sub_helper.choose_by_heuristic(selected_strategy));
  500. });
  501. return policy;
  502. MIDOUT_E
  503. }
  504. template <typename Opr>
  505. typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::AlgoChooserHelper::
  506. choose_by_profile(
  507. const ExecutionStrategy& selected_strategy, bool enable_update) const {
  508. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_profile")))
  509. // no_profiling_on_shape_change is usually false, no interface to change it easily
  510. if (m_desc.no_profiling_on_shape_change) {
  511. auto policy = m_dnn_opr->execution_policy();
  512. if (policy.algo.valid()) {
  513. return policy;
  514. }
  515. if (is_matmul<Opr>()) {
  516. mgb_log_warn(
  517. "choose algo by heuristic, which may cause performance "
  518. "regression.");
  519. return choose_by_heuristic(selected_strategy);
  520. }
  521. }
  522. typename AlgoChooser<Opr>::ImplExecutionPolicy tmp_policy;
  523. bool retrive_from_cache = true;
  524. bool allow_log = false;
  525. construct_execution_policy(
  526. selected_strategy, tmp_policy, retrive_from_cache, allow_log);
  527. if (tmp_policy.algo.valid()) {
  528. // return policy when contruct successed
  529. return tmp_policy;
  530. }
  531. // if update enabled, do profiling and update cache
  532. // enable_update = false only when using HEURISRIC_PROFILE strategy
  533. if (enable_update) {
  534. CircularDepsChecker circular_deps_checker;
  535. auto&& search_items = flatten_search_space<Opr>(*this, circular_deps_checker);
  536. FOREACH_OPR_TYPE_DISPATCH(search_items, {
  537. auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn);
  538. // skip different sub opr, for example:
  539. // skip matmul algo when profiling convolution
  540. if (m_dnn_opr->get_opr_type() != megdnn_opr->get_opr_type())
  541. continue;
  542. megdnn_opr->param() =
  543. Algorithm::deserialize_read_pod<typename _Opr::Param>(_item.param);
  544. typename AlgoChooser<_Opr>::AlgoChooserHelper sub_helper(
  545. to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
  546. _item.param, m_cn, m_execution_policy, m_allow_weight_preprocess,
  547. m_desc);
  548. sub_helper.profile(selected_strategy);
  549. });
  550. }
  551. // try to retrive algorithm from fastrun cache, this time it's guaranteed to get
  552. // result, retrive_from_cache = true, allow_log = true
  553. typename AlgoChooser<Opr>::ImplExecutionPolicy policy;
  554. construct_execution_policy(selected_strategy, policy);
  555. if (policy.algo.valid())
  556. return policy;
  557. return choose_by_heuristic(selected_strategy);
  558. MIDOUT_E
  559. }
  560. template <typename Opr>
  561. std::pair<
  562. typename AlgoChooser<Opr>::ImplAlgoDesc, Maybe<AlgoChooserProfileCache::Result>>
  563. AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
  564. const ExecutionStrategy& selected_strategy) const {
  565. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_profile_result_from_cache")))
  566. AlgoChooserProfileCache cache(m_cn, profile_name(m_dnn_opr).c_str());
  567. typename Opr::Param origin_param = m_dnn_opr->param();
  568. AlgoChooserProfileCache::Key cache_key{
  569. m_incache_layouts.data(), m_incache_layouts.size(), &origin_param,
  570. sizeof(origin_param)};
  571. auto&& rst = cache.get(cache_key);
  572. // failed to find a cache entry, return
  573. if (!rst.valid())
  574. return {{}, rst};
  575. // found a cache entry(it's a vector of Result), but it's empty
  576. auto&& prof = rst.val();
  577. if (prof.empty())
  578. return {{}, rst};
  579. // found non-empty cache result, filter it by workspace limit and attribute
  580. size_t workspace_limit =
  581. m_desc.get_workspace_limit(m_cn, m_execution_policy.workspace_limit);
  582. auto target_attr = extract_algo_attribute(selected_strategy);
  583. bool skip_by_negative = false;
  584. bool skip_by_workspace = false;
  585. for (auto&& i : prof) {
  586. auto attr_of_algo = static_cast<megdnn::Algorithm::Attribute>(i.attribute);
  587. bool contain_attr_all_positive =
  588. (target_attr.first == (attr_of_algo & target_attr.first));
  589. bool contain_attr_any_negative =
  590. static_cast<bool>(attr_of_algo & target_attr.second);
  591. if (contain_attr_all_positive) {
  592. if (!contain_attr_any_negative) {
  593. if (i.workspace <= workspace_limit) {
  594. // found a well-suited algothrim with good workspace limit and
  595. // correct attribute
  596. Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo);
  597. return {algo_desc, rst};
  598. }
  599. skip_by_workspace = true;
  600. } else {
  601. skip_by_negative = true;
  602. }
  603. }
  604. }
  605. // failed to find an algorithm that satisfies the actual workspace limit
  606. if (skip_by_workspace)
  607. return {};
  608. // failed to find an algorithm that satisfies the actual attribute
  609. std::string layouts_str = AlgoChooser::format_fixlayouts(m_fastrun_layouts);
  610. if (skip_by_negative) {
  611. mgb_log_error(
  612. "opr: %s, layouts: %s, No usable algo. There are available "
  613. "algos match "
  614. "positive strategy(%s), but filtered by negative stategy(%s).",
  615. ::MegDNNOpr2Typename<Opr>::name, layouts_str.c_str(),
  616. Algorithm::attribute_str(target_attr.first).c_str(),
  617. Algorithm::attribute_str(target_attr.second).c_str());
  618. } else {
  619. mgb_log_error(
  620. "opr: %s, layouts: %s, No usable algo. algos read from cache "
  621. "could not "
  622. "satisfy positive strategy(%s)",
  623. ::MegDNNOpr2Typename<Opr>::name, layouts_str.c_str(),
  624. Algorithm::attribute_str(target_attr.first).c_str());
  625. }
  626. mgb_trap();
  627. MIDOUT_E
  628. }
  629. template <typename Opr>
  630. void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy(
  631. const ExecutionStrategy& selected_strategy,
  632. typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, bool retrive_from_cache,
  633. bool allow_log) const {
  634. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_execution_policy")))
  635. // policy.algo is always invalid when called from choose_by_profile
  636. // policy.algo will be valid when called from profile
  637. if (!policy.algo.valid()) {
  638. if (retrive_from_cache) {
  639. policy.algo = get_profile_result_from_cache(selected_strategy).first;
  640. // nothing is found even with profiling
  641. if (!policy.algo.valid()) {
  642. if (allow_log) {
  643. auto target_attr = extract_algo_attribute(selected_strategy);
  644. std::string layouts_str =
  645. AlgoChooser::format_fixlayouts(m_fastrun_layouts);
  646. std::string msg = ssprintf(
  647. "(opr : %s, layouts %s, with attribute(%s) and "
  648. "without attribute(%s)",
  649. ::MegDNNOpr2Typename<Opr>::name, layouts_str.c_str(),
  650. Algorithm::attribute_str(target_attr.first).c_str(),
  651. Algorithm::attribute_str(target_attr.second).c_str());
  652. mgb_log_debug(
  653. "No algo get from cache for %s. This may caused by "
  654. "mismatch with model and cache file or imcomplete "
  655. "cache file. ex. profiling with version1, but "
  656. "inferencing on version2 or profiling modelA but "
  657. "inferencing modelB",
  658. msg.c_str());
  659. }
  660. return;
  661. }
  662. } else {
  663. // retrive_from_cache = false happens when using algo choose hook in
  664. // megbrain graph return heuristic algorithm in this case
  665. auto workspace_limit = m_desc.get_workspace_limit(
  666. m_cn, m_execution_policy.workspace_limit);
  667. auto attr = extract_algo_attribute(selected_strategy);
  668. policy.algo =
  669. APPLY(m_dnn_opr->get_algorithm_info_heuristic(
  670. args..., workspace_limit, attr.first, attr.second),
  671. m_fastrun_layouts)
  672. .desc;
  673. mgb_assert(
  674. policy.algo.valid(),
  675. "No algo found from heuristic with strategy %u and "
  676. "workspace limit %zu",
  677. static_cast<uint32_t>(selected_strategy), workspace_limit);
  678. }
  679. }
  680. // construct current algorithm
  681. Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
  682. mgb_assert(algo, "Unknown algo description");
  683. std::vector<Algorithm::SearchItem>&& sub_items =
  684. algo->get_subopr_list(to_layout_array<Opr>(m_fastrun_layouts), m_dnn_opr);
  685. // construct sub oprs' algorithm
  686. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  687. auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn);
  688. megdnn_opr->param() =
  689. Algorithm::deserialize_read_pod<typename _Opr::Param>(_item.param);
  690. typename AlgoChooser<_Opr>::AlgoChooserHelper sub_helper(
  691. to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), _item.param,
  692. m_cn, m_execution_policy, m_allow_weight_preprocess, m_desc);
  693. policy.sub_policy.push_back({});
  694. sub_helper.construct_execution_policy(
  695. selected_strategy, policy.sub_policy.back(), retrive_from_cache,
  696. allow_log);
  697. if (!policy.sub_policy.back().algo.valid()) {
  698. // means sub_helper.construct_execution_policy fails. clean up
  699. // policy.algo and return
  700. policy = {};
  701. return;
  702. }
  703. });
  704. MIDOUT_E
  705. }
  706. template <typename Opr>
  707. size_t AlgoChooser<Opr>::AlgoChooserHelper::get_workspace_size_bytes(
  708. const ImplExecutionPolicy& policy, const FixedTensorLayouts& layouts) const {
  709. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_workspace_size_bytes")))
  710. m_dnn_opr->execution_policy() = policy;
  711. size_t result;
  712. const FixedTensorLayouts* layouts_ptr = &m_fastrun_layouts;
  713. if (layouts.at(0).ndim) {
  714. layouts_ptr = &layouts;
  715. }
  716. if_constexpr<opr_supports_preprocess<Opr>()>(
  717. [&](auto _) {
  718. auto&& opr = _(m_dnn_opr);
  719. auto prep = this->construct_fake_preprocess_filter(*layouts_ptr);
  720. PreprocessFilter<Opr>* prep_ptr = prep.valid() ? &prep.val() : nullptr;
  721. result = std::max(
  722. APPLY(opr->get_preprocess_workspace_in_bytes(args...),
  723. *layouts_ptr),
  724. APPLY(opr->get_workspace_in_bytes(args..., prep_ptr),
  725. *layouts_ptr));
  726. },
  727. /* else */
  728. [&](auto _) {
  729. result = APPLY(
  730. _(m_dnn_opr)->get_workspace_in_bytes(args...), *layouts_ptr);
  731. });
  732. return result;
  733. MIDOUT_E
  734. }
  735. template <typename Opr>
  736. std::vector<typename AlgoChooser<Opr>::ImplAlgo> AlgoChooser<
  737. Opr>::AlgoChooserHelper::get_all_candidates() const {
  738. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates")))
  739. auto heu = choose_by_heuristic(m_execution_policy.strategy);
  740. auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info(args...), m_fastrun_layouts);
  741. bool found = false;
  742. // make heuristic algorithm always the first in all candidate alrogrithms
  743. // so profiling step will always run heuristic algorithm first
  744. for (size_t i = 0; i < ret.size(); ++i) {
  745. if (ret[i].desc == heu.algo) {
  746. found = true;
  747. std::swap(ret[i], ret[0]);
  748. break;
  749. }
  750. }
  751. // make sure heuristic algorithm is valid
  752. Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(heu.algo);
  753. mgb_assert(palgo, "Unknown algo description");
  754. mgb_assert(
  755. found,
  756. "algo %s got by heuristic not found in "
  757. "candidate list",
  758. palgo->name());
  759. return std::move(ret);
  760. MIDOUT_E
  761. }
  762. template <typename Opr>
  763. Maybe<AlgoChooserProfileCache::ResultEntry> AlgoChooser<Opr>::AlgoChooserHelper::
  764. profile_single_algo(const ImplExecutionPolicy& policy, double& timeout) const {
  765. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile_single_algo")))
  766. // fill TimedProfiler<Opr>::param and run actual timed profiler
  767. typename TimedProfiler<Opr>::Param param;
  768. // force check copy size <= dest len-1 from gcc8 for safe
  769. param.execution_policy =
  770. TimedProfiler<Opr>::Param::ExecutionPolicyBlob::serialize(policy);
  771. param.workspace = get_workspace_size_bytes(policy);
  772. for (int i = 0; i < arity; ++i) {
  773. auto&& src = m_fastrun_layouts[i];
  774. bool cond_normal = src.format.is_default() &&
  775. (src.dtype.category() == DTypeCategory::FLOAT ||
  776. src.dtype.category() == DTypeCategory::INT ||
  777. src.dtype.category() == DTypeCategory::QUANTIZED);
  778. bool cond_low_bit = src.dtype.is_low_bit() && src.format.is_lowbit_aligned() &&
  779. (src.dtype.category() == DTypeCategory::QUANTIZED ||
  780. src.dtype.category() == DTypeCategory::LOWBIT);
  781. MGB_MARK_USED_VAR(cond_normal);
  782. MGB_MARK_USED_VAR(cond_low_bit);
  783. mgb_assert(
  784. cond_normal || cond_low_bit, "unsupported layout in profiling: %s",
  785. src.to_string().c_str());
  786. param.dtypes[i] = src.dtype.enumv();
  787. }
  788. param.comp_node_physical = m_cn.locator();
  789. param.comp_node_logical = m_cn.locator_logical();
  790. mgb_assert(param.shapes.size() == m_fastrun_layouts.size());
  791. for (size_t i = 0; i < param.shapes.size(); ++i)
  792. param.shapes[i] = m_fastrun_layouts[i];
  793. param.opr_param = m_dnn_opr->param();
  794. param.allow_weight_preprocess = m_allow_weight_preprocess;
  795. Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
  796. mgb_assert(palgo, "can not find algo when profile single algo");
  797. auto rst = TimedProfiler<Opr>::profile(param, timeout);
  798. // MIOpen conv profiles all available algos when a specfic shape is
  799. // provided for the first time, which probably adds to the result time.
  800. // Therefore, a second profile execution is needed.
  801. if (strncmp(palgo->name(), "MIOpen", 6) == 0) {
  802. rst = TimedProfiler<Opr>::profile(param, timeout);
  803. }
  804. if (!rst.valid())
  805. return None;
  806. // subprocess will return dbl_max when meomry limit is not satisfied
  807. if (rst.val().time == std::numeric_limits<double>::max())
  808. return None;
  809. std::string algo_desc;
  810. serialize_write_pod(policy.algo, algo_desc);
  811. return AlgoChooserProfileCache::ResultEntry{
  812. algo_desc, static_cast<uint32_t>(palgo->attribute()), rst.val().time,
  813. param.workspace};
  814. MIDOUT_E
  815. }
  816. template <typename Opr>
  817. void AlgoChooser<Opr>::AlgoChooserHelper::profile(
  818. const ExecutionStrategy& selected_strategy) const {
  819. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile")))
  820. // some sub oprs have beed profiled before
  821. // sub oprs won't be checked at the beginning of choose_by_profile
  822. auto&& rst = get_profile_result_from_cache(selected_strategy);
  823. // rst.first.valid means there exists valid algorithms for current opr, just return
  824. // otherwise need to profile
  825. // in order to avoid reprofile in fastrun
  826. if (rst.first.valid())
  827. return;
  828. AlgoChooserProfileCache::Result prof_rst;
  829. auto target_attr = extract_algo_attribute(selected_strategy);
  830. std::string layouts_str = AlgoChooser::format_fixlayouts(m_fastrun_layouts);
  831. double cur_timeout = 0;
  832. size_t data_size = 0;
  833. for (auto ly : m_fastrun_layouts)
  834. data_size += ly.span().dist_byte();
  835. auto workspace_limit =
  836. m_desc.get_workspace_limit(m_cn, m_execution_policy.workspace_limit);
  837. RealTimer timer;
  838. std::unordered_set<std::string> rst_algos;
  839. if (rst.second.valid()) {
  840. std::transform(
  841. rst.second.val().begin(), rst.second.val().end(),
  842. std::inserter(rst_algos, rst_algos.end()),
  843. [](const AlgoChooserProfileCache::ResultEntry& result) {
  844. return result.algo;
  845. });
  846. }
  847. for (auto algo : get_all_candidates()) {
  848. std::string desc;
  849. serialize_write_pod(algo.desc, desc);
  850. if (rst_algos.find(desc) != rst_algos.end()) {
  851. continue;
  852. }
  853. Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst;
  854. ImplExecutionPolicy policy;
  855. policy.algo = algo.desc;
  856. // skip naive algo, can not using attribute to determine naive algo, thus using
  857. // strcmp
  858. if (algo.desc.name.compare("NAIVE") == 0) {
  859. continue;
  860. }
  861. //! check negative attribute : skip negative attribute
  862. auto palgo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
  863. if (palgo->contain_attribute_any(target_attr.second)) {
  864. mgb_log_debug(
  865. "skip algo %s, which matches the profile strategy required "
  866. "'not contain attribute(%s).'",
  867. algo.desc.name.c_str(),
  868. Algorithm::attribute_str(target_attr.second).c_str());
  869. continue;
  870. }
  871. //! check workspace limit
  872. construct_execution_policy(selected_strategy, policy);
  873. // this will failed
  874. // when construct matmul algorithm for convolution opr
  875. if (!policy.algo.valid())
  876. continue;
  877. size_t workspace_needed = get_workspace_size_bytes(policy);
  878. if (data_size + workspace_needed >
  879. m_desc.get_workspace_limit(m_cn, m_execution_policy.workspace_limit)) {
  880. continue;
  881. }
  882. std::string msg = ssprintf(
  883. "profiling %s algorithm %s %s", ::MegDNNOpr2Typename<Opr>::name,
  884. algo.desc.name.c_str(), layouts_str.c_str());
  885. timer.reset();
  886. MGB_TRY { cur_rst = profile_single_algo(policy, cur_timeout); }
  887. // megbrain catched exception
  888. MGB_CATCH(std::exception & exc, {
  889. mgb_log_debug("caught exception during %s: %s", msg.c_str(), exc.what());
  890. continue;
  891. })
  892. // megbrain uncatched exception
  893. MGB_CATCH(..., {
  894. mgb_log_debug("caught exception during %s", msg.c_str());
  895. continue;
  896. })
  897. if (!cur_rst.valid()) {
  898. mgb_log_debug(
  899. "timeout when %s; timeout setting: %.3fsec", msg.c_str(),
  900. cur_timeout);
  901. continue;
  902. }
  903. if (!cur_timeout) {
  904. cur_timeout = timer.get_secs() + TIMEOUT_TOLERANCE;
  905. } else {
  906. cur_timeout = std::min(cur_timeout, timer.get_secs() + TIMEOUT_TOLERANCE);
  907. }
  908. auto&& rst = cur_rst.val();
  909. mgb_log_debug(
  910. "%s: workspace: %zu; time: %.3gsec", msg.c_str(), rst.workspace,
  911. rst.time);
  912. prof_rst.push_back(rst);
  913. }
  914. std::string msg = ssprintf(
  915. "no usable %s algorithm %s without attribute(%s) or could not meet "
  916. "workspace limite requirement(%zu)",
  917. ::MegDNNOpr2Typename<Opr>::name, layouts_str.c_str(),
  918. Algorithm::attribute_str(target_attr.second).c_str(), workspace_limit);
  919. // allowed to have empty profile result for current opr
  920. // append some previous profiled results
  921. if (rst.second.valid())
  922. prof_rst.insert(
  923. prof_rst.end(), rst.second.val().begin(), rst.second.val().end());
  924. if (!prof_rst.empty()) {
  925. FixedTensorLayouts incache_layouts = m_incache_layouts;
  926. typename Opr::Param origin_param = m_dnn_opr->param();
  927. AlgoChooserProfileCache::Key cache_key{
  928. incache_layouts.data(), incache_layouts.size(), &origin_param,
  929. sizeof(origin_param)};
  930. AlgoChooserProfileCache cache(m_cn, profile_name(m_dnn_opr).c_str());
  931. cache.put(cache_key, prof_rst);
  932. }
  933. MIDOUT_E
  934. }
  935. template <typename Opr>
  936. Maybe<PreprocessFilter<Opr>> AlgoChooser<Opr>::AlgoChooserHelper::
  937. construct_fake_preprocess_filter(const FixedTensorLayouts& layouts) const {
  938. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_fake_preprocess_filter")))
  939. Maybe<PreprocessFilter<Opr>> result = None;
  940. const FixedTensorLayouts* layouts_ptr = &m_fastrun_layouts;
  941. if (layouts.at(0).ndim) {
  942. layouts_ptr = &layouts;
  943. }
  944. if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) {
  945. if (!m_allow_weight_preprocess)
  946. return;
  947. auto opr = _(m_dnn_opr);
  948. auto layouts =
  949. APPLY(opr->deduce_preprocessed_filter_layout(args...), *layouts_ptr);
  950. //! No preprocess layout means no need weight preprocess
  951. if (layouts.empty()) {
  952. return;
  953. }
  954. //! all layouts arm empty means no need weight preprocess
  955. bool layout_valid = false;
  956. for (auto&& layout : layouts) {
  957. if (!layout.is_empty()) {
  958. layout_valid = true;
  959. }
  960. }
  961. if (!layout_valid) {
  962. return;
  963. }
  964. result = PreprocessFilter<Opr>{};
  965. auto& res = result.val();
  966. res.algorithm_id = nullptr;
  967. res.tensors.resize(layouts.size());
  968. for (size_t i = 0; i < layouts.size(); i++) {
  969. res.tensors[i] = megdnn::TensorND(nullptr, layouts[i]);
  970. }
  971. });
  972. return result;
  973. MIDOUT_E
  974. }
  975. template <typename Opr>
  976. std::pair<AlgoAttribute, AlgoAttribute> AlgoChooser<Opr>::AlgoChooserHelper::
  977. extract_algo_attribute(const ExecutionStrategy& strategy) const {
  978. std::pair<AlgoAttribute, AlgoAttribute> ret =
  979. std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
  980. //! from strategy
  981. if (strategy & ExecutionStrategy::REPRODUCIBLE) {
  982. ret.first |= AlgoAttribute::REPRODUCIBLE;
  983. }
  984. if (strategy & ExecutionStrategy::OPTMIZED) {
  985. ret.second |= AlgoAttribute::NAIVE;
  986. }
  987. //! from graph option
  988. // FIXME: no_profiling_on_shape_change extract USABLE_DEPEND_ON_SHAPE
  989. // attribute when fixed usable
  990. if (m_desc.shared_batch_size) {
  991. ret.second |= AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
  992. }
  993. if (m_desc.binary_equal_between_batch) {
  994. ret.first |= AlgoAttribute::REPRODUCIBLE;
  995. ret.second |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
  996. }
  997. return ret;
  998. }
  999. #define INST(Opr) \
  1000. template AlgoChooser<megdnn::Opr>::AlgoChooserHelper::AlgoChooserHelper( \
  1001. const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
  1002. const std::string& param_str, const CompNode& cn, \
  1003. const megdnn::param::ExecutionPolicy& execution_policy, \
  1004. bool allow_weight_preprocess, const AlgoChooserDesc& desc); \
  1005. template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
  1006. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::choose_by_heuristic( \
  1007. const ExecutionStrategy& select_strategy) const; \
  1008. template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
  1009. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::choose_by_profile( \
  1010. const ExecutionStrategy& select_strategy, bool enable_update) const; \
  1011. template typename std::pair< \
  1012. AlgoChooser<megdnn::Opr>::ImplAlgoDesc, \
  1013. Maybe<AlgoChooserProfileCache::Result>> \
  1014. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::get_profile_result_from_cache( \
  1015. const ExecutionStrategy& select_strategy) const; \
  1016. template void \
  1017. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::construct_execution_policy( \
  1018. const ExecutionStrategy& select_strategy, \
  1019. typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \
  1020. bool retrive_from_cache, bool allow_log) const; \
  1021. template size_t \
  1022. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::get_workspace_size_bytes( \
  1023. const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \
  1024. const FixedTensorLayouts& layouts) const; \
  1025. template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \
  1026. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::get_all_candidates() const; \
  1027. template Maybe<AlgoChooserProfileCache::ResultEntry> \
  1028. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::profile_single_algo( \
  1029. const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \
  1030. double& timeout) const; \
  1031. template std::pair<AlgoAttribute, AlgoAttribute> \
  1032. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::extract_algo_attribute( \
  1033. const ExecutionStrategy& strategy) const; \
  1034. template void AlgoChooser<megdnn::Opr>::AlgoChooserHelper::profile( \
  1035. const ExecutionStrategy& selected_strategy) const;
  1036. DNN_FOREACH_FASTRUN_OPR(INST)
  1037. #undef INST
  1038. //////////////////////////////// AlgoChoose /////////////////////////////
  1039. template <typename Opr>
  1040. typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
  1041. const AlgoChooserHelper& helper) {
  1042. auto opr_strategy = helper.execution_policy().strategy;
  1043. auto strategy2str = [](auto strategy) {
  1044. std::string ret;
  1045. if (strategy & ExecutionStrategy::HEURISTIC) {
  1046. ret += "HEURISTIC ";
  1047. }
  1048. if (strategy & ExecutionStrategy::PROFILE) {
  1049. ret += "PROFILE ";
  1050. }
  1051. if (strategy & ExecutionStrategy::REPRODUCIBLE) {
  1052. ret += "REPRODUCIBLE ";
  1053. }
  1054. if (strategy & ExecutionStrategy::OPTIMIZED) {
  1055. ret += "OPTIMIZED ";
  1056. }
  1057. return ret;
  1058. };
  1059. mgb_log_debug("Use Stragegy :%s", strategy2str(opr_strategy).c_str());
  1060. if (opr_strategy & ExecutionStrategy::HEURISTIC) {
  1061. if (opr_strategy & ExecutionStrategy::PROFILE) {
  1062. //! this strategy will choose from cache first, then choost by
  1063. //! heuristic if fail.
  1064. ImplExecutionPolicy policy = helper.choose_by_profile(opr_strategy, false);
  1065. if (!policy.algo.valid()) {
  1066. policy = helper.choose_by_heuristic(opr_strategy);
  1067. }
  1068. return policy;
  1069. } else {
  1070. return helper.choose_by_heuristic(opr_strategy);
  1071. }
  1072. }
  1073. #if MGB_ENABLE_FASTRUN
  1074. else if (opr_strategy & ExecutionStrategy::PROFILE) {
  1075. return helper.choose_by_profile(opr_strategy, true);
  1076. }
  1077. #endif
  1078. else {
  1079. mgb_throw(InternalError, "bad ExecutionPolicy strategy");
  1080. }
  1081. }
  1082. template <typename Opr>
  1083. std::string AlgoChooser<Opr>::format_fixlayouts(const FixedTensorLayouts& layout) {
  1084. return ::format_fixlayouts<Opr>(layout, arity_in, arity_out);
  1085. }
  1086. #define INST(Opr) \
  1087. template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
  1088. AlgoChooser<megdnn::Opr>::get_policy(const AlgoChooserHelper& proxy); \
  1089. template std::string AlgoChooser<Opr>::format_fixlayouts( \
  1090. const FixedTensorLayouts& layout);
  1091. DNN_FOREACH_FASTRUN_OPR(INST)
  1092. #undef INST
  1093. } // namespace rdnn
  1094. } // namespace mgb
  1095. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}