You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

forward_sereg.cpp 10 kB


  1. #include "./forward_sereg.h"
  2. #include "./impl.h"
  3. #include "megbrain/opr/internal/param_tag_defs.h"
  4. #include "megbrain/serialization/opr_load_dump.h"
  5. #include "megbrain/serialization/opr_shallow_copy.h"
  6. #include "megbrain/serialization/serializer.h"
  7. using namespace mgb;
  8. using namespace mgb::opr::intl;
  9. using namespace mgb::serialization;
  10. namespace {
  11. class LoopDumpContext : public UserDataContainer::UserData {
  12. MGB_TYPEINFO_OBJ_DECL;
  13. public:
  14. ThinHashMap<VarNode*, size_t> ogvar2inpidx;
  15. static LoopDumpContext& from_dump_ctx(OprDumpContext& ctx) {
  16. auto ret = ctx.config().user_data->get_user_data<LoopDumpContext>();
  17. mgb_assert(ret.second);
  18. return *ret.first[ret.second - 1];
  19. }
  20. };
  21. class LoopLoadContext : public UserDataContainer::UserData {
  22. MGB_TYPEINFO_OBJ_DECL;
  23. public:
  24. const VarNodeArray& input_vars;
  25. opr::Loop::Desc& desc;
  26. LoopLoadContext(const VarNodeArray& input_vars_, opr::Loop::Desc& desc_)
  27. : input_vars{input_vars_}, desc{desc_} {}
  28. static LoopLoadContext& from_load_ctx(OprLoadContext& ctx) {
  29. auto ret = ctx.config().user_data->get_user_data<LoopLoadContext>();
  30. mgb_assert(ret.second);
  31. return *ret.first[ret.second - 1];
  32. }
  33. };
  34. MGB_TYPEINFO_OBJ_IMPL(LoopDumpContext);
  35. MGB_TYPEINFO_OBJ_IMPL(LoopLoadContext);
  36. } // anonymous namespace
  37. namespace mgb {
  38. namespace opr {
  39. namespace intl {
  40. //! use LoopSerializer because it is friend of LoopImpl
  41. class LoopSerializer {
  42. using InputMaker = LoopImpl::InputMaker;
  43. using CounterProvider = LoopImpl::DescImplBase::CounterProvider;
  44. struct LoopParam {
  45. static constexpr uint32_t TAG = opr::param_tag::LOOP;
  46. Loop::Param opr_param;
  47. uint64_t cond_var_id;
  48. };
  49. struct InputMakerParam {
  50. static constexpr uint32_t TAG = opr::param_tag::LOOP_INPUT_MAKER;
  51. bool has_assign;
  52. uint64_t ogvar_id; //! id of proxied var in owner graph
  53. };
  54. struct OutputListEntry {
  55. uint64_t subvar_id;
  56. LoopImpl::Desc::OutputMode mode;
  57. } MGB_PACKED;
  58. struct AssignListEntry {
  59. uint64_t dst_id, src_id;
  60. };
  61. static void dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr);
  62. static void dump_input_maker(OprDumpContext& ctx, const cg::OperatorNodeBase& opr);
  63. static void dump_counter_provider(
  64. OprDumpContext& ctx, const cg::OperatorNodeBase& opr);
  65. static cg::OperatorNodeBase* load_loop(
  66. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  67. const OperatorNodeConfig& config);
  68. static cg::OperatorNodeBase* load_input_maker(
  69. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  70. const OperatorNodeConfig& config);
  71. static cg::OperatorNodeBase* load_counter_provider(
  72. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  73. const OperatorNodeConfig& config);
  74. public:
  75. static void reg_all();
  76. // we need dedicated shallow_copy because some oprs can be copied
  77. // but can not be dumped; also record InterGraphVarTransformer
  78. static cg::OperatorNodeBase* shallow_copy(
  79. const OprShallowCopyContext& orig_ctx, const Loop& opr,
  80. const VarNodeArray& inputs, const OperatorNodeConfig& config);
  81. };
  82. } // namespace intl
  83. } // namespace opr
  84. } // namespace mgb
  85. namespace mgb {
  86. namespace serialization {
  87. namespace fbs {
  88. template <>
  89. struct SupportFlatBuffersSerialization<opr::intl::LoopSerializer::LoopParam> : No {};
  90. template <>
  91. struct SupportFlatBuffersSerialization<opr::intl::LoopSerializer::InputMakerParam>
  92. : No {};
  93. } // namespace fbs
  94. } // namespace serialization
  95. } // namespace mgb
  96. cg::OperatorNodeBase* serialization::opr_shallow_copy_loop(
  97. const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr,
  98. const VarNodeArray& inputs, const OperatorNodeConfig& config) {
  99. return opr::intl::LoopSerializer::shallow_copy(
  100. ctx, opr.cast_final_safe<opr::Loop>(), inputs, config);
  101. }
  102. void LoopSerializer::reg_all() {
  103. MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop, true);
  104. MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker, true);
  105. MGB_SEREG_OPR_INTL_CALL_ADD(
  106. CounterProvider, dump_counter_provider, load_counter_provider, true);
  107. MGB_SEREG_OPR_INTL_CALL_ADD_V2(
  108. opr::Loop, dump_loop, load_loop, nullptr, 2, CURRENT_VERSION);
  109. MGB_SEREG_OPR_INTL_CALL_ADD_V2(
  110. InputMaker, dump_input_maker, load_input_maker, nullptr, 2,
  111. CURRENT_VERSION);
  112. MGB_SEREG_OPR_INTL_CALL_ADD_V2(
  113. CounterProvider, dump_counter_provider, load_counter_provider, nullptr, 2,
  114. CURRENT_VERSION);
  115. }
  116. void LoopSerializer::dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
  117. bool dump_implemented = false;
  118. mgb_throw_if(
  119. !dump_implemented, SerializationError,
  120. "Serialization of Loop opr not implemented");
  121. }
  122. void LoopSerializer::dump_input_maker(
  123. OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
  124. auto&& ogvar2inpidx = LoopDumpContext::from_dump_ctx(ctx).ogvar2inpidx;
  125. auto&& opr_im = opr.cast_final_safe<InputMaker>();
  126. ctx.write_param<InputMakerParam>(
  127. {opr_im.param().has_assign, ogvar2inpidx.at(opr_im.orig_var())});
  128. }
  129. void LoopSerializer::dump_counter_provider(
  130. OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
  131. // there is nothing needs to do
  132. MGB_MARK_USED_VAR(ctx);
  133. MGB_MARK_USED_VAR(opr);
  134. }
  135. cg::OperatorNodeBase* LoopSerializer::load_loop(
  136. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  137. const OperatorNodeConfig& config) {
  138. bool load_implemented = false;
  139. cg::OperatorNodeBase* load_result = nullptr;
  140. mgb_throw_if(
  141. !load_implemented, SerializationError,
  142. "Serialization of Loop opr not implemented");
  143. return load_result;
  144. }
  145. cg::OperatorNodeBase* LoopSerializer::load_input_maker(
  146. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  147. const OperatorNodeConfig& config) {
  148. MGB_MARK_USED_VAR(config);
  149. auto&& loop_load_ctx = LoopLoadContext::from_load_ctx(ctx);
  150. auto param = ctx.read_param<InputMakerParam>();
  151. return loop_load_ctx.desc
  152. .add_input(loop_load_ctx.input_vars.at(param.ogvar_id), param.has_assign)
  153. .node()
  154. ->owner_opr();
  155. }
  156. cg::OperatorNodeBase* LoopSerializer::load_counter_provider(
  157. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  158. const OperatorNodeConfig& config) {
  159. MGB_MARK_USED_VAR(inputs);
  160. mgb_assert(inputs.empty());
  161. auto&& loop_load_ctx = LoopLoadContext::from_load_ctx(ctx);
  162. return loop_load_ctx.desc.get_counter_var().node()->owner_opr();
  163. }
  164. cg::OperatorNodeBase* LoopSerializer::shallow_copy(
  165. const OprShallowCopyContext& orig_ctx, const Loop& opr,
  166. const VarNodeArray& inputs, const OperatorNodeConfig& config) {
  167. auto orig_desc = static_cast<LoopImpl::FwdDesc*>(opr.m_desc.get());
  168. ThinHashMap<VarNode*, size_t> ogvar2inpidx;
  169. mgb_assert(inputs.size() == opr.input().size());
  170. for (size_t i = 0; i < inputs.size(); ++i)
  171. ogvar2inpidx[opr.input(i)] = i;
  172. VarNodeArray cur_opr_inputs;
  173. auto varmap_buf = std::make_shared<ThinHashMap<VarNode*, VarNode*>>();
  174. auto desc_maker = [&](Loop::Desc& desc) {
  175. ThinHashMap<VarNode*, LoopImpl::InputMaker*> assignee2orig_im;
  176. auto&& varmap = *varmap_buf;
  177. // add inputs
  178. OprShallowCopyContext ctx{orig_ctx};
  179. for (auto inp : orig_desc->all_inputs()) {
  180. auto ogvar = inputs.at(ogvar2inpidx.at(inp->orig_var()));
  181. auto subvar = desc.add_input(ogvar, inp->param().has_assign);
  182. varmap[inp->output(0)] = subvar.node();
  183. if (inp->param().has_assign) {
  184. assignee2orig_im[subvar.node()] = inp;
  185. }
  186. ctx.owner_graph(subvar.node()->owner_graph());
  187. }
  188. // copy oprs
  189. for (auto opr : orig_desc->sub_graph_oprs()) {
  190. if (opr->same_type<LoopImpl::InputMaker>()) {
  191. continue;
  192. }
  193. if (opr->same_type<LoopImpl::DescImplBase::CounterProvider>()) {
  194. varmap[opr->output(0)] = desc.get_counter_var().node();
  195. } else {
  196. cur_opr_inputs.clear();
  197. for (auto i : opr->input())
  198. cur_opr_inputs.push_back(varmap.at(i));
  199. auto new_opr =
  200. copy_opr_shallow(*opr, cur_opr_inputs, opr->config(), ctx);
  201. mgb_assert(new_opr->output().size() == opr->output().size());
  202. for (size_t i = 0; i < new_opr->output().size(); ++i)
  203. varmap[opr->output(i)] = new_opr->output(i);
  204. }
  205. }
  206. // add outputs in original order
  207. for (auto&& i : orig_desc->output_record_spec_no_dedup()) {
  208. desc.add_output(varmap.at(i->var_sub()), i->output_mode());
  209. }
  210. // add assignments
  211. for (auto&& i : assignee2orig_im) {
  212. desc.assign(i.first, varmap.at(i.second->assignor()));
  213. }
  214. desc.set_loop_condition(varmap.at(orig_desc->loop_cond_manager().var().node()));
  215. };
  216. auto&& ret =
  217. opr::Loop::make(desc_maker)[0].node()->owner_opr()->cast_final_safe<Loop>();
  218. mgb_assert(ret.output().size() == opr.output().size());
  219. auto trans_src_var = [varmap_buf](VarNode* src) -> VarNode* {
  220. auto iter = varmap_buf->find(src);
  221. mgb_throw_if(
  222. iter == varmap_buf->end(), GraphError,
  223. "loop fwd shallow copy: "
  224. "can not to get copied var from unused src var: %s",
  225. cg::dump_var_info({src}).c_str());
  226. return iter->second;
  227. };
  228. cg::InterGraphVarTransformer::register_to(
  229. ret.m_desc->sub_graph(), opr.m_desc->sub_graph(), trans_src_var);
  230. return &ret;
  231. }
  232. void LoopSerializerReg::entry() {
  233. LoopSerializer::reg_all();
  234. }
  235. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}