You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_manip.cpp 55 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634
  1. #include "megbrain/opr/tensor_manip.h"
  2. #include "megbrain/comp_node_env.h"
  3. #include "megbrain/graph/event.h"
  4. #include "megbrain/graph/exc_extra_info.h"
  5. #include "megbrain/graph/grad_impl.h"
  6. #include "megbrain/opr/basic_arith.h"
  7. #include "megbrain/opr/io.h"
  8. #include "megbrain/opr/param_defs.h"
  9. #include "megbrain/opr/utility.h"
  10. #include "megbrain/utils/arith_helper.h"
  11. #include "./internal/megdnn_opr_wrapper.inl"
  12. using namespace mgb;
  13. using namespace opr;
  14. using namespace intl;
  15. /* f{{{ ======================= local utils ======================= */
  16. namespace {
  17. using OptionalAxis = megdnn::param::OptionalAxisV1;
  18. //! check whether shp is GetVarShape(a)
  19. bool check_is_shape_of(SymbolVar shp, SymbolVar a) {
  20. #if MGB_BUILD_SLIM_SERVING
  21. return false;
  22. #else
  23. auto op = shp.node()->owner_opr();
  24. if (op->same_type<GetVarShape>() && op->input().size() == 1 &&
  25. op->input()[0] == a.node() &&
  26. op->cast_final<GetVarShape>().param().axis == OptionalAxis::INVALID_AXIS) {
  27. return true;
  28. }
  29. using namespace cg::static_infer;
  30. auto&& mgr = a.node()->owner_graph()->static_infer_manager();
  31. if ((mgr.get_infer_type(shp.node()).value & InferType::CONST) &&
  32. (mgr.get_infer_type(a.node()).shape & InferType::CONST)) {
  33. auto&& a_shp = mgr.infer_shape(a.node());
  34. auto&& shp_val = mgr.infer_value(shp.node());
  35. TensorShape shp_shp;
  36. cg::copy_tensor_value_to_shape(shp_shp, shp_val);
  37. return a_shp.eq_shape(shp_shp);
  38. }
  39. return false;
  40. #endif
  41. }
  42. #if !MGB_BUILD_SLIM_SERVING
  43. // return x such that shape_of(var) == x
  44. GetVarShape* get_shape_shortcut(VarNode* var) {
  45. auto opr = var->owner_opr();
  46. auto otype = opr->dyn_typeinfo();
  47. if (!(otype == Reshape::typeinfo() &&
  48. opr->cast_final<Reshape>().param().axis == OptionalAxis::INVALID_AXIS) &&
  49. otype != Broadcast::typeinfo()) {
  50. return nullptr;
  51. }
  52. auto i1 = opr->input(1)->owner_opr();
  53. if (i1->same_type<GetVarShape>())
  54. return &i1->cast_final<GetVarShape>();
  55. return nullptr;
  56. }
  57. #endif
  58. } // anonymous namespace
  59. // f}}}
  60. /* f{{{ ======================= GetVarShape ======================= */
  61. MGB_DYN_TYPE_OBJ_FINAL_IMPL(GetVarShape);
  62. GetVarShape::GetVarShape(
  63. const VarNodeArrayView& inp, Param axis, const OperatorNodeConfig& config)
  64. : Super(inp.at(0)->owner_graph(), config, "shape_of", inp), m_axis{axis} {
  65. m_src_shapes.resize(inp.size());
  66. for (auto i : inp)
  67. add_input({i});
  68. add_input({}, AddInputSortType::ALL);
  69. add_output(None)->dtype(dtype::Int32());
  70. add_equivalence_component<PODHash<Param>>(&m_axis);
  71. mgb_assert(abs(m_axis.axis) <= m_axis.MAX_NDIM);
  72. }
  73. void GetVarShape::update_cached_shape() {
  74. TensorShape ishp;
  75. if (m_src_shapes.size() == 1) {
  76. ishp = m_src_shapes[0];
  77. } else {
  78. megdnn::Elemwise::deduce_shape(m_src_shapes, ishp);
  79. }
  80. mgb_assert(ishp.ndim);
  81. // check whether m_cached_shape is valid and update it if not
  82. if (m_axis.axis != OptionalAxis::INVALID_AXIS) {
  83. int axis = m_axis.axis;
  84. if (axis < 0) {
  85. axis += ishp.ndim;
  86. }
  87. mgb_assert(axis >= 0 && axis < (int)ishp.ndim);
  88. if (m_cached_shape.ndim == 1 && m_cached_shape.shape[0] == ishp.shape[axis])
  89. return;
  90. m_cached_shape = {ishp.shape[axis]};
  91. } else {
  92. if (m_cached_shape.eq_shape(ishp))
  93. return;
  94. m_cached_shape = ishp;
  95. }
  96. cg::copy_shape_to_tensor_value(m_cached_shape_cpu_v, m_cached_shape);
  97. m_cached_shape_dev_v_synced = false;
  98. }
  99. void GetVarShape::scn_do_execute() {
  100. for (size_t i = 0; i < m_src_shapes.size(); ++i) {
  101. m_src_shapes[i] = input()[i]->shape();
  102. }
  103. update_cached_shape();
  104. if (!m_cached_shape_dev_v_synced) {
  105. m_cached_shape_dev_v.copy_from(m_cached_shape_cpu_v);
  106. m_cached_shape_dev_v_synced = true;
  107. }
  108. output(0)->dev_tensor().copy_from_fixlayout(m_cached_shape_dev_v);
  109. }
  110. void GetVarShape::update_for_static_infer(const cg::static_infer::InpVal& inp) {
  111. for (size_t i = 0; i < m_src_shapes.size(); ++i) {
  112. m_src_shapes[i] = inp.val.at(i).shape();
  113. }
  114. update_cached_shape();
  115. }
  116. void GetVarShape::init_output_static_infer_desc() {
  117. using namespace cg::static_infer;
  118. auto infer_shape = [this](TensorShape& dest, const InpVal& inp) {
  119. update_for_static_infer(inp);
  120. dest = m_cached_shape_cpu_v.shape();
  121. return true;
  122. };
  123. auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) {
  124. update_for_static_infer(inp);
  125. dest = m_cached_shape_cpu_v;
  126. return true;
  127. };
  128. DepVal deps;
  129. for (auto i : input()) {
  130. deps.push_back({i, DepType::SHAPE});
  131. }
  132. auto&& mgr = owner_graph()->static_infer_manager();
  133. mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape});
  134. mgr.register_value_infer(output(0), {SourceType::DEP, deps, infer_value});
  135. }
  136. #if MGB_ENABLE_GRAD
  137. MGB_IMPL_OPR_GRAD(GetVarShape) {
  138. MGB_MARK_USED_VAR(wrt_idx);
  139. MGB_MARK_USED_VAR(out_grad);
  140. return nullptr;
  141. }
  142. #endif
  143. SymbolVar GetVarShape::make(
  144. const VarNodeArrayView& inp, Param param, const OperatorNodeConfig& config) {
  145. mgb_assert(!inp.empty());
  146. #if !MGB_BUILD_SLIM_SERVING
  147. // try to apply shortcut and omit scalar shapes to optimize
  148. VarNodeArray inp_vp;
  149. inp_vp.reserve(inp.size());
  150. auto&& mgr = inp[0]->owner_graph()->static_infer_manager();
  151. for (auto var : inp) {
  152. auto&& it = mgr.get_infer_type(var);
  153. if (it.shape & cg::static_infer::InferType::CONST) {
  154. if (mgr.infer_shape(var).is_scalar()) {
  155. // scalar does not affect broadcast result
  156. continue;
  157. }
  158. }
  159. if (auto opr = get_shape_shortcut(var)) {
  160. // current var replaced by a shortcut
  161. auto&& op_inp = opr->input();
  162. inp_vp.insert(inp_vp.end(), op_inp.begin(), op_inp.end());
  163. continue;
  164. }
  165. inp_vp.push_back(var);
  166. }
  167. if (inp_vp.empty()) {
  168. // all inputs are scalar
  169. mgb_assert(param.axis == OptionalAxis::INVALID_AXIS || param.axis == 0);
  170. return SymbolVar{inp[0]}.make_scalar(1);
  171. }
  172. #else
  173. auto&& inp_vp = inp;
  174. #endif
  175. return SymbolVar{inp[0]}.insert_single_output_opr<GetVarShape>(
  176. inp_vp, param, config);
  177. }
  178. cg::OperatorNodeBase::NodeProp* GetVarShape::do_make_node_prop() const {
  179. auto prop = Super::do_make_node_prop();
  180. using DT = NodeProp::DepType;
  181. SmallVector<DT> dt(input().size(), DT::SHAPE);
  182. prop->reset_dep_type(input(), dt);
  183. return prop;
  184. }
  185. class GetVarShape::ShapeDevValueExecDep final : public ExecDependency {
  186. DeviceTensorStorage m_val;
  187. public:
  188. explicit ShapeDevValueExecDep(DeviceTensorStorage val) : m_val(std::move(val)) {}
  189. };
  190. void GetVarShape::record_execute_deps(ExecDependencyArray& deps) {
  191. deps.emplace_back(
  192. std::make_unique<ShapeDevValueExecDep>(m_cached_shape_dev_v.storage()));
  193. }
  194. // f}}}
  195. /* f{{{ ======================= ReshapeBrdcastHelper ======================= */
  196. void ReshapeBrdcastHelper::reshapebrdcast_init(VarNode* inp, VarNode* tshp) {
  197. add_input({inp, tshp});
  198. add_output(None)->dtype(inp->dtype()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  199. if (reshapebrdcast_output_shape_need_input_shape())
  200. outshape_by_symvar_enable(1, 1);
  201. else
  202. outshape_by_symvar_enable(0, 1);
  203. }
  204. void ReshapeBrdcastHelper::mem_plan_fwd_in2out_readonly() {
  205. auto&& tshape = output(0)->shape();
  206. auto inp_layout = input(0)->layout();
  207. auto dst_layout = reshapebrdcast_get_dest_layout(inp_layout, tshape);
  208. if (!dst_layout.valid()) {
  209. // retry after making input contiguous
  210. mgb_assert(dyn_typeinfo() == Reshape::typeinfo());
  211. inp_layout.init_contiguous_stride(input(0)->shape());
  212. dst_layout = reshapebrdcast_get_dest_layout(inp_layout, tshape);
  213. mgb_assert(dst_layout.valid());
  214. m_rofwd_subspec = SubTensorSpec::make_from_layout(dst_layout.val());
  215. m_incompatible_inp_layout = true;
  216. return;
  217. }
  218. m_rofwd_subspec = SubTensorSpec::make_from_layout(dst_layout.val());
  219. m_incompatible_inp_layout = false;
  220. rofwd_init_mem_plan();
  221. }
  222. void ReshapeBrdcastHelper::outshape_by_symvar_do_get_output_shape(
  223. TensorShape& dest, const ShapeInferInfo& shpinfo) {
  224. if (reshapebrdcast_output_shape_need_input_shape()) {
  225. TensorShape oshp_given;
  226. cg::copy_tensor_value_to_shape(oshp_given, *shpinfo.shpval_inp_val.at(0));
  227. TensorLayout src;
  228. src.init_contiguous_stride(shpinfo.shape_inp_shp.at(0));
  229. dest = reshapebrdcast_get_dest_layout(src, oshp_given).val();
  230. } else {
  231. cg::copy_tensor_value_to_shape(dest, *shpinfo.shpval_inp_val.at(0));
  232. }
  233. }
  234. void ReshapeBrdcastHelper::scn_do_execute() {
  235. if (m_incompatible_inp_layout) {
  236. // only happens in reshape
  237. auto&& iv = input(0)->dev_tensor();
  238. auto ishp = iv.shape();
  239. auto&& ov = output(0)->dev_tensor();
  240. mgb_assert(ishp.total_nr_elems() == ov.shape().total_nr_elems());
  241. ov.sub(SubTensorSpec::make_from_layout({ishp, iv.dtype()}))
  242. .copy_from_fixlayout(iv);
  243. } else
  244. rofwd_execute();
  245. }
  246. void ReshapeBrdcastHelper::add_input_layout_constraint() {
  247. if (!cg::is_static_var_value(input(1)))
  248. return;
  249. auto check_layout = [this](const TensorLayout& layout) {
  250. MGB_TRY {
  251. TensorShape oshp;
  252. outshape_by_symvar_do_get_output_shape(
  253. oshp, outshape_by_symvar_get_shape_infer_info());
  254. return reshapebrdcast_get_dest_layout(layout, oshp).valid();
  255. }
  256. MGB_CATCH(MegBrainError & exc, {
  257. if (!exc.extra_info())
  258. cg::OperatorNodeExcExtraInfo::record(this, exc);
  259. throw;
  260. })
  261. };
  262. input(0)->add_layout_constraint(check_layout);
  263. }
  264. void ReshapeBrdcastHelper::init_output_static_infer_desc() {
  265. Super::init_output_static_infer_desc();
  266. using namespace cg::static_infer;
  267. auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) {
  268. TensorShape oshp;
  269. cg::copy_tensor_value_to_shape(oshp, inp.val.at(1).value());
  270. auto&& iv = inp.val[0].value();
  271. auto sub_layout = reshapebrdcast_get_dest_layout(iv.layout(), oshp);
  272. if (sub_layout.valid()) {
  273. dest = const_cast<DeviceTensorND&>(iv).sub(
  274. SubTensorSpec::make_from_layout(sub_layout.val()));
  275. } else {
  276. // use contig dest
  277. dest = {};
  278. dest.copy_from(iv);
  279. sub_layout = reshapebrdcast_get_dest_layout(dest.layout(), oshp);
  280. mgb_assert(sub_layout.valid());
  281. dest = dest.sub(SubTensorSpec::make_from_layout(sub_layout.val()));
  282. }
  283. return true;
  284. };
  285. owner_graph()->static_infer_manager().register_value_infer(
  286. output(0), {SourceType::DEP,
  287. {{input(0), DepType::VALUE}, {input(1), DepType::VALUE}},
  288. infer_value});
  289. }
  290. ReshapeBrdcastHelper::NodeProp* ReshapeBrdcastHelper::do_make_node_prop() const {
  291. auto ret = Super::do_make_node_prop();
  292. ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
  293. return ret;
  294. }
  295. // f}}}
  296. /* f{{{ ======================= Reshape ======================= */
  297. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Reshape);
  298. Reshape::Reshape(
  299. VarNode* inp, VarNode* tshp, Param unspec_axis,
  300. const OperatorNodeConfig& config)
  301. : Super{inp->owner_graph(), config, "reshape", {inp}},
  302. m_unspec_axis{unspec_axis} {
  303. reshapebrdcast_init(inp, tshp);
  304. add_equivalence_component<PODHash<Param>>(&m_unspec_axis);
  305. }
  306. SymbolVar Reshape::make(
  307. SymbolVar inp, SymbolVar tshp, Param unspec_axis,
  308. const OperatorNodeConfig& config) {
  309. if (check_is_shape_of(tshp, inp))
  310. return inp;
  311. return inp.insert_single_output_opr<Reshape>(
  312. inp.node(), tshp.node(), unspec_axis, config);
  313. }
  314. #if MGB_ENABLE_GRAD
  315. MGB_IMPL_OPR_GRAD(Reshape) {
  316. if (wrt_idx)
  317. return InvalidGrad::make(opr, wrt_idx);
  318. return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node();
  319. }
  320. #endif
  321. Maybe<TensorLayout> Reshape::reshapebrdcast_get_dest_layout(
  322. const TensorLayout& src, const TensorShape& tshape) const {
  323. if (m_unspec_axis.axis == OptionalAxis::INVALID_AXIS) {
  324. TensorLayout ret;
  325. if (src.try_reshape(ret, tshape))
  326. return ret;
  327. return None;
  328. }
  329. int original_unspec = m_unspec_axis.axis;
  330. if (original_unspec < 0) {
  331. original_unspec += tshape.ndim;
  332. }
  333. size_t unspec = original_unspec;
  334. mgb_assert(unspec < tshape.ndim);
  335. auto actual_tshape = tshape;
  336. size_t rem_nr_elem = 1;
  337. for (size_t i = 0; i < tshape.ndim; ++i) {
  338. if (i != unspec)
  339. rem_nr_elem *= tshape.shape[i];
  340. }
  341. auto tot_nr_elem = src.total_nr_elems();
  342. actual_tshape.shape[unspec] = 0;
  343. mgb_throw_if(
  344. !rem_nr_elem || tot_nr_elem % rem_nr_elem, TensorReshapeError,
  345. "could not reshape: src=%s tshape=%s unspec_axis=%zd",
  346. static_cast<const TensorShape&>(src).to_string().c_str(),
  347. actual_tshape.to_string().c_str(), unspec);
  348. actual_tshape.shape[unspec] = tot_nr_elem / rem_nr_elem;
  349. TensorLayout ret;
  350. if (src.try_reshape(ret, actual_tshape))
  351. return ret;
  352. return None;
  353. }
  354. bool Reshape::reshapebrdcast_output_shape_need_input_shape() const {
  355. return m_unspec_axis.axis != OptionalAxis::INVALID_AXIS;
  356. }
  357. // f}}}
  358. /* f{{{ ======================= Broadcast ======================= */
  359. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Broadcast);
  360. Broadcast::Broadcast(VarNode* inp, VarNode* tshp, const OperatorNodeConfig& config)
  361. : Super{inp->owner_graph(), config, "broadcast", {inp}} {
  362. reshapebrdcast_init(inp, tshp);
  363. }
  364. SymbolVar Broadcast::make(
  365. SymbolVar inp, SymbolVar tshp, const OperatorNodeConfig& config) {
  366. if (check_is_shape_of(tshp, inp))
  367. return inp;
  368. return inp.insert_single_output_opr<Broadcast>(inp.node(), tshp.node(), config);
  369. }
  370. #if MGB_ENABLE_GRAD
  371. MGB_IMPL_OPR_GRAD(Broadcast) {
  372. if (wrt_idx)
  373. return InvalidGrad::make(opr, wrt_idx);
  374. return Reduce::make(
  375. out_grad.at(0), Reduce::Mode::SUM, GetVarShape::make(opr.input(0)))
  376. .node();
  377. }
  378. #endif
  379. Maybe<TensorLayout> Broadcast::reshapebrdcast_get_dest_layout(
  380. const TensorLayout& src, const TensorShape& tshape) const {
  381. return src.broadcast(tshape);
  382. }
  383. bool Broadcast::reshapebrdcast_output_shape_need_input_shape() const {
  384. return false;
  385. }
  386. // f}}}
  387. /* f{{{ ======================= AxisManipOprBase ======================= */
  388. void AxisManipOprBase::mem_plan_fwd_in2out_readonly() {
  389. m_rofwd_subspec = SubTensorSpec::make_from_layout(
  390. axis_manip_get_output_layout(input(0)->layout()));
  391. rofwd_init_mem_plan();
  392. }
  393. void AxisManipOprBase::scn_do_execute() {
  394. rofwd_execute();
  395. }
  396. void AxisManipOprBase::init_output_static_infer_desc() {
  397. using namespace cg::static_infer;
  398. auto&& mgr = owner_graph()->static_infer_manager();
  399. auto infer_shape = [this](TensorShape& dest, const InpVal& inp) {
  400. dest = axis_manip_get_output_layout({inp.val.at(0).shape(), input(0)->dtype()});
  401. return true;
  402. };
  403. auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) {
  404. auto&& iv = inp.val.at(0).value();
  405. auto oly = axis_manip_get_output_layout(iv.layout());
  406. dest = const_cast<DeviceTensorND&>(iv).sub(
  407. SubTensorSpec::make_from_layout(oly));
  408. return true;
  409. };
  410. mgr.register_shape_infer(
  411. output(0), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
  412. mgr.register_value_infer(
  413. output(0), {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value});
  414. }
  415. AxisManipOprBase::NodeProp* AxisManipOprBase::do_make_node_prop() const {
  416. auto ret = Super::do_make_node_prop();
  417. ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
  418. return ret;
  419. }
  420. void AxisManipOprBase::axis_manip_init(VarNode* inp) {
  421. add_input({inp});
  422. add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  423. }
  424. // f}}}
  425. /* f{{{ ======================= Dimshuffle ======================= */
  426. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Dimshuffle);
  427. Dimshuffle::Dimshuffle(
  428. VarNode* inp, const std::vector<int>& pattern, size_t ndim,
  429. const OperatorNodeConfig& config)
  430. : Super{inp->owner_graph(), config, "dimshuffle", {inp}},
  431. m_pattern(pattern),
  432. m_inp_ndim(ndim) {
  433. mgb_throw_if(
  434. m_pattern.size() > TensorShape::MAX_NDIM, GraphError,
  435. "Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM);
  436. for (auto i : m_pattern) {
  437. mgb_throw_if(i < -1 || i >= int(ndim), GraphError, "bad Dimshuffle pattern");
  438. }
  439. axis_manip_init(inp);
  440. add_equivalence_component<PODHash<int>>(m_pattern.data(), m_pattern.size());
  441. }
  442. SymbolVar Dimshuffle::make(
  443. SymbolVar inp, const std::vector<int>& pattern, size_t ndim,
  444. const OperatorNodeConfig& config) {
  445. if (!ndim)
  446. ndim = *std::max_element(pattern.begin(), pattern.end()) + 1;
  447. return inp.insert_single_output_opr<Dimshuffle>(inp.node(), pattern, ndim, config);
  448. }
  449. TensorLayout Dimshuffle::axis_manip_get_output_layout(const TensorLayout& ily) const {
  450. mgb_assert(
  451. ily.ndim == m_inp_ndim,
  452. "input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", m_inp_ndim,
  453. ily.ndim);
  454. TensorLayout oly{ily.dtype};
  455. oly.ndim = m_pattern.size();
  456. size_t idx = 0;
  457. bool input_used[TensorLayout::MAX_NDIM] = {0};
  458. for (auto i : m_pattern) {
  459. if (i < 0) {
  460. oly.shape[idx] = 1;
  461. oly.stride[idx] = 1;
  462. } else {
  463. input_used[i] = true;
  464. oly.shape[idx] = ily.shape[i];
  465. oly.stride[idx] = ily.stride[i];
  466. }
  467. ++idx;
  468. }
  469. for (size_t i = 0; i < m_inp_ndim; ++i) {
  470. mgb_assert(
  471. input_used[i] || ily.shape[i] == 1,
  472. "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
  473. static_cast<const TensorShape&>(ily).to_string().c_str(), i);
  474. }
  475. return oly;
  476. }
  477. VarNode* Dimshuffle::grad(size_t /*wrt_idx*/, const VarNodeArray& out_grad) const {
  478. std::vector<int> back(m_inp_ndim, -1);
  479. for (size_t i = 0; i < m_pattern.size(); i++) {
  480. // outdim[i] is indim[j]
  481. auto j = m_pattern[i];
  482. if (j >= 0) {
  483. mgb_assert(
  484. back[j] == -1,
  485. "taking grad for Dimshuffle with duplicated "
  486. "input axis unsupported");
  487. back[j] = i;
  488. }
  489. }
  490. return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node();
  491. }
  492. #if MGB_ENABLE_GRAD
  493. MGB_IMPL_OPR_GRAD(Dimshuffle) {
  494. return opr.grad(wrt_idx, out_grad);
  495. }
  496. #endif
  497. // f}}}
  498. /* f{{{ ======================= AxisAddRemove ======================= */
  499. MGB_DYN_TYPE_OBJ_FINAL_IMPL(AxisAddRemove);
  500. AxisAddRemove::AxisAddRemove(
  501. VarNode* inp, const std::vector<AxisDesc>& desc,
  502. const OperatorNodeConfig& config)
  503. : Super{inp->owner_graph(), config, "axis_add_rm", {inp}}, m_desc(desc) {
  504. mgb_throw_if(desc.empty(), GraphError, "desc for AxisAddRemove could not be empty");
  505. axis_manip_init(inp);
  506. add_equivalence_component<PODHash<AxisDesc>>(m_desc.data(), m_desc.size());
  507. }
  508. SymbolVar AxisAddRemove::make(
  509. SymbolVar inp, const std::vector<AxisDesc>& desc,
  510. const OperatorNodeConfig& config) {
  511. return inp.insert_single_output_opr<AxisAddRemove>(inp.node(), desc, config);
  512. }
  513. TensorLayout AxisAddRemove::axis_manip_get_output_layout(
  514. const TensorLayout& input_layout) const {
  515. auto layout = input_layout;
  516. for (auto&& i : m_desc) {
  517. using M = AxisDesc::Method;
  518. switch (i.method) {
  519. case M::REMOVE: {
  520. auto axis = i.axis.get(layout.ndim);
  521. if (layout.ndim == 1) {
  522. mgb_assert(
  523. layout.shape[0] == 1 && axis == 0,
  524. "can not remove axis %zu from tensor of shape=%s", axis,
  525. layout.megdnn::TensorShape::to_string().c_str());
  526. } else {
  527. mgb_assert(
  528. axis < layout.ndim && layout.shape[axis] == 1,
  529. "can not remove axis %zu from tensor of shape=%s", axis,
  530. layout.megdnn::TensorShape::to_string().c_str());
  531. layout.remove_axis_inplace(axis);
  532. }
  533. break;
  534. }
  535. case M::ADD_1:
  536. layout.add_axis_cont_inplace(i.axis.get(layout.ndim + 1));
  537. break;
  538. }
  539. }
  540. return layout;
  541. }
  542. #if MGB_ENABLE_GRAD
  543. MGB_IMPL_OPR_GRAD(AxisAddRemove) {
  544. MGB_MARK_USED_VAR(wrt_idx);
  545. return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node();
  546. }
  547. #endif
  548. // f}}}
  549. /* f{{{ ======================= Subtensor ======================= */
  550. Subtensor::Subtensor(
  551. VarNode* inp, const IndexDesc& desc, const OperatorNodeConfig& config)
  552. : Super({inp->owner_graph(), config, "subtensor", {inp}}, inp, nullptr, desc,
  553. true) {
  554. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  555. }
  556. SymbolVar Subtensor::make(
  557. SymbolVar inp, const IndexDesc& desc, const OperatorNodeConfig& config) {
  558. return inp.insert_single_output_opr<Subtensor>(inp.node(), desc, config);
  559. }
  560. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Subtensor);
  561. #if MGB_ENABLE_GRAD
  562. MGB_IMPL_OPR_GRAD(Subtensor) {
  563. if (wrt_idx)
  564. return InvalidGrad::make(opr, wrt_idx);
  565. return IncrSubtensor::make(
  566. SymbolVar{opr.input(0)}.fill_retain_dtype(0), out_grad.at(0),
  567. opr.index_desc())
  568. .node();
  569. }
  570. #endif
  571. void Subtensor::init_output_static_infer_desc() {
  572. using namespace cg::static_infer;
  573. DepVal deps;
  574. // shape inference only needs slices
  575. deps.push_back({input(0), DepType::SHAPE});
  576. for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) {
  577. if (!m_input2idxonly_axis_indexer[i])
  578. deps.push_back({input(i), DepType::VALUE});
  579. }
  580. auto infer_shape = [this](TensorShape& dest, const InpVal& inp) {
  581. auto&& ishp = inp.val[0].shape();
  582. auto subspec =
  583. fancy_indexing_make_sub_spec({ishp, input(0)->dtype()}, inp, 1, true);
  584. dest = subspec.layout();
  585. return true;
  586. };
  587. owner_graph()->static_infer_manager().register_shape_infer(
  588. output(0), {SourceType::DEP, deps, infer_shape});
  589. deps.clear();
  590. for (auto i : input())
  591. deps.push_back({i, DepType::VALUE});
  592. deps[0].type = DepType::VALUE;
  593. auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) {
  594. auto&& iv = inp.val[0].value();
  595. auto subspec = fancy_indexing_make_sub_spec(iv.layout(), inp, 1);
  596. dest = const_cast<DeviceTensorND&>(iv).sub(subspec);
  597. return true;
  598. };
  599. owner_graph()->static_infer_manager().register_value_infer(
  600. output(0), {SourceType::DEP, deps, infer_value});
  601. }
  602. void Subtensor::scn_do_execute() {
  603. rofwd_execute();
  604. }
  605. void Subtensor::mem_plan_fwd_in2out_readonly() {
  606. m_rofwd_subspec = fancy_indexing_make_sub_spec(input(0)->layout());
  607. rofwd_init_mem_plan();
  608. }
  609. void Subtensor::init_rt_force_dynamic_mem_alloc_imply_chain() {
  610. auto inp = input(0), out = output(0);
  611. inp->add_rt_force_dynamic_mem_alloc_imply_chain(out);
  612. out->add_rt_force_dynamic_mem_alloc_imply_chain(inp);
  613. }
  614. Subtensor::NodeProp* Subtensor::do_make_node_prop() const {
  615. auto ret = Super::do_make_node_prop();
  616. ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
  617. return ret;
  618. }
  619. // f}}}
  620. /* f{{{ ================== ModifySubtensorImplHelper ================== */
  621. void ModifySubtensorImplHelper::scn_do_execute() {
  622. auto mod = fancy_indexing_get_tensors_for_modify_in_scn_do_execute();
  623. modify(mod.first, mod.second);
  624. }
  625. void ModifySubtensorImplHelper::init_output_static_infer_desc() {
  626. using namespace cg::static_infer;
  627. auto&& mgr = owner_graph()->static_infer_manager();
  628. // try to register shape infer with subtensor shape check
  629. auto try_infer_shape_with_check = [&]() -> bool {
  630. if (!cg::is_static_var_shape(input(0)) || !cg::is_static_var_shape(input(1)))
  631. return false;
  632. for (size_t i = 2; i < input().size(); ++i) {
  633. if (!cg::is_static_var_value(input(i)) ||
  634. !mgr.infer_value_fallible(input(i)))
  635. return false;
  636. }
  637. auto infer_shape = [this](TensorShape& dest, const InpVal& inp) {
  638. dest = inp.val.at(0).shape();
  639. // throw exception if shapes mismatch
  640. auto subspec =
  641. fancy_indexing_make_sub_spec({dest, input(0)->dtype()}, inp, 2);
  642. auto&& subshp = inp.val.at(1).shape();
  643. mgb_throw_if(
  644. !subspec.layout().eq_shape(subshp), TensorReshapeError,
  645. "SetSubtensor shape mismatch: subspec=%s value_shape=%s",
  646. subspec.layout().TensorShape::to_string().c_str(),
  647. subshp.to_string().c_str());
  648. return true;
  649. };
  650. DepVal deps;
  651. for (auto i : input())
  652. deps.push_back({i, DepType::VALUE});
  653. deps[0].type = deps[1].type = DepType::SHAPE;
  654. mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape});
  655. return true;
  656. };
  657. if (has_input_tensor_replacer()) {
  658. mgr.register_shape_infer(output(0), ShapeInferDesc::make_const({}));
  659. } else {
  660. if (!try_infer_shape_with_check()) {
  661. auto infer_shape = [](TensorShape& dest, const InpVal& inp) {
  662. dest = inp.val.at(0).shape();
  663. return true;
  664. };
  665. mgr.register_shape_infer(
  666. output(0),
  667. {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
  668. }
  669. }
  670. auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) {
  671. dest.copy_from(inp.val.at(0).value());
  672. auto subspec = fancy_indexing_make_sub_spec(dest.layout(), inp, 2);
  673. auto dsub = dest.sub(subspec);
  674. modify(dsub, inp.val.at(1).value());
  675. return true;
  676. };
  677. DepVal value_deps;
  678. for (auto i : input())
  679. value_deps.push_back({i, DepType::VALUE});
  680. mgr.register_value_infer(output(0), {SourceType::DEP, value_deps, infer_value});
  681. }
  682. // f}}}
  683. /* f{{{ ======================= SetSubtensor ======================= */
  684. SetSubtensor::SetSubtensor(
  685. VarNode* inp, VarNode* value, const IndexDesc& desc,
  686. const OperatorNodeConfig& config,
  687. const InputTensorReplacer& input_tensor_replacer)
  688. : Super({inp->owner_graph(), config, "set_subtensor", {inp, value}}, inp, value,
  689. desc, true, input_tensor_replacer) {
  690. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  691. }
  692. SymbolVar SetSubtensor::make(
  693. SymbolVar inp, SymbolVar value, const IndexDesc& desc,
  694. const OperatorNodeConfig& config,
  695. const InputTensorReplacer& input_tensor_replacer) {
  696. return inp.insert_single_output_opr<SetSubtensor>(
  697. inp.node(), value.node(), desc, config, input_tensor_replacer);
  698. }
  699. MGB_DYN_TYPE_OBJ_FINAL_IMPL(SetSubtensor);
  700. void SetSubtensor::modify(DeviceTensorND& sub, const DeviceTensorND& val) {
  701. if (!val.layout().is_empty()) {
  702. sub.copy_from_fixlayout(val);
  703. }
  704. }
  705. SetSubtensor::NodeProp* SetSubtensor::do_make_node_prop() const {
  706. auto ret = Super::do_make_node_prop();
  707. ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
  708. ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY);
  709. return ret;
  710. }
  711. #if MGB_ENABLE_GRAD
  712. MGB_IMPL_OPR_GRAD(SetSubtensor) {
  713. if (wrt_idx >= 2)
  714. return InvalidGrad::make(opr, wrt_idx);
  715. if (wrt_idx == 0) {
  716. return SetSubtensor::make(
  717. out_grad.at(0), SymbolVar{opr.input(1)}.fill_retain_dtype(0),
  718. opr.index_desc())
  719. .node();
  720. }
  721. return Subtensor::make(out_grad.at(0), opr.index_desc()).node();
  722. }
  723. #endif
  724. // f}}}
  725. /* f{{{ ======================= IncrSubtensor ======================= */
  726. MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrSubtensor, "incr_subtensor", true);
  727. void IncrSubtensor::modify(DeviceTensorND& sub, const DeviceTensorND& val) {
  728. CompNode opr_comp_node;
  729. if (sub.comp_node().locator().device == CompNode::Locator::DEVICE_CPU_DEFAULT) {
  730. // for static infer
  731. opr_comp_node = CompNode::default_cpu();
  732. } else {
  733. opr_comp_node = comp_node();
  734. }
  735. auto opr = intl::get_megdnn_global_opr<megdnn::AddUpdate>(opr_comp_node);
  736. opr->exec(sub.as_megdnn(), val.as_megdnn());
  737. }
  738. #if MGB_ENABLE_GRAD
  739. MGB_IMPL_OPR_GRAD(IncrSubtensor) {
  740. if (wrt_idx >= 2)
  741. return InvalidGrad::make(opr, wrt_idx);
  742. if (wrt_idx == 0) {
  743. return out_grad.at(0);
  744. }
  745. return Subtensor::make(out_grad.at(0), opr.index_desc()).node();
  746. }
  747. #endif
  748. // f}}}
  749. /* f{{{ ======================= IndexAt ======================= */
  750. SymbolVar IndexAt::make(
  751. SymbolVar inp, const std::vector<std::pair<size_t, SymbolVar>>& index,
  752. const OperatorNodeConfig& config) {
  753. Subtensor::IndexDesc desc;
  754. for (auto&& i : index) {
  755. desc.emplace_back();
  756. desc.back().axis = i.first;
  757. desc.back().idx = i.second;
  758. }
  759. return Subtensor::make(inp, desc, config);
  760. }
  761. // f}}}
  762. /* f{{{ ======================= Split ======================= */
  763. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Split);
  764. Split::Options Split::Options::make_average(int axis, size_t nr_part) {
  765. auto cb = [nr_part](size_t size) {
  766. std::vector<size_t> part(nr_part, size / nr_part);
  767. for (size_t i = 0, it = size % nr_part; i < it; ++i)
  768. ++part[i];
  769. return part;
  770. };
  771. return make_callback(axis, nr_part, cb);
  772. }
  773. Split::Options Split::Options::make_partition(
  774. int axis, const SymbolVarArray& partition) {
  775. mgb_assert(!partition.empty());
  776. Options rst;
  777. rst.method = Method::SPECIFY;
  778. rst.axis = axis;
  779. rst.partition = partition;
  780. return rst;
  781. }
  782. Split::Options Split::Options::make_partition(
  783. SymbolVar inp, int axis, const std::vector<size_t>& partition) {
  784. SymbolVarArray sym_partition;
  785. for (auto i : partition)
  786. sym_partition.push_back(inp.make_scalar(static_cast<int>(i)));
  787. return make_partition(axis, sym_partition);
  788. }
  789. Split::Options Split::Options::make_callback(
  790. int axis, size_t nr_part, callback_t callback) {
  791. mgb_assert(nr_part);
  792. Options rst;
  793. rst.method = Method::CALL_BACK;
  794. rst.axis = axis;
  795. rst.callback = callback;
  796. rst.nr_part = nr_part;
  797. return rst;
  798. }
  799. SymbolVarArray Split::make(
  800. SymbolVar inp, Options opt, const OperatorNodeConfig& config) {
  801. SymbolVarArray ret;
  802. auto&& output =
  803. inp.node()
  804. ->owner_graph()
  805. ->insert_opr(std::make_unique<Split>(inp.node(), opt, config))
  806. ->output();
  807. for (auto i : output) {
  808. ret.emplace_back(i);
  809. }
  810. return ret;
  811. }
  812. Split::Split(VarNode* inp, const Options& opt, const OperatorNodeConfig& config)
  813. : Super{inp->owner_graph(), config, "split", {inp}}, m_opt(opt) {
  814. add_input({inp});
  815. add_equivalence_component<ScalarHash<size_t>>(m_opt.axis);
  816. if (m_opt.method == Options::Method::SPECIFY) {
  817. mgb_assert(!m_opt.partition.empty());
  818. for (auto&& i : m_opt.partition)
  819. add_input({i.node()});
  820. outshape_by_symvar_enable(0, 1);
  821. m_opt.nr_part = m_opt.partition.size();
  822. } else {
  823. // disable dedup
  824. add_equivalence_component<ScalarHash<void*>>(this);
  825. mgb_assert(m_opt.method == Options::Method::CALL_BACK);
  826. mgb_assert(m_opt.nr_part);
  827. }
  828. for (size_t i = 0; i < m_opt.nr_part; ++i)
  829. add_output(ssprintf("o%zd", i))
  830. ->dtype(inp->dtype())
  831. .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  832. m_output_spec.resize(m_opt.nr_part);
  833. }
  834. void Split::init_output_static_infer_desc() {
  835. using namespace cg::static_infer;
  836. using namespace std::placeholders;
  837. auto&& mgr = owner_graph()->static_infer_manager();
  838. DepVal shp_deps{{input(0), DepType::SHAPE}};
  839. if (m_opt.method == Options::Method::SPECIFY) {
  840. for (size_t i = 1; i < input().size(); ++i)
  841. shp_deps.push_back({input(i), DepType::VALUE});
  842. }
  843. auto infer_value = [this](size_t oidx, DeviceTensorND& dest, const InpVal& inp) {
  844. auto&& cur_shp = m_output_spec[oidx].shape;
  845. mgb_assert(cur_shp.eq_shape(inp.val[1].shape()));
  846. auto axis = m_opt.axis;
  847. if (axis < 0)
  848. axis += m_output_spec[0].shape.ndim;
  849. size_t offset = 0;
  850. for (size_t i = 0; i < oidx; ++i)
  851. offset += m_output_spec[i].shape[axis];
  852. auto&& iv = inp.val[0].value();
  853. auto subspec = Slice(offset, offset + cur_shp[axis]).apply(iv.layout(), axis);
  854. dest.copy_from(const_cast<DeviceTensorND&>(iv).sub(subspec));
  855. return true;
  856. };
  857. for (size_t i = 0; i < output().size(); ++i) {
  858. auto ov = output(i);
  859. mgr.register_shape_infer(
  860. ov, {SourceType::DEP, shp_deps,
  861. std::bind(&Split::infer_shape, this, i, _1, _2)});
  862. mgr.register_value_infer(
  863. ov, {SourceType::DEP,
  864. {{input(0), DepType::VALUE}, {ov, DepType::SHAPE}},
  865. std::bind(infer_value, i, _1, _2)});
  866. }
  867. }
  868. bool Split::infer_shape(
  869. size_t out_idx, TensorShape& dest, const cg::static_infer::InpVal& inp) {
  870. mgb_assert(inp.run_id > 0, "run id should be a positive number");
  871. if (inp.run_id != m_output_shape_version) {
  872. std::vector<size_t> partition;
  873. auto ishp = inp.val.at(0).shape();
  874. auto axis = m_opt.axis;
  875. if (axis < 0)
  876. axis += ishp.ndim;
  877. if (m_opt.method == Options::Method::SPECIFY) {
  878. for (size_t i = 0; i < m_opt.nr_part; ++i) {
  879. auto&& val = inp.val.at(i + 1).value();
  880. mgb_assert(val.shape().is_scalar(), "shapes for Split must be scalars");
  881. size_t cvt;
  882. static_cast_dtype_safe(&cvt, val.dtype(), val.raw_ptr());
  883. partition.push_back(cvt);
  884. }
  885. } else {
  886. partition = m_opt.callback(ishp.shape[axis]);
  887. mgb_assert(
  888. partition.size() == m_opt.nr_part,
  889. "nr_part=%zu but split callback returned %zu parts", m_opt.nr_part,
  890. partition.size());
  891. }
  892. size_t size = 0;
  893. for (size_t i = 0; i < m_opt.nr_part; ++i) {
  894. auto p = partition[i];
  895. size += p;
  896. auto&& cur = m_output_spec[i].shape;
  897. cur = ishp;
  898. cur.shape[axis] = p;
  899. }
  900. mgb_assert(
  901. size == ishp.shape[axis],
  902. "split size sums to %zd, but shape at the axis is %zd", size,
  903. ishp.shape[axis]);
  904. m_output_shape_version = inp.run_id;
  905. }
  906. dest = m_output_spec.at(out_idx).shape;
  907. return true;
  908. }
  909. void Split::init_output_comp_node() {
  910. auto&& conf_node = config().comp_node();
  911. auto&& cn_opt = owner_graph()->seq_comp_node_optimizer();
  912. // details of each comp_node specified
  913. if (conf_node.size() > 1) {
  914. mgb_assert(
  915. conf_node.size() == output().size(),
  916. "number of CompNodes specified in config should equal to number"
  917. " of output, but got %zd configured CompNodes while there are"
  918. " %zd output (node_name=%s node_type=%s)",
  919. conf_node.size(), output().size(), cname(), dyn_typeinfo()->name);
  920. auto cn0 = input(0)->comp_node();
  921. for (size_t i = 0; i < output().size(); i++) {
  922. auto dvar = output(i);
  923. dvar->comp_node(conf_node[i]);
  924. if (conf_node[i].mem_node() != cn0.mem_node())
  925. cn_opt.register_stream_var(
  926. dvar, {CompNode::Stream::COPY,
  927. cg::SeqCompNodeOptimizer::StreamPropType::WEAK});
  928. }
  929. return;
  930. }
  931. CompNode cn;
  932. if (conf_node.size() == 1) {
  933. cn = conf_node[0];
  934. } else {
  935. cn = input(0)->comp_node();
  936. }
  937. for (auto i : output())
  938. i->comp_node(cn);
  939. if (cn.mem_node() != input(0)->comp_node().mem_node()) {
  940. for (auto i : output())
  941. cn_opt.register_stream_var(
  942. i, {CompNode::Stream::COPY,
  943. cg::SeqCompNodeOptimizer::StreamPropType::WEAK});
  944. }
  945. }
  946. cg::OperatorNodeBase::NodeProp* Split::do_make_node_prop() const {
  947. auto rst = OperatorNodeBase::do_make_node_prop();
  948. rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  949. outshape_by_symvar_reset_node_dep_type(rst);
  950. rst->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
  951. return rst;
  952. }
  953. void Split::do_execute(ExecEnv& env) {
  954. for (size_t idx = 0; idx < output().size(); ++idx) {
  955. auto out = output(idx);
  956. if (!owner_graph()->var_receiver_in_current_comp_seq(out).value_needed())
  957. continue;
  958. auto runner = [idx, this]() {
  959. auto&& in = input(0)->dev_tensor();
  960. auto&& out = output(idx)->dev_tensor();
  961. auto&& spec = m_output_spec.at(idx);
  962. if (out.layout().is_empty()) {
  963. mgb_assert(spec.subspec.layout().is_empty());
  964. return;
  965. }
  966. owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(
  967. this, out.comp_node());
  968. if (spec.mem_fwd_success) {
  969. mgb_assert(out.raw_ptr() == in.raw_ptr() + spec.subspec.offset_byte());
  970. } else {
  971. out.comp_node().activate();
  972. out.copy_from_fixlayout(in.sub(spec.subspec));
  973. }
  974. owner_graph()->event().signal_inplace<cg::event::AfterKernel>(
  975. this, out.comp_node());
  976. };
  977. env.dispatch_on_comp_node(out->comp_node(), runner);
  978. }
  979. }
  980. #if MGB_ENABLE_GRAD
  981. MGB_IMPL_OPR_GRAD(Split) {
  982. if (wrt_idx)
  983. return InvalidGrad::make(opr, wrt_idx);
  984. mgb_assert(out_grad.size() == opr.output().size());
  985. SymbolVarArray grad;
  986. for (size_t i = 0; i < out_grad.size(); ++i) {
  987. auto gval = out_grad[i];
  988. if (!gval) {
  989. gval = SymbolVar{opr.output(i)}.fill_retain_dtype(0).node();
  990. }
  991. grad.emplace_back(gval);
  992. }
  993. return Concat::make(
  994. grad, opr.options().axis,
  995. OperatorNodeConfig{}.follow_comp_node(opr.input(0)))
  996. .node();
  997. }
  998. #endif
  999. void Split::mem_plan_fwd_in2out_readonly() {
  1000. m_readonly_fwd_called = true;
  1001. init_subspec(true);
  1002. }
  1003. void Split::init_subspec(bool memfwd) {
  1004. auto in = input(0);
  1005. size_t begin = 0, end = 0;
  1006. for (size_t i = 0; i < output().size(); ++i) {
  1007. auto&& spec = m_output_spec[i];
  1008. auto out = output(i);
  1009. auto real_axis = m_opt.axis;
  1010. if (real_axis < 0)
  1011. real_axis += spec.shape.ndim;
  1012. begin = end;
  1013. mgb_assert(out->shape().eq_shape(spec.shape));
  1014. end = begin + spec.shape.shape[real_axis];
  1015. spec.subspec = Slice(begin, end).apply(in->layout(), real_axis);
  1016. if (out->comp_node() == in->comp_node() && memfwd) {
  1017. spec.mem_fwd_success = out->set_fwd_in2out_readonly(in, spec.subspec);
  1018. } else {
  1019. spec.mem_fwd_success = false;
  1020. }
  1021. }
  1022. }
  1023. void Split::outshape_by_symvar_do_get_output_shape(
  1024. TensorShape& dest, const ShapeInferInfo& shpinfo) {
  1025. // shape infer handled in this class
  1026. MGB_MARK_USED_VAR(dest);
  1027. MGB_MARK_USED_VAR(shpinfo);
  1028. mgb_assert(0);
  1029. }
  1030. void Split::add_input_layout_constraint() {
  1031. m_readonly_fwd_called = false;
  1032. auto cn = input(0)->comp_node();
  1033. for (auto i : output())
  1034. if (i->comp_node() != cn) {
  1035. input(0)->add_layout_constraint_contiguous();
  1036. return;
  1037. }
  1038. }
  1039. void Split::on_mem_status_changed() {
  1040. if (!m_readonly_fwd_called) {
  1041. init_subspec(false);
  1042. }
  1043. }
  1044. cg::OperatorNodeBase::OprEventCallback Split::get_opr_event_callback() {
  1045. return {std::bind(&Split::on_mem_status_changed, this)};
  1046. }
  1047. void Split::on_output_comp_node_stream_changed() {}
  1048. void Split::init_rt_force_dynamic_mem_alloc_imply_chain() {
  1049. auto inp = input(0);
  1050. auto cn0 = inp->comp_node();
  1051. for (auto i : output()) {
  1052. if (i->comp_node() == cn0) {
  1053. i->add_rt_force_dynamic_mem_alloc_imply_chain(inp);
  1054. inp->add_rt_force_dynamic_mem_alloc_imply_chain(i);
  1055. }
  1056. }
  1057. }
  1058. // f}}}
  1059. /* f{{{ ======================= Concat ======================= */
  1060. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Concat);
  1061. Concat::Concat(const VarNodeArrayView& inp, int axis, const OperatorNodeConfig& config)
  1062. : Super{inp[0]->owner_graph(), config, "concat", inp}, m_axis(axis) {
  1063. mgb_assert(!inp.empty());
  1064. for (auto&& i : inp) {
  1065. add_input({i});
  1066. }
  1067. add_equivalence_component<ScalarHash<size_t>>(m_axis);
  1068. add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  1069. }
  1070. void Concat::get_output_var_shape(
  1071. const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
  1072. mgb_assert(inp_shape.size() == input().size());
  1073. mgb_assert(out_shape.size() == 1);
  1074. auto&& oshp = out_shape[0];
  1075. oshp = inp_shape[0];
  1076. mgb_throw_if(
  1077. m_axis >= static_cast<int>(oshp.ndim) ||
  1078. m_axis < -static_cast<int>(oshp.ndim),
  1079. GraphError, "concat axis out of bound: input_ndim=%zu axis=%d", oshp.ndim,
  1080. m_axis);
  1081. auto real_axis = m_axis;
  1082. if (real_axis < 0)
  1083. real_axis += oshp.ndim;
  1084. for (size_t i = 1; i < inp_shape.size(); ++i) {
  1085. auto&& tmp = inp_shape[i];
  1086. mgb_throw_if(
  1087. oshp.ndim != tmp.ndim, GraphError,
  1088. "ndim mismatch: shape=%s inp[%zd]=%s", oshp.to_string().c_str(), i,
  1089. tmp.to_string().c_str());
  1090. for (int n = 0; n < static_cast<int>(tmp.ndim); ++n) {
  1091. if (n == real_axis) {
  1092. oshp.shape[n] += tmp.shape[n];
  1093. } else {
  1094. mgb_throw_if(
  1095. oshp.shape[n] != tmp.shape[n], GraphError,
  1096. "Concat input shapes mismatch: "
  1097. "accum_out_shape=%s cur_inp_shape=%s inp_idx=%zu"
  1098. " axis_concat=%d axis_mismatch=%d",
  1099. oshp.to_string().c_str(), tmp.to_string().c_str(), i, real_axis,
  1100. n);
  1101. }
  1102. }
  1103. }
  1104. }
  1105. SymbolVar Concat::make(
  1106. const VarNodeArrayView& inp, int axis, const OperatorNodeConfig& config) {
  1107. mgb_assert(!inp.empty());
  1108. if (inp.size() == 1)
  1109. return inp[0];
  1110. intl::BatchedDTypePromotion dtp{inp};
  1111. return SymbolVar{inp[0]}.insert_single_output_opr<Concat>(
  1112. dtp.get_vars(), axis, config);
  1113. }
  1114. #if MGB_ENABLE_GRAD
  1115. MGB_IMPL_OPR_GRAD(Concat) {
  1116. auto axis = opr.axis();
  1117. mgb_assert(out_grad.size() == 1);
  1118. OperatorNodeConfig::CompNodeArray comp_node;
  1119. SymbolVarArray partition;
  1120. for (auto i : opr.input()) {
  1121. partition.push_back(GetVarShape::make(i, axis));
  1122. comp_node.push_back(i->comp_node());
  1123. }
  1124. auto ret = Split::make(
  1125. out_grad[0], Split::Options::make_partition(axis, partition),
  1126. OperatorNodeConfig().comp_node_arr(comp_node));
  1127. return cg::to_var_node_array(ret);
  1128. }
  1129. #endif
  1130. void Concat::scn_do_execute() {
  1131. auto&& out = output(0)->dev_tensor();
  1132. size_t end = 0;
  1133. for (auto&& input : this->input()) {
  1134. auto&& in = input->dev_tensor();
  1135. auto begin = end;
  1136. auto real_axis = m_axis;
  1137. if (real_axis < 0)
  1138. real_axis += in.shape().ndim;
  1139. end = begin + in.shape().shape[real_axis];
  1140. if (!in.layout().is_empty()) {
  1141. out.sub(Slice(begin, end).apply(out.layout(), real_axis))
  1142. .copy_from_fixlayout(in);
  1143. }
  1144. }
  1145. }
  1146. Concat::NodeProp* Concat::do_make_node_prop() const {
  1147. auto rst = Super::do_make_node_prop();
  1148. rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  1149. for (auto i : input()) {
  1150. rst->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY);
  1151. }
  1152. return rst;
  1153. }
  1154. void Concat::init_output_static_infer_desc() {
  1155. Super::init_output_static_infer_desc();
  1156. using namespace cg::static_infer;
  1157. auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) {
  1158. TensorShape oshp = inp.val[0].shape();
  1159. auto real_axis = m_axis;
  1160. if (real_axis < 0)
  1161. m_axis += oshp.ndim;
  1162. for (size_t i = 1; i < input().size(); ++i)
  1163. oshp.shape[real_axis] += inp.val.at(i).shape().shape[real_axis];
  1164. dest.resize(oshp);
  1165. size_t end = 0;
  1166. for (size_t i = 0; i < input().size(); ++i) {
  1167. auto begin = end;
  1168. end = begin + inp.val[i].shape().shape[real_axis];
  1169. dest.sub(Slice(begin, end).apply(dest.layout(), real_axis))
  1170. .copy_from_fixlayout(inp.val[i].value());
  1171. }
  1172. return true;
  1173. };
  1174. DepVal deps;
  1175. for (auto i : input())
  1176. deps.push_back({i, DepType::VALUE});
  1177. owner_graph()->static_infer_manager().register_value_infer(
  1178. output(0), {SourceType::DEP, deps, infer_value});
  1179. }
  1180. void Concat::add_input_layout_constraint() {
  1181. auto cn = output(0)->comp_node();
  1182. for (auto i : input()) {
  1183. if (i->comp_node() != cn) {
  1184. i->add_layout_constraint_contiguous();
  1185. }
  1186. }
  1187. }
  1188. void Concat::init_output_comp_node() {
  1189. Super::init_output_comp_node();
  1190. auto dcn = output(0)->comp_node();
  1191. for (auto i : input()) {
  1192. if (i->comp_node().mem_node() != dcn.mem_node()) {
  1193. owner_graph()->seq_comp_node_optimizer().register_stream_var(
  1194. output(0), {CompNode::Stream::COPY,
  1195. cg::SeqCompNodeOptimizer::StreamPropType::WEAK});
  1196. return;
  1197. }
  1198. }
  1199. }
  1200. // f}}}
  1201. /* f{{{ ======================= ParamPackConcat ======================= */
  1202. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat);
  1203. ParamPackConcat::ParamPackConcat(
  1204. VarNodeArray& inp, VarNode* table, const std::vector<dt_int32> offsets_val,
  1205. const OperatorNodeConfig& config)
  1206. : Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp),
  1207. m_offsets(offsets_val) {
  1208. CompNode cn = inp[0]->comp_node();
  1209. add_input({inp[0]});
  1210. for (size_t i = 1; i < inp.size(); i++) {
  1211. add_input({inp[i]});
  1212. mgb_assert(
  1213. cn == inp[i]->comp_node(),
  1214. "input var for param pack must in same comp node");
  1215. }
  1216. add_input({table});
  1217. add_output(None);
  1218. cg::add_workspace_output(this);
  1219. m_opr = intl::create_megdnn_opr<megdnn::ParamPackConcat>(cn);
  1220. }
  1221. void ParamPackConcat::add_input_layout_constraint() {
  1222. for (auto i : input()) {
  1223. i->add_layout_constraint_contiguous();
  1224. }
  1225. }
  1226. SymbolVar ParamPackConcat::make(
  1227. const SmallVector<SymbolVar>& inp, const SymbolVar& offsets,
  1228. const std::vector<dt_int32> offsets_val, const OperatorNodeConfig& config) {
  1229. VarNodeArray array(inp.size());
  1230. for (size_t i = 0; i < inp.size(); i++) {
  1231. array[i] = inp[i].node();
  1232. }
  1233. return inp.front().insert_single_output_opr<ParamPackConcat>(
  1234. array, offsets.node(), offsets_val, config);
  1235. }
  1236. void ParamPackConcat::scn_do_execute() {
  1237. mgb_assert(m_opr.comp_node() == comp_node());
  1238. auto&& inputs = input();
  1239. m_inp_ptr.resize(inputs.size() - 1);
  1240. auto ptr = m_inp_ptr.data();
  1241. for (size_t i = 0; i < inputs.size() - 1; i++) {
  1242. ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr();
  1243. }
  1244. auto offsets = inputs.back()->dev_tensor().as_megdnn();
  1245. megdnn::TensorND srcs(
  1246. ptr, megdnn::TensorLayout({inputs.size() - 1}, dtype::Int32()));
  1247. auto&& dst = output(0)->dev_tensor().as_megdnn();
  1248. m_opr->exec(srcs, offsets, dst, get_megdnn_workspace_from_var(output(1)));
  1249. }
  1250. void ParamPackConcat::init_output_dtype() {
  1251. output(0)->dtype(input(0)->dtype());
  1252. }
  1253. void ParamPackConcat::init_output_static_infer_desc() {
  1254. using namespace cg::static_infer;
  1255. auto&& mgr = owner_graph()->static_infer_manager();
  1256. auto infer_out = [this](TensorShape& dest, const InpVal& inp) {
  1257. dest = {static_cast<unsigned int>(m_offsets.back())};
  1258. return true;
  1259. };
  1260. DepVal shp_deps;
  1261. shp_deps.reserve(input().size());
  1262. for (auto&& inp : input()) {
  1263. shp_deps.emplace_back(DepElement{inp, DepType::SHAPE});
  1264. }
  1265. auto infer_wk = [this](TensorShape& dest, const InpVal& inp) {
  1266. TensorShapeArray shapes;
  1267. auto vals = inp.val;
  1268. size_t nr_params = vals.size() - 1;
  1269. dest = {m_opr->get_workspace_in_bytes({nr_params}, vals.back().shape(), dest)};
  1270. return true;
  1271. };
  1272. mgr.register_shape_infer(output(0), {SourceType::DEP, shp_deps, infer_out});
  1273. mgr.register_shape_infer(output(1), {SourceType::DEP, shp_deps, infer_wk});
  1274. }
  1275. void ParamPackConcat::on_output_comp_node_stream_changed() {
  1276. Super::on_output_comp_node_stream_changed();
  1277. m_opr = intl::create_megdnn_opr<megdnn::ParamPackConcat>(comp_node());
  1278. }
  1279. // f}}}
  1280. /* f{{{ ======================= ParamPackSplit ======================= */
  1281. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit);
  1282. ParamPackSplit::ParamPackSplit(
  1283. VarNode* src, const std::vector<dt_int32> offsets, TensorShapeArray& shapes,
  1284. const OperatorNodeConfig& config)
  1285. : Super{src->owner_graph(), config, "ParamPackSplit", {src}},
  1286. m_shapes(shapes),
  1287. m_offsets(offsets) {
  1288. add_input({src});
  1289. for (size_t i = 0; i < shapes.size(); i++) {
  1290. mgb_assert(shapes[i].total_nr_elems(), "empty param is not allowed!");
  1291. add_output(ssprintf("param_pack_o%zu", i))
  1292. ->dtype(src->dtype())
  1293. .shape(shapes[i]);
  1294. }
  1295. }
  1296. void ParamPackSplit::add_input_layout_constraint() {
  1297. input(0)->add_layout_constraint_contiguous();
  1298. }
  1299. SymbolVarArray ParamPackSplit::make(
  1300. const SymbolVar& src, const std::vector<dt_int32> offsets,
  1301. TensorShapeArray shapes, const OperatorNodeConfig& config) {
  1302. auto&& out = src.node()
  1303. ->owner_graph()
  1304. ->insert_opr(std::make_unique<ParamPackSplit>(
  1305. src.node(), offsets, shapes, config))
  1306. ->output();
  1307. SymbolVarArray ret;
  1308. ret.resize(out.size());
  1309. for (size_t i = 0; i < ret.size(); ++i) {
  1310. ret[i] = out[i];
  1311. }
  1312. return ret;
  1313. }
  1314. void ParamPackSplit::init_output_dtype() {
  1315. // already initialized in constructor
  1316. }
  1317. void ParamPackSplit::init_rt_force_dynamic_mem_alloc_imply_chain() {
  1318. for (size_t i = 0; i < output().size(); ++i) {
  1319. auto s = input(0), t = output(i);
  1320. s->add_rt_force_dynamic_mem_alloc_imply_chain(t);
  1321. t->add_rt_force_dynamic_mem_alloc_imply_chain(s);
  1322. }
  1323. }
  1324. void ParamPackSplit::mem_plan_fwd_in2out_readonly() {
  1325. mgb_assert(m_offsets.size() == output().size() * 2);
  1326. for (size_t i = 0; i < output().size(); i++) {
  1327. auto layout = output(i)->layout();
  1328. auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i * 2]);
  1329. mgb_assert(output(i)->set_fwd_in2out_readonly(input(0), spec));
  1330. }
  1331. }
  1332. bool ParamPackSplit::infer_shape(
  1333. size_t index, TensorShape& dest, const cg::static_infer::InpVal& inp) {
  1334. dest = m_shapes[index];
  1335. return true;
  1336. }
  1337. void ParamPackSplit::init_output_static_infer_desc() {
  1338. using namespace cg::static_infer;
  1339. using namespace std::placeholders;
  1340. auto&& mgr = owner_graph()->static_infer_manager();
  1341. for (size_t i = 0; i < output().size(); i++) {
  1342. auto ov = output(i);
  1343. mgr.register_shape_infer(
  1344. ov, {SourceType::CONSTANT,
  1345. {},
  1346. std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)});
  1347. }
  1348. }
  1349. void ParamPackSplit::scn_do_execute() {
  1350. int inp_size = input(0)->shape().total_nr_elems();
  1351. mgb_assert(inp_size == m_offsets.back(), "input shape should match offsets");
  1352. }
  1353. #if MGB_ENABLE_GRAD
  1354. MGB_IMPL_OPR_GRAD(ParamPackSplit) {
  1355. mgb_assert(out_grad.size() == opr.output().size());
  1356. SmallVector<SymbolVar> grad;
  1357. for (size_t i = 0; i < out_grad.size(); ++i) {
  1358. auto gval = out_grad[i];
  1359. if (!gval) {
  1360. gval = SymbolVar{opr.output(i)}.fill_retain_dtype(0).node();
  1361. }
  1362. grad.emplace_back(gval);
  1363. }
  1364. auto offsets_val = opr.get_offsets();
  1365. auto cn = opr.input(0)->comp_node();
  1366. if (opr.config().has_comp_node_set()) {
  1367. cn = opr.config().get_single_comp_node();
  1368. }
  1369. HostTensorND hv{cn, TensorShape{offsets_val.size()}, dtype::Int32{}};
  1370. memcpy(hv.raw_ptr(), offsets_val.data(), offsets_val.size() * sizeof(int));
  1371. auto offsets = opr::ImmutableTensor::make(*opr.input(0)->owner_graph(), hv);
  1372. return ParamPackConcat::make(
  1373. grad, offsets, offsets_val,
  1374. OperatorNodeConfig{}.follow_comp_node(opr.input(0)))
  1375. .node();
  1376. }
  1377. #endif
  1378. // f}}}
  1379. /* f{{{ ======================= RelayoutFormat ======================= */
  1380. namespace mgb {
  1381. namespace opr {
  1382. namespace intl {
  1383. template <>
  1384. struct MegDNNOprInitPostCtor<RelayoutFormat> {
  1385. static void apply(cg::OperatorNodeBase& opr) {
  1386. if (opr.config().output_dtype().valid()) {
  1387. opr.output(0)->dtype(opr.config().output_dtype());
  1388. } else {
  1389. opr.output(0)->dtype(opr.input(0)->dtype());
  1390. }
  1391. }
  1392. };
  1393. } // namespace intl
  1394. } // namespace opr
  1395. } // namespace mgb
  1396. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RelayoutFormat);
  1397. MEGDNN_OPR_INIT1(RelayoutFormat, "relayout_format")
  1398. void RelayoutFormat::init_output_format() {
  1399. TensorFormat src_fmt = input(0)->format(), dst_fmt;
  1400. megdnn_opr()->deduce_format(src_fmt, dst_fmt);
  1401. mgb_assert(output().size() == 2);
  1402. output(0)->format(dst_fmt);
  1403. output(1)->format({}); // default format
  1404. }
  1405. // f}}}
  1406. //
  1407. /* f{{{ ======================= PaddingForward ======================= */
  1408. MGB_DYN_TYPE_OBJ_FINAL_IMPL(PaddingForward);
  1409. MEGDNN_OPR_INIT1(PaddingForward, "padding")
  1410. #if MGB_ENABLE_GRAD
  1411. MGB_IMPL_OPR_GRAD(PaddingForward) {
  1412. mgb_assert(opr.input().size() == 1);
  1413. if (wrt_idx == 0) {
  1414. SymbolVar grad = PaddingBackward::make(out_grad[0], opr.input(0), opr.param());
  1415. return grad.node();
  1416. } else
  1417. return InvalidGrad::make(opr, wrt_idx);
  1418. }
  1419. #endif
  1420. // f}}}
  1421. /* f{{{ ======================= PaddingBackward ======================= */
  1422. MGB_DYN_TYPE_OBJ_FINAL_IMPL(PaddingBackward);
  1423. MEGDNN_OPR_INIT2(PaddingBackward, "padding_backward", 1, false);
  1424. // f}}}
  1425. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}