You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cond.cpp 52 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358
  1. /**
  2. * \file src/opr/impl/cond.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/cond.h"
  12. #include "megbrain/graph/event.h"
  13. #include "megbrain/graph/grad_impl.h"
  14. #include "megbrain/opr/basic_arith.h"
  15. #include "megbrain/opr/utility.h"
  16. using namespace mgb;
  17. using namespace opr;
  18. #if MGB_ENABLE_COND_EXEC
  19. namespace {
  20. //! return whether ``lhs -> rhs`` can be proved
  21. bool can_prove_imply(cg::ExecutionMask* lhs, cg::ExecutionMask* rhs) {
  22. // this function is neither sound nor complete (and it can never be due
  23. // to the NP-completeness of SAT); here we only handle the most common
  24. // cases
  25. if (rhs == lhs->parent()) {
  26. // nested cond exec oprs
  27. return true;
  28. }
  29. using Mode = CondExecPredLogical::Mode;
  30. auto is_pred_logical = [](cg::OperatorNodeBase* opr, Mode mode) {
  31. auto as_p = opr->try_cast_final<CondExecPredLogical>();
  32. return as_p && as_p->param().mode == mode;
  33. };
  34. auto opr = rhs->owner()->owner_opr();
  35. if (is_pred_logical(opr, Mode::AND) && opr->input().size() == 1) {
  36. // cross-cn copy of predicate
  37. opr = opr->input(0)->owner_opr();
  38. }
  39. if (is_pred_logical(opr, Mode::OR)) {
  40. // in the grad of SUM_COND_OUT CondExecMerge
  41. auto lvar = lhs->owner();
  42. for (auto i : opr->input()) {
  43. if (lvar == i) {
  44. return true;
  45. }
  46. }
  47. return false;
  48. }
  49. return false;
  50. }
  51. VarNode* proxy_var_from_mask(cg::ExecutionMask* mask) {
  52. auto var = mask->owner();
  53. mgb_assert(var);
  54. auto opr = var->owner_opr();
  55. auto type = opr->dyn_typeinfo();
  56. mgb_assert(type->is<CondExecPred>() || type->is<CondExecPredLogical>(),
  57. "mask not from CondExec opr: %s",
  58. cg::dump_var_info({var}).c_str());
  59. return var;
  60. }
  61. #if MGB_ENABLE_LOGGING
  62. std::string mask2str(cg::ExecutionMask* mask) {
  63. if (!mask) {
  64. return "null";
  65. }
  66. auto var = mask->owner();
  67. mgb_assert(var);
  68. if (var->owner_opr()->same_type<CondExecPred>()) {
  69. return ssprintf("CondExecPred(%s)", var->cname());
  70. }
  71. mgb_assert(var->owner_opr()->same_type<CondExecPredLogical>());
  72. return ssprintf("CondExecPredLogical(%s)", var->cname());
  73. }
  74. #else
  75. std::string mask2str(cg::ExecutionMask*) {
  76. return "";
  77. }
  78. #endif
  79. } // anonymous namespace
  80. /* ============================= CondExecPred ============================= */
  81. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecPred);
  82. class CondExecPred::PredEvaluator {
  83. public:
  84. enum Result { LT, EQ, GT };
  85. PredEvaluator(const CondExecPred& opr, const DeviceTensorND& pred);
  86. Result operator()(const DeviceTensorND& key) {
  87. pre_check(key);
  88. return m_compare(key);
  89. }
  90. private:
  91. CompNode default_cpu = CompNode::default_cpu();
  92. thin_function<Result(const DeviceTensorND&)> m_compare;
  93. void pre_check(const DeviceTensorND& val) {
  94. mgb_assert(val.comp_node() == default_cpu);
  95. mgb_throw_if(!val.shape().is_scalar(), GraphError,
  96. "CondExec predicate or branch key is not scalar: %s",
  97. val.shape().to_string().c_str());
  98. }
  99. };
  100. CondExecPred::PredEvaluator::PredEvaluator(const CondExecPred& opr,
  101. const DeviceTensorND& pred) {
  102. pre_check(pred);
  103. switch (pred.dtype().enumv()) {
  104. #define cbf(dt) \
  105. case DTypeTrait<dt>::enumv: { \
  106. using ct = DTypeTrait<dt>::ctype; \
  107. m_compare = [ eps = opr.m_param.eps, \
  108. p = pred.ptr<ct>()[0] ](const DeviceTensorND& key) { \
  109. ct k = key.ptr<ct>()[0]; \
  110. return std::abs(p - k) < eps ? EQ : (p < k ? LT : GT); \
  111. }; \
  112. break; \
  113. }
  114. #define cbi(dt) \
  115. case DTypeTrait<dt>::enumv: { \
  116. using ct = DTypeTrait<dt>::ctype; \
  117. m_compare = [p = pred.ptr<ct>()[0]](const DeviceTensorND& key) { \
  118. ct k = key.ptr<ct>()[0]; \
  119. return p == k ? EQ : (p < k ? LT : GT); \
  120. }; \
  121. break; \
  122. }
  123. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cbf);
  124. MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cbi)
  125. #undef cbf
  126. #undef cbi
  127. default:
  128. mgb_throw(GraphError, "unsupported pred dtype: %s",
  129. pred.dtype().name());
  130. }
  131. }
  132. class CondExecPred::GlobalRegistry final : public UserDataContainer::UserData {
  133. MGB_TYPEINFO_OBJ_DECL;
  134. SyncEventConnecter::ReceiverHandler m_opr_insert_handler;
  135. ThinHashMap<VarNode*, ExecutionMask*> m_var2mask;
  136. void on_new_opr(OperatorNodeBase* opr);
  137. public:
  138. static GlobalRegistry* get(ComputingGraph& graph) {
  139. using namespace cg::event;
  140. auto ptr = graph.options()
  141. .user_data.get_user_data_or_create<GlobalRegistry>();
  142. if (!ptr->m_opr_insert_handler) {
  143. ptr->m_opr_insert_handler =
  144. graph.event().register_receiver<OprInserted>(
  145. [ptr](const OprInserted& ev) {
  146. if (!ev.is_dedup && !ev.exc) {
  147. ptr->on_new_opr(ev.opr);
  148. }
  149. });
  150. }
  151. return ptr;
  152. }
  153. //! get mask if var is conditional, or nullptr otherwise
  154. ExecutionMask* get_mask_from_var(VarNode* var) const {
  155. auto iter = m_var2mask.find(var);
  156. return iter == m_var2mask.end() ? nullptr : iter->second;
  157. }
  158. //! throw error if var is not controlled by ExecutionMask
  159. ExecutionMask* require_mask_from_var(VarNode* var) const {
  160. auto mask = get_mask_from_var(var);
  161. mgb_throw_if(!mask, GraphError,
  162. "var is not controlled by ExecutionMask: %s",
  163. cg::dump_var_info({var}).c_str());
  164. return mask;
  165. }
  166. //! assert that a var is a PPV
  167. ExecutionMask* check_ppv(VarNode* var) const {
  168. auto mask = require_mask_from_var(var);
  169. mgb_throw_if(mask->owner() != var, GraphError,
  170. "a conditional var is not PPV: mask=%s var=%s",
  171. mask2str(mask).c_str(), cg::dump_var_info({var}).c_str());
  172. return mask;
  173. }
  174. };
  175. MGB_TYPEINFO_OBJ_IMPL(CondExecPred::GlobalRegistry);
  176. void CondExecPred::GlobalRegistry::on_new_opr(OperatorNodeBase* const opr) {
  177. // mask that controls execution of this opr
  178. ExecutionMask* mask = nullptr;
  179. auto opr_type = opr->dyn_typeinfo();
  180. bool opr_is_mark = opr_type->is<CondExecMark>(),
  181. opr_is_merge = opr_type->is<CondExecMerge>(),
  182. opr_is_pred_logical = opr_type->is<CondExecPredLogical>();
  183. using MergeMode = CondExecMerge::Mode;
  184. MergeMode merge_mode =
  185. opr_is_merge ? opr->cast_final<CondExecMerge>().param().mode
  186. : static_cast<MergeMode>(-1);
  187. bool opr_follow_pred =
  188. opr_is_mark ||
  189. (opr_is_merge && merge_mode == MergeMode::SUM_COND_OUT);
  190. // find mask from inputs
  191. auto&& inputs = opr->input();
  192. for (size_t idx = 0; idx < inputs.size(); ++idx) {
  193. auto i_var = inputs[idx];
  194. ExecutionMask* i_mask = nullptr;
  195. auto i_owner = i_var->owner_opr();
  196. bool i_is_pred = false;
  197. if (i_owner->same_type<CondExecPred>() ||
  198. i_owner->same_type<CondExecPredLogical>()) {
  199. i_is_pred = true;
  200. mgb_throw_if(!((opr_follow_pred && i_var == opr->input().back()) ||
  201. opr_is_pred_logical),
  202. GraphError,
  203. "predicate proxy var not received by CondExec "
  204. "mark/merge opr: var=%s recv_opr=%s{%s}",
  205. cg::dump_var_info({i_var}).c_str(), opr->cname(),
  206. opr->dyn_typeinfo()->name);
  207. }
  208. if (opr_follow_pred && i_var == opr->input().back()) {
  209. // CondExecMerge(with SUM_COND_OUT) and CondExecMark are controlled
  210. // by given pred
  211. mgb_assert(i_is_pred);
  212. i_mask = m_var2mask.at(i_var);
  213. if (mask) {
  214. // here we handle the nested case; note that pred is the last
  215. // input, so other inputs have been processed and mask is
  216. // derived from other inputs
  217. mgb_throw_if(!can_prove_imply(i_mask, mask), GraphError,
  218. "can not prove opr mask implies inputs mask: "
  219. "opr=%s{%s}: opr_mask=%s "
  220. "inputs_mask=%s",
  221. opr->cname(), opr->dyn_typeinfo()->name,
  222. mask2str(i_mask).c_str(), mask2str(mask).c_str());
  223. }
  224. mask = i_mask;
  225. break;
  226. }
  227. if (!i_mask) {
  228. auto iter = m_var2mask.find(i_var);
  229. i_mask = iter == m_var2mask.end() ? nullptr : iter->second;
  230. }
  231. if (opr_is_pred_logical && i_mask) {
  232. // CondExecPredLogical should only combine preds from the
  233. // higher-level same mask
  234. i_mask = i_mask->parent();
  235. }
  236. if (opr_is_merge) {
  237. if (merge_mode == MergeMode::SUM &&
  238. idx >= inputs.size() - opr->output().size()) {
  239. // the remaining inputs are output shapes; if they can not be
  240. // statically inferred, their execution mask must be on the same
  241. // level of this CondExecMerge, so we do not modify i_mask
  242. if (cg::is_static_var_value(i_var)) {
  243. // no need to add execution mask for statically inferrable
  244. // values
  245. i_mask = nullptr;
  246. }
  247. } else if (i_mask) {
  248. // execution of merge opr is controlled by mask at a higher
  249. // level
  250. i_mask = i_mask->parent();
  251. }
  252. }
  253. if (i_mask) {
  254. auto lower = ExecutionMask::find_direct_lowest(mask, i_mask);
  255. mgb_throw_if(!lower, GraphError,
  256. "different ExecutionMask trees on inputs of a single "
  257. "opr: opr=%s{%s} mask0=%s mask1=%s",
  258. opr->cname(), opr->dyn_typeinfo()->name,
  259. mask2str(mask).c_str(), mask2str(i_mask).c_str());
  260. mask = lower;
  261. }
  262. }
  263. if (mask) {
  264. mask->register_to_opr(opr);
  265. for (auto i : opr->output()) {
  266. m_var2mask[i] = mask;
  267. }
  268. }
  269. // register nested masks and record var2mask map
  270. if (opr_type->is<CondExecPred>()) {
  271. size_t idx = 0;
  272. for (auto&& i : opr->cast_final<CondExecPred>().masks()) {
  273. if (mask) {
  274. mask->add_nested(i.get());
  275. }
  276. m_var2mask[opr->output(idx++)] = i.get();
  277. }
  278. } else if (opr_is_pred_logical) {
  279. auto m = opr->cast_final<CondExecPredLogical>().mask();
  280. if (mask) {
  281. mask->add_nested(m);
  282. }
  283. m_var2mask[opr->output(0)] = m;
  284. }
  285. }
  286. CondExecPred::CondExecPred(VarNode* pred, const VarNodeArrayView& keys,
  287. const Param& param, const OperatorNodeConfig& config)
  288. : Super(pred->owner_graph(), config, "cond_pred", {pred}),
  289. m_param{param} {
  290. m_masks.reserve(keys.size() + 1);
  291. auto add_out = [this](const std::string& name) {
  292. auto var = add_output(name);
  293. var->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC).dtype(dtype::Int32{});
  294. m_masks.emplace_back(std::make_shared<ExecutionMask>(var));
  295. };
  296. for (size_t i = 0; i < keys.size(); ++i) {
  297. mgb_throw_if(keys[i]->dtype() != pred->dtype(), GraphError,
  298. "dtype mismatch: pred=%s input[%zu]=%s",
  299. pred->dtype().name(), i, keys[i]->dtype().name());
  300. add_input({keys[i]});
  301. if (param.mode == Param::Mode::PIECEWISE) {
  302. if (!i) {
  303. add_out("[-inf,k0]");
  304. }
  305. if (i != keys.size() - 1) {
  306. add_out(ssprintf("[k%zu,k%zu]", i, i + 1));
  307. } else {
  308. add_out(ssprintf("[k%zu,inf]", i));
  309. }
  310. } else {
  311. add_out(ssprintf("branch%zu", i));
  312. }
  313. }
  314. if (param.mode == Param::Mode::CASE_FALLBACK) {
  315. add_out("fallback");
  316. }
  317. add_input({pred});
  318. add_equivalence_component<PODHash<Param>>(&m_param);
  319. // ensure listener is registered
  320. GlobalRegistry::get(*owner_graph());
  321. }
  322. cg::OperatorNodeBase* CondExecPred::make_opr(SymbolVar pred,
  323. const VarNodeArrayView& keys,
  324. const Param& param,
  325. const OperatorNodeConfig& config) {
  326. return pred.node()->owner_graph()->insert_opr(
  327. std::make_unique<CondExecPred>(pred.node(), keys, param, config));
  328. }
  329. void CondExecPred::init_output_static_infer_desc() {
  330. using namespace cg::static_infer;
  331. auto&& mgr = owner_graph()->static_infer_manager();
  332. for (auto i : output()) {
  333. mgr.register_shape_infer(i, ShapeInferDesc::make_const({1}));
  334. }
  335. auto reg_value_infer_no_const = [&mgr](VarNode* var, ValueInferDesc& desc) {
  336. auto orig_size = desc.deps.size();
  337. mixin::ForwardInputToOutput::ensure_not_replaced_by_const_folding(desc);
  338. mgr.register_value_infer(var, desc);
  339. if (desc.deps.size() != orig_size) {
  340. // remove newly added dep
  341. mgb_assert(desc.deps.size() == orig_size + 1);
  342. desc.deps.pop_back();
  343. }
  344. };
  345. size_t nr_key = input().size() - 1;
  346. auto mode = m_param.mode;
  347. if (mode == Mode::CASE || mode == Mode::CASE_FALLBACK) {
  348. auto infer_val_eq = [this](DeviceTensorND& dest, const InpVal& inp) {
  349. auto&& pv = inp.val[0].value();
  350. auto&& key = inp.val[1].value();
  351. dest.resize({1}).ptr<int>()[0] =
  352. (PredEvaluator{*this, pv}(key) == PredEvaluator::EQ);
  353. return true;
  354. };
  355. ValueInferDesc desc{
  356. SourceType::DEP,
  357. {{input().back(), DepType::VALUE}, {nullptr, DepType::VALUE}},
  358. infer_val_eq};
  359. for (size_t i = 0; i < nr_key; ++i) {
  360. desc.deps[1].dest = input(i);
  361. reg_value_infer_no_const(output(i), desc);
  362. }
  363. if (mode == Mode::CASE_FALLBACK) {
  364. desc.deps.clear();
  365. for (size_t i = 0; i < nr_key; ++i) {
  366. desc.deps.push_back({output(i), DepType::VALUE});
  367. }
  368. desc.infer_func = [](DeviceTensorND& dest, const InpVal& inp) {
  369. int r = 1;
  370. for (auto&& i : inp.val) {
  371. if (i.value().ptr<int>()[0]) {
  372. r = 0;
  373. break;
  374. }
  375. }
  376. dest.resize({1}).ptr<int>()[0] = r;
  377. return true;
  378. };
  379. reg_value_infer_no_const(output().back(), desc);
  380. }
  381. } else {
  382. mgb_assert(mode == Mode::PIECEWISE);
  383. auto infer_first = [this](DeviceTensorND& dest, const InpVal& inp) {
  384. auto&& pv = inp.val[0].value();
  385. auto&& key = inp.val[1].value();
  386. dest.resize({1}).ptr<int>()[0] =
  387. (PredEvaluator{*this, pv}(key) == PredEvaluator::LT);
  388. return true;
  389. };
  390. auto infer_mid = [this](DeviceTensorND& dest, const InpVal& inp) {
  391. auto&& pv = inp.val[0].value();
  392. auto&& left = inp.val[1].value();
  393. auto&& right = inp.val[2].value();
  394. PredEvaluator eval{*this, pv};
  395. auto el = eval(left), er = eval(right);
  396. dest.resize({1}).ptr<int>()[0] =
  397. (el != PredEvaluator::LT && er == PredEvaluator::LT);
  398. return true;
  399. };
  400. auto infer_last = [this](DeviceTensorND& dest, const InpVal& inp) {
  401. auto&& pv = inp.val[0].value();
  402. auto&& key = inp.val[1].value();
  403. dest.resize({1}).ptr<int>()[0] =
  404. (PredEvaluator{*this, pv}(key) != PredEvaluator::LT);
  405. return true;
  406. };
  407. // (-inf, key[0])
  408. ValueInferDesc desc{
  409. SourceType::DEP,
  410. {{input().back(), DepType::VALUE}, {input(0), DepType::VALUE}},
  411. infer_first};
  412. reg_value_infer_no_const(output(0), desc);
  413. // [key[i-1], key[i])
  414. desc.deps.push_back({nullptr, DepType::VALUE});
  415. desc.infer_func = infer_mid;
  416. for (size_t i = 1; i < nr_key; ++i) {
  417. desc.deps[1].dest = input(i - 1);
  418. desc.deps[2].dest = input(i);
  419. reg_value_infer_no_const(output(i), desc);
  420. }
  421. // [key[n-1], inf)
  422. desc.deps.resize(2);
  423. desc.deps[1].dest = input(nr_key - 1);
  424. desc.infer_func = infer_last;
  425. reg_value_infer_no_const(output(nr_key), desc);
  426. }
  427. }
  428. CondExecPred::NodeProp* CondExecPred::do_make_node_prop() const {
  429. auto ret = Super::do_make_node_prop();
  430. for (auto&& i : ret->dep_map()) {
  431. i.second = NodeProp::DepType::HOST_VALUE;
  432. }
  433. return ret;
  434. }
  435. void CondExecPred::scn_do_execute() {
  436. auto&& mgr = owner_graph()->static_infer_manager();
  437. PredEvaluator eval{*this, mgr.infer_value(input().back())};
  438. auto mode = m_param.mode;
  439. if (mode == Mode::CASE || mode == Mode::CASE_FALLBACK) {
  440. bool enabled = false;
  441. for (size_t i = 0; i < input().size() - 1; ++i) {
  442. auto cur = eval(mgr.infer_value(input(i))) == PredEvaluator::EQ;
  443. m_masks[i]->enable(cur);
  444. enabled |= cur;
  445. }
  446. if (mode == Mode::CASE_FALLBACK) {
  447. m_masks.back()->enable(!enabled);
  448. }
  449. } else {
  450. mgb_assert(mode == Mode::PIECEWISE);
  451. const DeviceTensorND *val_prev = nullptr, *val_cur = nullptr;
  452. for (size_t i = 0; i < input().size(); ++i) {
  453. val_prev = val_cur;
  454. if (i == input().size() - 1) {
  455. val_cur = nullptr;
  456. } else {
  457. val_cur = &mgr.infer_value(input(i));
  458. }
  459. PredEvaluator::Result el, er;
  460. if (!val_prev) {
  461. el = PredEvaluator::GT;
  462. } else {
  463. el = eval(*val_prev);
  464. }
  465. if (!val_cur) {
  466. er = PredEvaluator::LT;
  467. } else {
  468. er = eval(*val_cur);
  469. }
  470. m_masks[i]->enable(el != PredEvaluator::LT &&
  471. er == PredEvaluator::LT);
  472. }
  473. }
  474. }
  475. VarNode* CondExecPred::out_var_from_mask(ExecutionMask* mask) const {
  476. for (size_t i = 0; i < output().size(); ++i) {
  477. if (mask == m_masks[i].get()) {
  478. return output(i);
  479. }
  480. }
  481. mgb_throw(AssertionError, "bad mask");
  482. }
  483. /* ========================== CondExecPredLogical ========================== */
  484. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecPredLogical);
  485. class CondExecPredLogical::PredEvaluator {
  486. //! return false to early stop
  487. bool (*m_updater)(int*, int);
  488. int m_cur_val, m_negate = 0;
  489. public:
  490. explicit PredEvaluator(Mode mode, int init) : m_cur_val{init} {
  491. auto fn_or = [](int* dst, int v) -> bool {
  492. *dst |= v;
  493. return !*dst;
  494. };
  495. auto fn_and = [](int* dst, int v) -> bool {
  496. *dst &= v;
  497. return *dst;
  498. };
  499. auto fn_xor = [](int* dst, int v) -> bool {
  500. *dst ^= v;
  501. return true;
  502. };
  503. switch (mode) {
  504. case Mode::NOR:
  505. m_negate = 1;
  506. // falls through
  507. case Mode::OR:
  508. m_updater = fn_or;
  509. break;
  510. case Mode::NAND:
  511. m_negate = 1;
  512. // falls through
  513. case Mode::AND:
  514. m_updater = fn_and;
  515. break;
  516. case Mode::XNOR:
  517. m_negate = 1;
  518. // falls through
  519. case Mode::XOR:
  520. m_updater = fn_xor;
  521. break;
  522. default:
  523. mgb_throw(MegBrainError, "invalid CondExecPredLogical mode");
  524. }
  525. }
  526. //! return false to early stop
  527. bool update(int val) { return m_updater(&m_cur_val, val); }
  528. bool get() const { return m_cur_val ^ m_negate; }
  529. };
  530. CondExecPredLogical::CondExecPredLogical(const VarNodeArrayView& preds,
  531. const Param& param,
  532. const OperatorNodeConfig& config)
  533. : Super(preds.at(0)->owner_graph(), config,
  534. mgb_cstr_log(mode2str(param.mode)), preds),
  535. m_param{param} {
  536. m_input_masks.resize(preds.size());
  537. auto gr = CondExecPred::GlobalRegistry::get(*owner_graph());
  538. for (size_t i = 0; i < preds.size(); ++i) {
  539. m_input_masks[i] = gr->require_mask_from_var(preds[i]);
  540. add_input({preds[i]}, i == preds.size() - 1 ? AddInputSortType::ALL
  541. : AddInputSortType::NONE);
  542. }
  543. add_output(None)
  544. ->dtype(dtype::Int32{})
  545. .add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC);
  546. m_mask = std::make_shared<ExecutionMask>(output(0));
  547. add_equivalence_component<PODHash<Param>>(&m_param);
  548. }
  549. SymbolVar CondExecPredLogical::make(const VarNodeArrayView& preds,
  550. const Param& param,
  551. const OperatorNodeConfig& config) {
  552. mgb_assert(!preds.empty());
  553. if (preds.size() == 1) {
  554. if (!config.has_comp_node_set() ||
  555. config.get_single_comp_node() == preds[0]->comp_node()) {
  556. auto m = param.mode;
  557. if (m == Mode::OR || m == Mode::XOR || m == Mode::AND) {
  558. return preds[0];
  559. }
  560. }
  561. }
  562. return SymbolVar{preds[0]}.insert_single_output_opr<CondExecPredLogical>(
  563. preds, param, config);
  564. }
  565. void CondExecPredLogical::init_output_static_infer_desc() {
  566. using namespace cg::static_infer;
  567. auto&& mgr = owner_graph()->static_infer_manager();
  568. mgr.register_shape_infer(output(0), ShapeInferDesc::make_const({1}));
  569. auto infer_val = [mode = m_param.mode](DeviceTensorND & dst,
  570. const InpVal& inp) {
  571. PredEvaluator eval{mode, inp.val[0].value().ptr<int>()[0]};
  572. for (size_t i = 1; i < inp.val.size(); ++i) {
  573. if (!eval.update(inp.val[i].value().ptr<int>()[0])) {
  574. break;
  575. }
  576. }
  577. dst.resize({1}).ptr<int>()[0] = eval.get();
  578. return true;
  579. };
  580. ValueInferDesc desc;
  581. desc.src_type = SourceType::DEP;
  582. desc.deps.reserve(input().size());
  583. for (auto i : input()) {
  584. desc.deps.push_back({i, DepType::VALUE});
  585. }
  586. desc.infer_func = infer_val;
  587. mgr.register_value_infer(output(0), desc);
  588. }
  589. void CondExecPredLogical::scn_do_execute() {
  590. PredEvaluator eval{m_param.mode, m_input_masks[0]->enabled()};
  591. for (size_t i = 1; i < m_input_masks.size(); ++i) {
  592. if (!eval.update(m_input_masks[i]->enabled())) {
  593. break;
  594. }
  595. }
  596. m_mask->enable(eval.get());
  597. }
  598. CondExecPredLogical::NodeProp* CondExecPredLogical::do_make_node_prop() const {
  599. auto ret = Super::do_make_node_prop();
  600. for (auto&& i : ret->dep_map()) {
  601. i.second = NodeProp::DepType::DEV_COMP_ORDER;
  602. }
  603. ret->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  604. return ret;
  605. }
  606. const char* CondExecPredLogical::mode2str(Mode mode) {
  607. switch (mode) {
  608. #define CASE(n) \
  609. case Mode::n: \
  610. return #n
  611. CASE(OR);
  612. CASE(AND);
  613. CASE(XOR);
  614. CASE(NOR);
  615. CASE(NAND);
  616. CASE(XNOR);
  617. default:
  618. mgb_throw(MegBrainError, "bad CondExecPredLogical mode: %d",
  619. static_cast<int>(mode));
  620. }
  621. }
  622. /* ============================= CondExecMark ============================= */
  623. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecMark);
  624. CondExecMark::CondExecMark(VarNode* ppv, const VarNodeArrayView& inputs,
  625. const Param& param, const OperatorNodeConfig& config)
  626. : Super(ppv->owner_graph(), config, "cond_mark", {ppv}),
  627. m_param{param} {
  628. CondExecPred::GlobalRegistry::get(*owner_graph())->check_ppv(ppv);
  629. for (size_t i = 0; i < inputs.size(); ++i) {
  630. add_input({inputs[i]});
  631. add_output(ssprintf("fwd%zu", i))
  632. ->dtype(inputs[i]->dtype())
  633. .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  634. }
  635. add_input({ppv});
  636. add_equivalence_component<PODHash<Param>>(&m_param);
  637. if (has_no_shape_infer()) {
  638. for (auto i : input()) {
  639. // force dynamic allocation of input so storage can be forwarded
  640. i->add_flag(VarNode::Flag::NO_SYS_STATIC_MEM_ALLOC);
  641. }
  642. for (auto i : output()) {
  643. i->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC);
  644. }
  645. } else {
  646. m_mem_fwd_success.resize(inputs.size(), false);
  647. }
  648. }
  649. void CondExecMark::init_output_static_infer_desc() {
  650. using namespace cg::static_infer;
  651. auto&& mgr = owner_graph()->static_infer_manager();
  652. using InferMode = Param::StaticInfer;
  653. auto infer_mode = param().static_infer;
  654. if (infer_mode == InferMode::NONE) {
  655. return;
  656. }
  657. for (size_t i = 0; i < output().size(); ++i) {
  658. auto s = input(i), t = output(i);
  659. mgr.register_shape_infer(t, ShapeInferDesc::make_identity(s));
  660. if (infer_mode != InferMode::SHAPE_ONLY) {
  661. auto desc = ValueInferDesc::make_identity(s);
  662. mixin::ForwardInputToOutput::ensure_not_replaced_by_const_folding(
  663. desc);
  664. mgr.register_value_infer(t, desc);
  665. }
  666. }
  667. }
  668. void CondExecMark::scn_do_execute() {
  669. bool no_sys_alloc = has_no_shape_infer();
  670. for (size_t i = 0; i < output().size(); ++i) {
  671. if (no_sys_alloc) {
  672. bool succ = output(i)->reset_dev_tensor_from_other_var(input(i));
  673. MGB_MARK_USED_VAR(succ);
  674. } else {
  675. auto &&out = output(i)->dev_tensor(),
  676. &&inp = input(i)->dev_tensor();
  677. if (m_mem_fwd_success[i]) {
  678. mgb_assert(inp.raw_ptr() == out.raw_ptr() &&
  679. out.layout().eq_layout(inp.layout()));
  680. } else {
  681. out.copy_from_fixlayout(inp);
  682. }
  683. }
  684. }
  685. }
  686. void CondExecMark::init_rt_force_dynamic_mem_alloc_imply_chain() {
  687. if (has_no_shape_infer()) {
  688. return;
  689. }
  690. for (size_t i = 0; i < output().size(); ++i) {
  691. auto s = input(i), t = output(i);
  692. s->add_rt_force_dynamic_mem_alloc_imply_chain(t);
  693. t->add_rt_force_dynamic_mem_alloc_imply_chain(s);
  694. }
  695. }
  696. void CondExecMark::mem_plan_fwd_in2out_readonly() {
  697. if (has_no_shape_infer()) {
  698. return;
  699. }
  700. for (size_t i = 0; i < output().size(); ++i) {
  701. auto s = input(i), t = output(i);
  702. m_mem_fwd_success[i] = t->set_fwd_in2out_readonly(
  703. s, SubTensorSpec::make_from_layout(s->layout()));
  704. }
  705. }
  706. void CondExecMark::add_input_layout_constraint() {
  707. if (has_no_shape_infer()) {
  708. for (auto i : input()) {
  709. // reset_dev_tensor_from_other_var already has such requirement
  710. i->add_layout_constraint_contiguous();
  711. }
  712. }
  713. }
  714. CondExecMark::NodeProp* CondExecMark::do_make_node_prop() const {
  715. auto ret = Super::do_make_node_prop();
  716. ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER;
  717. for (size_t i = 0; i < input().size() - 1; ++ i) {
  718. ret->add_dep_type_existing_var(input(i),
  719. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  720. }
  721. return ret;
  722. }
  723. cg::OperatorNodeBase* CondExecMark::make_opr(SymbolVar ppv,
  724. const VarNodeArrayView& inputs,
  725. const Param& param,
  726. const OperatorNodeConfig& config) {
  727. return ppv.node()->owner_graph()->insert_opr(
  728. std::make_unique<CondExecMark>(ppv.node(), inputs, param, config));
  729. }
  730. SymbolVar CondExecMark::mark_if_need(SymbolVar maybe_ppv, SymbolVar input,
  731. const Param& param,
  732. const OperatorNodeConfig& config) {
  733. auto mask =
  734. CondExecPred::GlobalRegistry::get(*maybe_ppv.node()->owner_graph())
  735. ->get_mask_from_var(maybe_ppv.node());
  736. if (mask) {
  737. return make_opr(mask->owner(), {input}, param, config)->output(0);
  738. }
  739. return input;
  740. }
  741. MGB_IMPL_OPR_GRAD(CondExecMark) {
  742. if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) {
  743. return nullptr;
  744. }
  745. using GradMode = CondExecMark::Param::GradMode;
  746. using MergeMode = CondExecMerge::Param::Mode;
  747. MergeMode grad_mode;
  748. SymbolVarArray grad_shapes;
  749. switch (opr.param().grad_mode) {
  750. case GradMode::SUM:
  751. grad_mode = MergeMode::SUM;
  752. grad_shapes.emplace_back(SymbolVar{opr.input(wrt_idx)}.symshape());
  753. break;
  754. case GradMode::SUM_COND_OUT:
  755. grad_mode = MergeMode::SUM_COND_OUT;
  756. break;
  757. default:
  758. mgb_throw(MegBrainError, "invalid grad_mode");
  759. }
  760. return CondExecMerge::make_opr({out_grad[wrt_idx]}, grad_shapes,
  761. {1, grad_mode}, OperatorNodeConfig{})
  762. ->output(0);
  763. }
  764. /* ============================= CondExecMerge ============================= */
  765. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecMerge);
  766. CondExecMerge::CondExecMerge(const VarNodeArrayView& inputs,
  767. const VarNodeArrayView& out_shapes,
  768. const Param& param,
  769. const OperatorNodeConfig& config)
  770. : Super(inputs[0]->owner_graph(), config, "cond_merge", {}),
  771. m_param{param} {
  772. mgb_throw_if(inputs.size() % param.nr_output, GraphError,
  773. "input size can not divide nr_output: %zu %u", inputs.size(),
  774. param.nr_output);
  775. auto global_registry = CondExecPred::GlobalRegistry::get(*owner_graph());
  776. auto nr_branch = inputs.size() / param.nr_output;
  777. mgb_assert(param.nr_output);
  778. for (size_t i = 0; i < param.nr_output; ++i) {
  779. auto ovar = add_output(ssprintf("out%zu", i));
  780. ovar->dtype(inputs[i]->dtype());
  781. // disable system memory allocation because:
  782. // 1. we can directly forward input storage to output
  783. // 2. dynamic allocator would wait for all inputs to become ready (see
  784. // VarNodeMemManager::DynamicAllocOprInfo::host_wait_input_ready),
  785. // which would cause infinite waiting for unselected inputs.
  786. ovar->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
  787. .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  788. }
  789. MGB_MARK_USED_VAR(mask2str);
  790. m_branch_masks.resize(nr_branch, nullptr);
  791. for (size_t i = 0; i < nr_branch; ++i) {
  792. ExecutionMask* br_mask = nullptr;
  793. for (size_t j = 0; j < param.nr_output; ++j) {
  794. auto ivar = inputs[i * param.nr_output + j];
  795. auto mask = global_registry->require_mask_from_var(ivar);
  796. mgb_throw_if(
  797. output(j)->dtype() != ivar->dtype(), GraphError,
  798. "CondExecMerge input dtypes mismatch: branch=%zu %s vs %s",
  799. i, output(j)->dtype().name(), ivar->dtype().name());
  800. if (!j) {
  801. br_mask = mask;
  802. } else {
  803. mgb_throw_if(br_mask != mask, GraphError,
  804. "CondExecMerge branch %zu have different masks: "
  805. "%s vs %s",
  806. i, mask2str(br_mask).c_str(),
  807. mask2str(mask).c_str());
  808. }
  809. // this flag is added by ExecutionMask; we require flag because
  810. // output var might forward input var storage
  811. mgb_assert(
  812. ivar->contain_flag(VarNode::Flag::NO_SYS_STATIC_MEM_ALLOC));
  813. add_input({ivar});
  814. }
  815. m_branch_masks[i] = br_mask;
  816. }
  817. add_equivalence_component<PODHash<Param>>(&m_param);
  818. // handle extra inputs for special modes
  819. if (param.mode == Mode::SUM) {
  820. mgb_assert(out_shapes.size() == param.nr_output);
  821. for (auto i : out_shapes) {
  822. add_input({i});
  823. }
  824. } else {
  825. mgb_assert(out_shapes.empty(),
  826. "out_shapes should not be given if mode is not SUM");
  827. }
  828. if (param.mode == Mode::SUM_COND_OUT) {
  829. VarNodeArray preds;
  830. preds.reserve(nr_branch);
  831. for (auto i : m_branch_masks) {
  832. preds.emplace_back(proxy_var_from_mask(i));
  833. }
  834. auto cn = mixin_infer_output_comp_node(*this, true);
  835. auto preds_or = CondExecPredLogical::make(
  836. preds, CondExecPredLogical::Mode::OR, cn);
  837. add_input({preds_or.node()});
  838. }
  839. }
  840. cg::OperatorNodeBase* CondExecMerge::make_opr(
  841. const VarNodeArrayView& inputs, const VarNodeArrayView& out_shapes,
  842. const Param& param, const OperatorNodeConfig& config) {
  843. mgb_assert(!inputs.empty());
  844. const VarNodeArrayView* out_shapes_ptr = &out_shapes;
  845. Maybe<VarNodeArrayView> out_shapes_from_inp;
  846. VarNodeArray out_shapes_from_inp_storage;
  847. if (out_shapes.empty() && param.mode == Mode::SUM) {
  848. // find out_shapes from inputs
  849. mgb_assert(inputs.size() % param.nr_output == 0);
  850. size_t nr_branch = inputs.size() / param.nr_output;
  851. auto inp = [&](size_t br, size_t oidx) {
  852. return inputs[br * param.nr_output + oidx];
  853. };
  854. for (size_t oidx = 0; oidx < param.nr_output; ++oidx) {
  855. bool found = false;
  856. for (size_t br = 0; br < nr_branch; ++br) {
  857. auto ivar = inp(br, oidx);
  858. if (cg::is_static_var_shape(ivar)) {
  859. found = true;
  860. out_shapes_from_inp_storage.push_back(
  861. SymbolVar{ivar}.symshape().node());
  862. break;
  863. }
  864. }
  865. mgb_throw_if(!found, GraphError,
  866. "out_shapes is omitted but no input shape is "
  867. "inferrable for output %zu",
  868. oidx);
  869. }
  870. out_shapes_ptr =
  871. &out_shapes_from_inp.emplace(out_shapes_from_inp_storage);
  872. }
  873. return inputs[0]->owner_graph()->insert_opr(std::make_unique<CondExecMerge>(
  874. inputs, *out_shapes_ptr, param, config));
  875. }
  876. void CondExecMerge::init_output_static_infer_desc() {
  877. using namespace cg::static_infer;
  878. auto&& mgr = owner_graph()->static_infer_manager();
  879. auto nr_out = m_param.nr_output;
  880. auto inp = [this, nr_out](size_t branch, size_t oidx) {
  881. return input(branch * nr_out + oidx);
  882. };
  883. static auto select_one_branch = [](size_t nr_branch,
  884. const InpVal& bval) -> size_t {
  885. bool found = false;
  886. size_t ret;
  887. for (size_t i = 0; i < nr_branch; ++i) {
  888. if (bval.val[i].value().ptr<int>()[0]) {
  889. if (!found) {
  890. found = true;
  891. ret = i;
  892. } else {
  893. mgb_throw(GraphError,
  894. "multiple branches are active in EXACT_ONE mode: "
  895. "%zu and %zu",
  896. ret, i);
  897. }
  898. }
  899. }
  900. mgb_throw_if(!found, GraphError,
  901. "no branch is active in EXACT_ONE mode");
  902. return ret;
  903. };
  904. DepVal branch_deps;
  905. auto nr_branch = m_branch_masks.size();
  906. branch_deps.reserve(nr_branch);
  907. for (size_t i = 0; i < nr_branch; ++i) {
  908. branch_deps.push_back(
  909. {proxy_var_from_mask(m_branch_masks[i]), DepType::VALUE});
  910. }
  911. // register shape and value infers for each output
  912. for (size_t oidx = 0; oidx < nr_out; oidx++) {
  913. if (m_param.mode == Mode::EXACT_ONE_SAME_SHAPE ||
  914. m_param.mode == Mode::SUM_COND_OUT) {
  915. // all branches should have the same shape
  916. bool found = false;
  917. // find any inferrable input var
  918. for (size_t i = 0; i < nr_branch; ++i) {
  919. if (cg::is_static_var_shape(inp(i, oidx))) {
  920. mgr.register_shape_infer(
  921. output(oidx),
  922. ShapeInferDesc::make_identity(inp(i, oidx)));
  923. found = true;
  924. break;
  925. }
  926. }
  927. if (!found) {
  928. mgr.register_shape_infer(
  929. output(oidx),
  930. ShapeInferDesc::make_identity(inp(0, oidx)));
  931. }
  932. } else if (m_param.mode == Mode::SUM) {
  933. auto infer_fn = [](TensorShape& dst, const InpVal& inp) {
  934. cg::copy_tensor_value_to_shape(dst, inp.val[0].value());
  935. return true;
  936. };
  937. mgr.register_shape_infer(output(oidx),
  938. {SourceType::DEP,
  939. {{inp(nr_branch, oidx), DepType::VALUE}},
  940. infer_fn});
  941. } else {
  942. // general shape inference for EXACT_ONE mode
  943. auto infer_fn = [this](TensorShape& dest, const InpVal& inp) {
  944. auto nr_branch = m_branch_masks.size();
  945. size_t branch = select_one_branch(nr_branch, inp);
  946. dest = inp.val.at(nr_branch + branch).shape();
  947. return true;
  948. };
  949. ShapeInferDesc desc{SourceType::DEP, branch_deps, infer_fn};
  950. for (size_t i = 0; i < nr_branch; ++i) {
  951. desc.deps.push_back({inp(i, oidx), DepType::SHAPE});
  952. }
  953. mgr.register_shape_infer(output(oidx), desc);
  954. }
  955. // general value inference
  956. ValueInferDesc desc{SourceType::DEP, branch_deps, {}};
  957. for (size_t i = 0; i < nr_branch; ++i) {
  958. desc.deps.push_back({inp(i, oidx), DepType::VALUE});
  959. }
  960. if (is_exact_one()) {
  961. desc.infer_func = [this](DeviceTensorND& dest, const InpVal& inp) {
  962. auto nr_branch = m_branch_masks.size();
  963. size_t branch = select_one_branch(nr_branch, inp);
  964. dest = inp.val.at(nr_branch + branch).value();
  965. return true;
  966. };
  967. } else {
  968. mgb_assert(m_param.mode == Mode::SUM ||
  969. m_param.mode == Mode::SUM_COND_OUT);
  970. desc.infer_func = [this](DeviceTensorND& dest, const InpVal& inp) {
  971. auto nr_branch = m_branch_masks.size();
  972. bool found = false, first = true;
  973. auto&& shape = inp.val.at(nr_branch).shape();
  974. for (size_t i = 0; i < nr_branch && !shape.is_empty(); ++i) {
  975. if (!inp.val[i].value().ptr<int>()[0])
  976. continue;
  977. auto&& cur = inp.val.at(nr_branch + i).value();
  978. // add cur value to dest
  979. if (!found) {
  980. found = true;
  981. dest = cur;
  982. } else {
  983. if (first) {
  984. first = false;
  985. DeviceTensorND tmp;
  986. tmp.copy_from(dest);
  987. dest = std::move(tmp);
  988. }
  989. // comp node is cpu default, so it is safe to use a
  990. // temporary megdnn opr here
  991. auto dnn_opr =
  992. intl::create_megdnn_opr<megdnn::Elemwise>(
  993. dest.comp_node());
  994. dnn_opr->param().mode = Elemwise::Mode::ADD;
  995. dnn_opr->exec({dest.as_megdnn(), cur.as_megdnn()},
  996. dest.as_megdnn());
  997. }
  998. }
  999. if (!found) {
  1000. if (dest.storage().raw_storage().use_count() > 1) {
  1001. // likely to be assigned from some input in previous
  1002. // runs; we create a new tensor to avoid modifying input
  1003. // value
  1004. DeviceTensorND tmp{dest.comp_node(), shape,
  1005. dest.dtype()};
  1006. dest = std::move(tmp);
  1007. } else {
  1008. dest.resize(shape);
  1009. }
  1010. fill_zero_dev_tensor(dest);
  1011. }
  1012. return true;
  1013. };
  1014. }
  1015. mgr.register_value_infer(output(oidx), desc);
  1016. }
  1017. }
  1018. void CondExecMerge::scn_do_execute() {
  1019. auto nr_out = m_param.nr_output;
  1020. auto inp = [this, nr_out](size_t branch, size_t oidx) {
  1021. return input(branch * nr_out + oidx);
  1022. };
  1023. auto cn = this->comp_node();
  1024. mgb_assert(cn == output(0)->comp_node());
  1025. bool first = true;
  1026. auto&& forwarded = m_mem_forwarded;
  1027. std::vector<bool> is_shape_empty(nr_out, false);
  1028. for (size_t br = 0; br < m_branch_masks.size(); ++br) {
  1029. if (!m_branch_masks[br]->enabled()) {
  1030. continue;
  1031. }
  1032. if (first) {
  1033. first = false;
  1034. for (size_t oidx = 0; oidx < nr_out; ++oidx) {
  1035. bool succ = output(oidx)->reset_dev_tensor_from_other_var(
  1036. inp(br, oidx));
  1037. if (inp(br, oidx)->shape().is_empty()) {
  1038. is_shape_empty[oidx] = true;
  1039. continue;
  1040. }
  1041. if (!is_exact_one()) {
  1042. if (forwarded.empty()) {
  1043. forwarded.resize(nr_out);
  1044. }
  1045. forwarded[oidx] = succ;
  1046. }
  1047. }
  1048. } else {
  1049. mgb_throw_if(is_exact_one(), GraphError,
  1050. "multiple branches are active in EXACT_ONE mode");
  1051. auto&& dnn_opr = m_exec_dnn_opr;
  1052. if (!dnn_opr || dnn_opr.comp_node() != cn) {
  1053. dnn_opr = intl::create_megdnn_opr<megdnn::Elemwise>(cn);
  1054. dnn_opr->param().mode = Elemwise::Mode::ADD;
  1055. }
  1056. for (size_t oidx = 0; oidx < nr_out; ++oidx) {
  1057. auto ovar = output(oidx);
  1058. auto&& src = inp(br, oidx)->dev_tensor().as_megdnn();
  1059. auto&& dest = ovar->dev_tensor().as_megdnn();
  1060. mgb_assert(src.layout.eq_shape(dest.layout),
  1061. "shape mismatch: %s vs %s in CondExecMerge",
  1062. src.layout.to_string().c_str(),
  1063. dest.layout.to_string().c_str());
  1064. if (is_shape_empty[oidx]) continue;
  1065. if (forwarded[oidx]) {
  1066. ovar->shape_alloc(ovar->shape());
  1067. auto&& own_dest = ovar->dev_tensor().as_megdnn();
  1068. mgb_assert(own_dest.raw_ptr != dest.raw_ptr);
  1069. dnn_opr->exec({dest, src}, own_dest);
  1070. forwarded[oidx] = false;
  1071. } else {
  1072. dnn_opr->exec({dest, src}, dest);
  1073. }
  1074. }
  1075. }
  1076. }
  1077. if (first) {
  1078. mgb_throw_if(is_exact_one(), GraphError,
  1079. "no branch is selected in EXACT_ONE mode");
  1080. mgb_assert(m_param.mode == Param::Mode::SUM);
  1081. auto&& mgr = owner_graph()->static_infer_manager();
  1082. for (auto var : output()) {
  1083. auto&& dv = var->shape_alloc(mgr.infer_shape(var)).dev_tensor();
  1084. fill_zero_dev_tensor(dv);
  1085. }
  1086. } else if (m_param.mode == Param::Mode::SUM) {
  1087. auto&& mgr = owner_graph()->static_infer_manager();
  1088. for (auto var : output()) {
  1089. auto&& shp_infer = mgr.infer_shape(var);
  1090. auto&& shp_got = var->shape();
  1091. mgb_throw_if(!shp_infer.eq_shape(shp_got), GraphError,
  1092. "inferred shape is %s, actual shape is %s",
  1093. shp_infer.to_string().c_str(),
  1094. shp_got.to_string().c_str());
  1095. }
  1096. }
  1097. }
  1098. void CondExecMerge::add_input_layout_constraint() {
  1099. for (auto i : input()) {
  1100. // reset_dev_tensor_from_other_var already has such requirement
  1101. i->add_layout_constraint_contiguous();
  1102. }
  1103. }
  1104. CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const {
  1105. auto ret = Super::do_make_node_prop();
  1106. using DT = NodeProp::DepType;
  1107. if (m_param.mode == Mode::SUM) {
  1108. SmallVector<DT> inp_dt(input().size(), DT::DEV_VALUE);
  1109. for (size_t i = 0; i < m_param.nr_output; ++i) {
  1110. inp_dt[inp_dt.size() - i - 1] = DT::HOST_VALUE;
  1111. }
  1112. ret->reset_dep_type(input(), inp_dt);
  1113. } else if (m_param.mode == Mode::SUM_COND_OUT) {
  1114. // PPV can not be used as a usual input, so we can modify dep_map
  1115. // directly
  1116. ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER;
  1117. }
  1118. for (size_t i = 0; i < m_param.nr_output * m_branch_masks.size(); ++ i) {
  1119. ret->add_dep_type_existing_var(input(i),
  1120. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  1121. }
  1122. return ret;
  1123. }
  1124. MGB_IMPL_OPR_GRAD(CondExecMerge) {
  1125. using Mode = CondExecMerge::Param::Mode;
  1126. if (opr.param().mode == Mode::SUM_COND_OUT &&
  1127. wrt_idx == opr.input().size() - 1) {
  1128. return nullptr;
  1129. }
  1130. if (opr.param().mode == Mode::SUM &&
  1131. wrt_idx >= opr.input().size() - opr.output().size()) {
  1132. return InvalidGrad::make(opr, wrt_idx);
  1133. }
  1134. size_t wrt_branch = wrt_idx / opr.param().nr_output,
  1135. wrt_oidx = wrt_idx % opr.param().nr_output;
  1136. auto og = out_grad.at(wrt_oidx);
  1137. if (!og) {
  1138. return nullptr;
  1139. }
  1140. auto ppv = proxy_var_from_mask(opr.branch_mask(wrt_branch));
  1141. if (ppv->comp_node().mem_node() != og->comp_node().mem_node()) {
  1142. ppv = CondExecPredLogical::make({ppv}, CondExecPredLogical::Mode::AND,
  1143. og->comp_node())
  1144. .node();
  1145. }
  1146. CondExecMark::Param gparam;
  1147. if (opr.param().mode == Mode::EXACT_ONE) {
  1148. // only in this mode different branches may have different shapes, so to
  1149. // avoid shape inference failure we simply skip shape inference here;
  1150. // see TestCondExec.MultiShape
  1151. // TODO: remove this if static infer considers execution mask
  1152. gparam.static_infer = CondExecMark::Param::StaticInfer::NONE;
  1153. }
  1154. return CondExecMark::make_opr(ppv, {og}, gparam,
  1155. OperatorNodeConfig{og->comp_node()})
  1156. ->output(0);
  1157. }
  1158. void CondExecMerge::modify_grad_sum_list(VarNode* wrt, VarNodeArray& grads) {
  1159. if (!ExecutionMask::have_alive_instance()) {
  1160. return;
  1161. }
  1162. auto global_registry_vec =
  1163. grads.at(0)
  1164. ->owner_graph()
  1165. ->options()
  1166. .user_data.get_user_data<CondExecPred::GlobalRegistry>();
  1167. if (!global_registry_vec.second) {
  1168. // no cond exec related oprs
  1169. return;
  1170. }
  1171. auto global_registry = global_registry_vec.first[0];
  1172. size_t nr_var_remove = 0, nr_merge_opr = 0;
  1173. VarNodeArray merged_branches;
  1174. static constexpr Param::Mode BAD_MODE = static_cast<Param::Mode>(-1);
  1175. Param::Mode merged_mode = BAD_MODE;
  1176. ExecutionMask* part_exec_mask = nullptr;
  1177. bool have_multiple_exec_mask = false;
  1178. auto check_multiple_mask = [&part_exec_mask,
  1179. &have_multiple_exec_mask](ExecutionMask* mask) {
  1180. if (!part_exec_mask) {
  1181. part_exec_mask = mask;
  1182. } else if (part_exec_mask != mask) {
  1183. have_multiple_exec_mask = true;
  1184. }
  1185. };
  1186. // loop in reverse order, and put vars to be merged at end
  1187. for (size_t i = grads.size(); i;) {
  1188. --i;
  1189. auto opr = grads[i]->owner_opr();
  1190. if (opr->same_type<CondExecMerge>()) {
  1191. // merge sum of CondExecMerge by expanding their inputs
  1192. mgb_assert(opr->output().size() == 1,
  1193. "CondExecMerge in grad list has multiple outputs: "
  1194. "name=%s out=%zu",
  1195. opr->cname(), opr->output().size());
  1196. auto cur_mode = opr->cast_final<CondExecMerge>().param().mode;
  1197. mgb_assert(cur_mode == Param::Mode::SUM ||
  1198. cur_mode == Param::Mode::SUM_COND_OUT);
  1199. if (merged_mode != Param::Mode::SUM_COND_OUT) {
  1200. // only allow promoting merge mode to be cond out (if any of the
  1201. // components are conditional)
  1202. merged_mode = cur_mode;
  1203. }
  1204. merged_branches.insert(merged_branches.end(), opr->input().begin(),
  1205. opr->input().end());
  1206. if (cur_mode == Param::Mode::SUM_COND_OUT) {
  1207. // remove the predicate input
  1208. mgb_assert(opr->input().size() == opr->output().size() + 1);
  1209. merged_branches.pop_back();
  1210. check_multiple_mask(
  1211. global_registry->require_mask_from_var(opr->output(0)));
  1212. } else if (cur_mode == Param::Mode::SUM) {
  1213. // remove shape input
  1214. mgb_assert(opr->input().size() >= opr->output().size() * 2);
  1215. merged_branches.resize(merged_branches.size() -
  1216. opr->output().size());
  1217. }
  1218. ++nr_merge_opr;
  1219. ++nr_var_remove;
  1220. std::swap(grads[grads.size() - nr_var_remove], grads[i]);
  1221. } else if (auto mask = global_registry->get_mask_from_var(grads[i])) {
  1222. check_multiple_mask(mask);
  1223. merged_branches.push_back(grads[i]);
  1224. ++nr_var_remove;
  1225. std::swap(grads[grads.size() - nr_var_remove], grads[i]);
  1226. merged_mode = Param::Mode::SUM_COND_OUT;
  1227. }
  1228. }
  1229. if (have_multiple_exec_mask || nr_merge_opr > 1) {
  1230. mgb_assert(merged_mode != BAD_MODE);
  1231. grads.resize(grads.size() - nr_var_remove);
  1232. SymbolVarArray grad_shapes;
  1233. if (merged_mode == Param::Mode::SUM) {
  1234. grad_shapes.emplace_back(SymbolVar{wrt}.symshape());
  1235. }
  1236. grads.push_back(CondExecMerge::make_opr(merged_branches, grad_shapes,
  1237. {1, merged_mode},
  1238. OperatorNodeConfig{})
  1239. ->output(0));
  1240. }
  1241. }
  1242. #endif // MGB_ENABLE_COND_EXEC
  1243. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台