You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

custom_opnode.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. #include "megbrain/opr/custom_opnode.h"
  2. #if MGB_CUSTOM_OP
  3. namespace mgb {
  4. namespace opr {
  5. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpNode);
  6. void CustomOpNode::infer_output_comp_node(void) {
  7. SmallVector<CompNode> input_comp_nodes(input_num());
  8. for (size_t i = 0; i < input_num(); ++i) {
  9. input_comp_nodes[i] = input(i)->comp_node();
  10. }
  11. SmallVector<CompNode> output_comp_nodes =
  12. custom::to_builtin<CompNode, custom::Device>(m_op->infer_output_device(
  13. custom::to_custom<CompNode, custom::Device>(input_comp_nodes),
  14. m_param));
  15. for (size_t i = 0; i < output_num(); ++i) {
  16. mgb_assert(
  17. output_comp_nodes[i] == output_comp_nodes[0],
  18. "only single comp node operator is supported");
  19. output(i)->comp_node(output_comp_nodes[i]);
  20. }
  21. m_comp_node = output_comp_nodes[0];
  22. }
  23. void CustomOpNode::infer_output_dtype(void) {
  24. SmallVector<DType> input_dtypes(input_num());
  25. for (size_t i = 0; i < input_num(); ++i) {
  26. input_dtypes[i] = input(i)->dtype();
  27. }
  28. SmallVector<DType> output_dtypes =
  29. custom::to_builtin<megdnn::DType, custom::DType>(m_op->infer_output_dtype(
  30. custom::to_custom<megdnn::DType, custom::DType>(input_dtypes),
  31. m_param));
  32. for (size_t i = 0; i < output_num(); ++i) {
  33. output(i)->dtype(output_dtypes[i]);
  34. }
  35. }
  36. void CustomOpNode::infer_output_format(void) {
  37. SmallVector<TensorFormat> input_formats(input_num());
  38. for (size_t i = 0; i < input_num(); ++i) {
  39. input_formats[i] = input(i)->format();
  40. }
  41. SmallVector<TensorFormat> output_formats =
  42. custom::to_builtin<TensorFormat, custom::Format>(m_op->infer_output_format(
  43. custom::to_custom<TensorFormat, custom::Format>(input_formats),
  44. m_param));
  45. for (size_t i = 0; i < output_num(); ++i) {
  46. output(i)->format(output_formats[i]);
  47. }
  48. }
  49. void CustomOpNode::infer_output_shape(void) {
  50. SmallVector<TensorShape> input_shapes(input_num());
  51. for (size_t i = 0; i < input_num(); ++i) {
  52. input_shapes[i] = input(i)->shape();
  53. }
  54. SmallVector<TensorShape> output_shapes =
  55. custom::to_builtin<TensorShape, custom::Shape>(m_op->infer_output_shape(
  56. custom::to_custom<TensorShape, custom::Shape>(input_shapes),
  57. m_param));
  58. for (size_t i = 0; i < output_num(); ++i) {
  59. output(i)->shape(output_shapes[i]);
  60. }
  61. }
  62. void CustomOpNode::infer_output_shape(
  63. const TensorShapeArray& input_shapes, TensorShapeArray& output_shapes) {
  64. output_shapes =
  65. custom::to_builtin<TensorShape, custom::Shape>(m_op->infer_output_shape(
  66. custom::to_custom<TensorShape, custom::Shape>(input_shapes),
  67. m_param));
  68. }
  69. // called by computing_graph for each output varnode
  70. bool CustomOpNode::infer_desc(
  71. size_t out_idx, TensorShape& output_shape,
  72. const StaticInferInpVal& input_vals) {
  73. TensorShapeArray input_shapes(input_vals.val.size());
  74. TensorShapeArray output_shapes(output_num());
  75. for (size_t i = 0; i < input_shapes.size(); ++i) {
  76. input_shapes[i] = input_vals.val[i].shape();
  77. }
  78. infer_output_shape(input_shapes, output_shapes);
  79. output_shape = output_shapes.at(out_idx);
  80. return true;
  81. }
  82. void CustomOpNode::init_output_dtype() {
  83. infer_output_dtype();
  84. }
  85. void CustomOpNode::init_output_format() {
  86. infer_output_format();
  87. }
  88. void CustomOpNode::init_output_comp_node() {
  89. infer_output_comp_node();
  90. }
  91. void CustomOpNode::do_execute(ExecEnv& env) {
  92. auto runner = [this]() {
  93. this->owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(
  94. this, m_comp_node);
  95. m_comp_node.activate();
  96. SmallVector<DeviceTensorND> inputs, outputs;
  97. for (size_t i = 0; i < input_num(); i++)
  98. inputs.push_back(input(i)->dev_tensor());
  99. for (size_t i = 0; i < output_num(); i++)
  100. outputs.push_back(output(i)->dev_tensor());
  101. std::vector<custom::Tensor> custom_inputs =
  102. custom::to_custom<DeviceTensorND, custom::Tensor>(inputs);
  103. std::vector<custom::Tensor> custom_outputs =
  104. custom::to_custom<DeviceTensorND, custom::Tensor>(outputs);
  105. m_op->compute(custom_inputs, m_param, custom_outputs);
  106. // [TODO] sync should be modified
  107. CompNode::sync_all();
  108. this->owner_graph()->event().signal_inplace<cg::event::AfterKernel>(
  109. this, m_comp_node);
  110. };
  111. env.dispatch_on_comp_node(m_comp_node, runner);
  112. }
  113. void CustomOpNode::init_output_static_infer_desc() {
  114. using namespace std::placeholders;
  115. using namespace cg::static_infer;
  116. m_out_shape.resize(output_num());
  117. auto&& mgr = owner_graph()->static_infer_manager();
  118. DepVal dep;
  119. // [TODO] need design a interface to allow user to decide it
  120. if (true) {
  121. for (auto input_var : input())
  122. dep.push_back({input_var, DepType::SHAPE});
  123. } else {
  124. for (auto input_var : input())
  125. dep.push_back({input_var, DepType::VALUE});
  126. }
  127. for (size_t i = 0; i < output_num(); ++i) {
  128. mgr.register_shape_infer(
  129. output(i), {dep.empty() ? SourceType::CONSTANT : SourceType::DEP, dep,
  130. std::bind(&CustomOpNode::infer_desc, this, i, _1, _2)});
  131. }
  132. }
  133. void CustomOpNode::init_output_mem_plan(bool dynamic) {
  134. for (auto output_var : output()) {
  135. if (cg::is_static_var_storage(output_var) == !dynamic &&
  136. !output_var->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC))
  137. output_var->init_mem_plan();
  138. }
  139. }
  140. void CustomOpNode::init_rt_force_dynamic_mem_alloc_imply_chain() {}
  141. void CustomOpNode::add_input_layout_constraint() {
  142. for (auto&& input_var : input()) {
  143. input_var->add_layout_constraint_contiguous();
  144. }
  145. }
  146. void CustomOpNode::mem_plan_fwd_in2out_readonly() {}
  147. void CustomOpNode::mem_plan_fwd_in2out_writable() {}
  148. cg::OperatorNodeBase::OprEventCallback CustomOpNode::get_opr_event_callback() {
  149. return {};
  150. }
  151. void CustomOpNode::on_output_comp_node_stream_changed() {
  152. for (auto output_var : output()) {
  153. if (output_var->comp_node() != m_comp_node) {
  154. mgb_assert(output_var->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
  155. output_var->comp_node(m_comp_node);
  156. }
  157. }
  158. }
  159. cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const {
  160. return OperatorNodeBase::do_make_node_prop();
  161. }
  162. bool CustomOpNode::update_priority() const {
  163. if (output_num() == 1 &&
  164. output()[0]->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) {
  165. node_prop().attribute().priority =
  166. std::numeric_limits<decltype(NodeProp::Attribute::priority)>::min();
  167. return true;
  168. }
  169. return false;
  170. }
  171. CustomOpNode::CustomOpNode(
  172. const std::shared_ptr<const custom::CustomOp>& op, VarNodeArray inputs,
  173. const custom::Param& param, const OperatorNodeConfig& config)
  174. : OperatorNodeBase(inputs[0]->owner_graph(), config, op->op_type(), inputs),
  175. m_op(op),
  176. m_param(param) {
  177. mgb_assert(input_num() == inputs.size(), "wrong input tensors list length");
  178. for (size_t i = 0; i < input_num(); ++i)
  179. add_input({inputs[i]});
  180. for (size_t i = 0; i < output_num(); ++i)
  181. add_output(output_info(i).name());
  182. if (!std::is_empty<custom::Param>::value) {
  183. using step = unsigned long;
  184. size_t STEP_SIZE = sizeof(step);
  185. std::string hash_str = std::to_string(op->runtime_id());
  186. for (auto&& val : param.raw()) {
  187. hash_str += val.first;
  188. hash_str += val.second.str();
  189. }
  190. if (hash_str.size() % STEP_SIZE != 0)
  191. hash_str += std::string(STEP_SIZE - (hash_str.size() % STEP_SIZE), ' ');
  192. for (size_t pos = 0; pos < hash_str.size(); pos += STEP_SIZE)
  193. add_equivalence_component<PODHash<step>>(
  194. reinterpret_cast<const step*>(hash_str.c_str() + pos));
  195. }
  196. }
  197. VarNodeArray CustomOpNode::make(
  198. const std::shared_ptr<const custom::CustomOp>& op, VarNodeArray inputs,
  199. const custom::Param& param, const OperatorNodeConfig& config) {
  200. auto&& outputs = inputs[0]
  201. ->owner_graph()
  202. ->insert_opr(std::make_unique<CustomOpNode>(
  203. op, inputs, param, config))
  204. ->output();
  205. return outputs;
  206. }
  207. SymbolVarArray CustomOpNode::make(
  208. const std::shared_ptr<const custom::CustomOp>& op, SymbolVarArray inputs,
  209. const custom::Param& param, const OperatorNodeConfig& config) {
  210. VarNodeArray input_vars(inputs.size());
  211. for (size_t i = 0; i < input_vars.size(); ++i)
  212. input_vars[i] = inputs[i].node();
  213. auto&& outputs = inputs[0]
  214. .node()
  215. ->owner_graph()
  216. ->insert_opr(std::make_unique<CustomOpNode>(
  217. op, input_vars, param, config))
  218. ->output();
  219. SymbolVarArray ret(outputs.size());
  220. for (size_t i = 0; i < ret.size(); ++i)
  221. ret[i] = outputs[i];
  222. return ret;
  223. }
  224. custom::RunTimeId CustomOpNode::runtime_id() const {
  225. return m_op->runtime_id();
  226. }
  227. uint32_t CustomOpNode::param_tag(void) const {
  228. return m_op->param_info().tag();
  229. }
  230. custom::Param& CustomOpNode::param(void) {
  231. return m_param;
  232. }
  233. custom::Param CustomOpNode::param(void) const {
  234. return m_param;
  235. }
  236. // a series of functions with the same names as CustomOpImpl
  237. std::string CustomOpNode::op_type(void) const {
  238. return m_op->op_type();
  239. }
  240. std::string CustomOpNode::op_desc(void) const {
  241. return m_op->op_desc();
  242. }
  243. size_t CustomOpNode::input_num(void) const {
  244. return m_op->input_num();
  245. }
  246. size_t CustomOpNode::output_num(void) const {
  247. return m_op->output_num();
  248. }
  249. custom::ArgInfo CustomOpNode::input_info(size_t idx) const {
  250. return m_op->input_info(idx);
  251. }
  252. custom::ArgInfo CustomOpNode::output_info(size_t idx) const {
  253. return m_op->output_info(idx);
  254. }
  255. } // namespace opr
  256. } // namespace mgb
  257. #endif