You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

extern_c_opr.cpp 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. #include "megbrain/serialization/extern_c_opr.h"
  2. #include "megbrain/comp_node_env.h"
  3. #include "megbrain/graph/extern_copr_api.h"
  4. #include "megbrain/serialization/extern_c_opr_io.h"
  5. #include "megbrain/serialization/opr_load_dump.h"
  6. #include <cstdlib>
  7. using namespace mgb;
  8. using namespace serialization;
  9. using namespace opr;
  10. namespace {
  11. const char PLACEHOLDER_TYPE_NAME[] = "placeholder";
  12. typedef MGBOprDesc* (*opr_desc_transformer_t)(void* input);
  13. using LoaderMap = std::unordered_map<
  14. std::string, std::pair<MGBOprLoader, opr_desc_transformer_t>>;
  15. //! singleton LoaderMap
  16. LoaderMap& loader_map() {
  17. static LoaderMap ret;
  18. return ret;
  19. }
  20. class MGBOprDescHash final : public HashableVD {
  21. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  22. MGBOprDesc* const m_desc;
  23. bool is_same_st(const Hashable& rhs) const override {
  24. return m_desc->is_same(m_desc, static_cast<const MGBOprDescHash&>(rhs).m_desc);
  25. }
  26. public:
  27. MGBOprDescHash(MGBOprDesc* desc) : m_desc{desc} {}
  28. size_t hash() const override { return m_desc->hash(m_desc); }
  29. };
  30. MGB_DYN_TYPE_OBJ_FINAL_IMPL(MGBOprDescHash);
  31. MGBDType dtype_cpp2c(DType dtype) {
  32. switch (dtype.enumv()) {
  33. case DTypeEnum::Float32:
  34. return MGB_DTYPE_FLOAT32;
  35. case DTypeEnum::Int32:
  36. return MGB_DTYPE_INT32;
  37. case DTypeEnum::Int16:
  38. return MGB_DTYPE_INT16;
  39. case DTypeEnum::Uint8:
  40. return MGB_DTYPE_UINT8;
  41. #if !MEGDNN_DISABLE_FLOAT16
  42. case DTypeEnum::Float16:
  43. return MGB_DTYPE_FLOAT16;
  44. #endif
  45. default:
  46. mgb_throw(
  47. InternalError, "unsupported dtype for extern C API: %s",
  48. dtype.name());
  49. }
  50. }
  51. DType dtype_c2cpp(MGBDType dtype) {
  52. switch (dtype) {
  53. case MGB_DTYPE_UINT8:
  54. return dtype::Uint8{};
  55. case MGB_DTYPE_INT16:
  56. return dtype::Int16{};
  57. case MGB_DTYPE_INT32:
  58. return dtype::Int32{};
  59. case MGB_DTYPE_FLOAT32:
  60. return dtype::Float32{};
  61. #if !MEGDNN_DISABLE_FLOAT16
  62. case MGB_DTYPE_FLOAT16:
  63. return dtype::Float16{};
  64. #endif
  65. default:
  66. mgb_throw(
  67. SerializationError, "bad dtype value: %d", static_cast<int>(dtype));
  68. }
  69. }
  70. template <typename S>
  71. MGBTensor tensor_to_c(const TensorND<S>& src) {
  72. MGBTensor ret;
  73. ret.data = const_cast<void*>(static_cast<const void*>(src.raw_ptr()));
  74. ret.layout.dtype = dtype_cpp2c(src.dtype());
  75. ret.layout.shape = ExternCOprRunner::tensor_shape_to_c(src.shape());
  76. return ret;
  77. }
  78. struct MGBOprDescV23 {
  79. size_t nr_input, nr_output;
  80. //! operator type name
  81. const char* type_name;
  82. //! release this descriptor
  83. void (*release)(MGBOprDescV23* self);
  84. //! compute hash
  85. size_t (*hash)(const MGBOprDescV23* self);
  86. //! equality check
  87. int (*is_same)(const MGBOprDescV23* self, const MGBOprDescV23* rhs);
  88. //! perform the computation
  89. void (*execute)(
  90. const MGBOprDescV23* self, const MGBTensor* input, const MGBTensor* output);
  91. //! infer output shapes from input shapes
  92. void (*infer_shape)(
  93. const MGBOprDescV23* self, const MGBTensorShape* input,
  94. MGBTensorShape* output);
  95. //! custom user data to be associated with this descriptor
  96. void* user_data;
  97. static MGBOprDesc* as_opr_desc(void* v23_raw) {
  98. auto release = [](MGBOprDesc* self) {
  99. auto p = static_cast<MGBOprDescV23*>(self->user_data);
  100. p->release(p);
  101. delete self;
  102. };
  103. auto hash = [](const MGBOprDesc* self) {
  104. auto p = static_cast<MGBOprDescV23*>(self->user_data);
  105. return p->hash(p);
  106. };
  107. auto is_same = [](const MGBOprDesc* self, const MGBOprDesc* rhs) {
  108. auto p0 = static_cast<MGBOprDescV23*>(self->user_data);
  109. auto p1 = static_cast<MGBOprDescV23*>(rhs->user_data);
  110. return p0->is_same(p0, p1);
  111. };
  112. auto execute = [](const MGBOprDesc* self, const MGBTensor* input,
  113. const MGBTensor* output) {
  114. auto p = static_cast<MGBOprDescV23*>(self->user_data);
  115. p->execute(p, input, output);
  116. };
  117. auto infer_shape = [](const MGBOprDesc* self, const MGBTensorShape* input,
  118. MGBTensorShape* output) {
  119. auto p = static_cast<MGBOprDescV23*>(self->user_data);
  120. p->infer_shape(p, input, output);
  121. };
  122. auto v23 = static_cast<MGBOprDescV23*>(v23_raw);
  123. auto ret = std::make_unique<MGBOprDesc>();
  124. mgb_init_opr_desc(ret.get(), v23->nr_output, v23->type_name);
  125. ret->user_data = v23;
  126. #define ASSIGN(name) ret->name = name;
  127. MGB_OPR_DESC_FOREACH_MEM_FN(ASSIGN);
  128. #undef ASSIGN
  129. return ret.release();
  130. }
  131. };
  132. //! impl MGBOprDesc for ExternCOprRunner::make_placeholder
  133. class PlaceholderMGBOprDesc {
  134. struct UserData {
  135. std::string name;
  136. TensorShapeArray output_shapes;
  137. SmallVector<DType> output_dtypes;
  138. std::unique_ptr<uint8_t[]> data;
  139. size_t data_len;
  140. };
  141. static UserData* user_data(const MGBOprDesc* self) {
  142. return static_cast<UserData*>(self->user_data);
  143. }
  144. static void release(MGBOprDesc* self) {
  145. user_data(self)->~UserData();
  146. ::free(self);
  147. }
  148. static size_t hash(const MGBOprDesc* self) {
  149. return reinterpret_cast<size_t>(self); // hash disabled
  150. }
  151. static int is_same(const MGBOprDesc* self, const MGBOprDesc* rhs) {
  152. return self == rhs;
  153. }
  154. //! perform the computation
  155. static void execute(const MGBOprDesc*, const MGBTensor*, const MGBTensor*) {
  156. mgb_throw(MegBrainError, "placeholder ExternCOprRunner can not be executed");
  157. }
  158. static void infer_shape(
  159. const MGBOprDesc* self, const MGBTensorShape* input,
  160. MGBTensorShape* output);
  161. static void infer_dtype(
  162. const struct MGBOprDesc* self, const MGBDType* input, MGBDType* output);
  163. public:
  164. static MGBOprDesc* make(
  165. size_t nr_input, const char* name, const TensorShapeArray& output_shapes,
  166. const SmallVector<DType>& output_dtypes, const void* data, size_t data_len);
  167. static void dump(OprDumpContext& ctx, MGBOprDesc* desc);
  168. };
  169. } // anonymous namespace
  170. /* ===================== PlaceholderMGBOprDesc ===================== */
  171. void PlaceholderMGBOprDesc::infer_shape(
  172. const MGBOprDesc* self, const MGBTensorShape* input, MGBTensorShape* output) {
  173. auto ud = user_data(self);
  174. for (size_t i = 0; i < ud->output_shapes.size(); ++i) {
  175. output[i] = ExternCOprRunner::tensor_shape_to_c(ud->output_shapes[i]);
  176. }
  177. }
  178. void PlaceholderMGBOprDesc::infer_dtype(
  179. const struct MGBOprDesc* self, const MGBDType* input, MGBDType* output) {
  180. auto ud = user_data(self);
  181. for (size_t i = 0; i < ud->output_dtypes.size(); ++i) {
  182. output[i] = dtype_cpp2c(ud->output_dtypes[i]);
  183. }
  184. }
  185. MGBOprDesc* PlaceholderMGBOprDesc::make(
  186. size_t nr_input, const char* name, const TensorShapeArray& output_shapes,
  187. const SmallVector<DType>& output_dtypes, const void* data, size_t data_len) {
  188. constexpr size_t align = std::max(alignof(MGBOprDesc), alignof(UserData)),
  189. desc_size = ((sizeof(MGBOprDesc) - 1) / align + 1) * align;
  190. std::unique_ptr<uint8_t, void (*)(void*)> ptr(
  191. static_cast<uint8_t*>(malloc(desc_size + sizeof(UserData))), ::free);
  192. mgb_assert(ptr);
  193. auto del_ud = [](UserData* p) { p->~UserData(); };
  194. std::unique_ptr<UserData, decltype(del_ud)> ud(
  195. new (ptr.get() + desc_size) UserData, del_ud);
  196. ud->name = name;
  197. ud->output_shapes = output_shapes;
  198. ud->output_dtypes = output_dtypes;
  199. ud->data.reset(new uint8_t[data_len]);
  200. ud->data_len = data_len;
  201. memcpy(ud->data.get(), data, data_len);
  202. auto desc = new (ptr.get()) MGBOprDesc;
  203. mgb_init_opr_desc(desc, output_shapes.size(), PLACEHOLDER_TYPE_NAME);
  204. desc->user_data = ud.release();
  205. #define s(n) desc->n = &PlaceholderMGBOprDesc::n;
  206. MGB_OPR_DESC_FOREACH_MEM_FN(s);
  207. if (!output_dtypes.empty()) {
  208. desc->infer_dtype = &PlaceholderMGBOprDesc::infer_dtype;
  209. }
  210. #undef s
  211. return reinterpret_cast<MGBOprDesc*>(ptr.release());
  212. }
  213. void PlaceholderMGBOprDesc::dump(OprDumpContext& ctx, MGBOprDesc* desc) {
  214. mgb_assert(
  215. desc->type_name == PLACEHOLDER_TYPE_NAME,
  216. "only placeholder ExternCOprRunner can be dumped; got type %s",
  217. desc->type_name);
  218. auto ud = user_data(desc);
  219. ctx.dump_buf_with_len(ud->name.c_str(), ud->name.size());
  220. ctx.dump_buf_with_len(ud->data.get(), ud->data_len);
  221. }
  222. /* ===================== ExternCOprRunner ===================== */
  223. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ExternCOprRunner);
  224. ExternCOprRunner::ExternCOprRunner(
  225. std::string& name, const VarNodeArray& inputs, std::shared_ptr<MGBOprDesc> desc,
  226. const OperatorNodeConfig& config)
  227. : Super{inputs[0]->owner_graph(), config, desc->type_name, inputs},
  228. m_desc{std::move(desc)},
  229. m_dump_name{name},
  230. m_param{nullptr} {
  231. auto size_diff = sizeof(MGBOprDesc) - m_desc->size;
  232. is_loader_support_dynamic_param = (0 == size_diff) ? true : false;
  233. mgb_assert(
  234. 0 == size_diff || sizeof(ExternCOprParam*) == size_diff,
  235. "invalid OprDesc size: expect=%zu got=%u, may caused by "
  236. "extern_c_opr.h mismatch, please confirm that the "
  237. "extern_c_opr.h used when compiling the loader is consistent "
  238. "with the runtime caller build used",
  239. sizeof(MGBOprDesc), m_desc->size);
  240. for (auto i : inputs) {
  241. add_input({i});
  242. }
  243. auto nr_out = m_desc->nr_output;
  244. if (nr_out > 1) {
  245. for (size_t i = 0, it = nr_out; i < it; ++i)
  246. add_output(ssprintf("o%zu", i));
  247. } else {
  248. mgb_assert(
  249. nr_out == 1, "could not create an operator with %u outputs: %s", nr_out,
  250. cname());
  251. add_output(None);
  252. }
  253. add_equivalence_component<MGBOprDescHash>(m_desc.get());
  254. }
  255. void ExternCOprRunner::get_output_var_shape(
  256. const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
  257. SmallVector<MGBTensorShape> c_inp(inp_shape.size()), c_out(out_shape.size());
  258. for (size_t i = 0; i < inp_shape.size(); ++i) {
  259. c_inp[i] = tensor_shape_to_c(inp_shape[i]);
  260. }
  261. m_desc->infer_shape(m_desc.get(), c_inp.data(), c_out.data());
  262. for (size_t i = 0; i < out_shape.size(); ++i) {
  263. out_shape[i] = tensor_shape_from_c(c_out[i]);
  264. }
  265. }
  266. void ExternCOprRunner::init_output_dtype() {
  267. if (!m_desc->infer_dtype) {
  268. Super::init_output_dtype();
  269. return;
  270. }
  271. SmallVector<MGBDType> inp_dtypes, out_dtypes(output().size());
  272. inp_dtypes.reserve(input().size());
  273. for (auto i : input()) {
  274. inp_dtypes.push_back(dtype_cpp2c(i->dtype()));
  275. }
  276. m_desc->infer_dtype(m_desc.get(), inp_dtypes.data(), out_dtypes.data());
  277. for (size_t i = 0; i < out_dtypes.size(); ++i) {
  278. output(i)->dtype(dtype_c2cpp(out_dtypes[i]));
  279. }
  280. }
  281. void ExternCOprRunner::check_param() {
  282. //! check extern dynamic param validity
  283. //! nr_input=0 or nr_output=0 means do not provide input/output
  284. //! ExternDeviceTensor for some case, ExternCOprParam may only config
  285. //! device_id, extra_info, etc. so we need consider nr_input=0 or
  286. //! nr_output=0
  287. if (!is_loader_support_dynamic_param)
  288. return;
  289. auto check = [](size_t nr_config_tensor, size_t var_node_size,
  290. ExternDeviceTensor* e_tensor, const VarNodeArray& var_node_array,
  291. const char* msg) {
  292. mgb_assert(e_tensor, "%s ExternDeviceTensor should not be null!!", msg);
  293. mgb_assert(
  294. nr_config_tensor == var_node_size,
  295. "param %s size provided by `config_extern_c_opr_dynamic_param` "
  296. "mismatch with the number of %s, got %zu, expected %zu",
  297. msg, msg, nr_config_tensor, var_node_size);
  298. for (size_t i = 0; i < nr_config_tensor; i++) {
  299. mgb_assert(
  300. e_tensor[i].device_ptr,
  301. "%s ExternDeviceTensor(index: %zu) device_ptr should "
  302. "not be null!!",
  303. msg, i);
  304. auto param_shape = e_tensor[i].layout.shape;
  305. auto shape = var_node_array.at(i)->shape();
  306. auto param_dtype = e_tensor[i].layout.dtype;
  307. auto dtype = dtype_cpp2c(var_node_array.at(i)->dtype());
  308. mgb_assert(
  309. param_dtype == dtype,
  310. "%s dtype provided mismatch, expected: %u, got: %d", msg,
  311. param_dtype, dtype);
  312. mgb_assert(
  313. shape.ndim == param_shape.ndim,
  314. "%s ndim provided mismatch got: %u, expect: %zu of "
  315. "index: %zu",
  316. msg, param_shape.ndim, shape.ndim, i);
  317. for (size_t j = 0; j < shape.ndim; j++) {
  318. mgb_assert(
  319. param_shape.shape[j] == shape.shape[j],
  320. "config %s shape should same with c opr %s shape: "
  321. "(got: %u expect: %zu) of index: %zu",
  322. msg, msg, param_shape.shape[j], shape.shape[j], j);
  323. }
  324. }
  325. };
  326. if (m_param && m_param->nr_input > 0) {
  327. check(m_param->nr_input, input().size(), m_param->input, input(), "input");
  328. }
  329. if (m_param && m_param->nr_output > 0) {
  330. check(m_param->nr_output, output().size(), m_param->output, output(), "output");
  331. }
  332. }
  333. void ExternCOprRunner::scn_do_execute() {
  334. SmallVector<MGBTensor> c_inp(input().size()), c_out(output().size());
  335. SmallVector<HostTensorND> cpu_inp, cpu_out;
  336. check_param();
  337. bool need_copy = false;
  338. if (comp_node().device_type() == CompNode::DeviceType::CPU) {
  339. for (size_t i = 0; i < input().size(); ++i) {
  340. c_inp[i] = tensor_to_c(input(i)->dev_tensor());
  341. }
  342. for (size_t i = 0; i < output().size(); ++i) {
  343. c_out[i] = tensor_to_c(output(i)->dev_tensor());
  344. }
  345. } else {
  346. need_copy = true;
  347. mgb_log_debug(
  348. "copy is needed to execute extern C "
  349. "opr `%s' on comp node `%s'",
  350. cname(), comp_node().to_string().c_str());
  351. cpu_inp.resize(input().size());
  352. cpu_out.resize(output().size());
  353. for (size_t i = 0; i < input().size(); ++i) {
  354. cpu_inp[i].copy_from(input(i)->dev_tensor());
  355. c_inp[i] = tensor_to_c(cpu_inp[i]);
  356. }
  357. for (size_t i = 0; i < output().size(); ++i) {
  358. cpu_out[i]
  359. .comp_node(comp_node())
  360. .dtype(output(i)->dtype())
  361. .resize(output(i)->shape());
  362. c_out[i] = tensor_to_c(cpu_out[i]);
  363. }
  364. }
  365. if (need_copy) {
  366. comp_node().sync();
  367. m_desc->execute(m_desc.get(), c_inp.data(), c_out.data());
  368. for (size_t i = 0; i < output().size(); ++i)
  369. output(i)->dev_tensor().copy_from_fixlayout(cpu_out[i]).sync();
  370. } else {
  371. CompNodeEnv::from_comp_node(comp_node())
  372. .cpu_env()
  373. .dispatch([this, c_inp, c_out]() mutable {
  374. m_desc->execute(m_desc.get(), c_inp.data(), c_out.data());
  375. });
  376. }
  377. }
  378. void ExternCOprRunner::add_input_layout_constraint() {
  379. for (auto i : input())
  380. i->add_layout_constraint_contiguous();
  381. }
  382. cg::OperatorNodeBase* ExternCOprRunner::make_placeholder(
  383. const SymbolVarArray& inputs, const TensorShapeArray& output_shapes,
  384. const char* name, const void* data, size_t data_len,
  385. const OperatorNodeConfig& config, const SmallVector<DType>& output_dtypes) {
  386. auto desc = PlaceholderMGBOprDesc::make(
  387. inputs.size(), name, output_shapes, output_dtypes, data, data_len);
  388. VarNodeArray var_inp(inputs.size());
  389. for (size_t i = 0; i < inputs.size(); ++i) {
  390. var_inp[i] = inputs[i].node();
  391. }
  392. auto dump_name = std::string{name};
  393. return make_from_desc(dump_name, var_inp, desc, config);
  394. }
  395. cg::OperatorNodeBase* ExternCOprRunner::make_from_desc(
  396. std::string& name, const VarNodeArray& inputs, MGBOprDesc* desc,
  397. const OperatorNodeConfig& config) {
  398. auto desc_del = [](MGBOprDesc* ptr) { ptr->release(ptr); };
  399. return make_from_desc_shared(name, inputs, {desc, desc_del}, config);
  400. }
  401. cg::OperatorNodeBase* ExternCOprRunner::make_from_desc_shared(
  402. std::string& name, const VarNodeArray& inputs, std::shared_ptr<MGBOprDesc> desc,
  403. const OperatorNodeConfig& config) {
  404. mgb_assert(!inputs.empty() && desc->nr_output);
  405. #define CHECK(name) mgb_assert(desc->name, #name " is not given");
  406. MGB_OPR_DESC_FOREACH_MEM_FN(CHECK);
  407. #undef CHECK
  408. if (!config.name().valid())
  409. const_cast<OperatorNodeConfig&>(config).name(name);
  410. auto opr = inputs[0]->owner_graph()->insert_opr(
  411. std::make_unique<ExternCOprRunner>(name, inputs, std::move(desc), config));
  412. return &opr->cast_final_safe<ExternCOprRunner>();
  413. }
  414. bool ExternCOprRunner::unregister_loader(const char* name) {
  415. return loader_map().erase(name);
  416. }
  417. void ExternCOprRunner::dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
  418. auto&& opr = opr_.cast_final<ExternCOprRunner>();
  419. PlaceholderMGBOprDesc::dump(ctx, opr.m_desc.get());
  420. }
  421. cg::OperatorNodeBase* ExternCOprRunner::load(
  422. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  423. const OperatorNodeConfig& config) {
  424. auto dump_name = ctx.load_buf_with_len();
  425. auto name = dump_name;
  426. //! use to compat dump ExternCOprRunner with more info
  427. if (auto index = name.find(":"))
  428. name = name.substr(0, index);
  429. auto&& map = loader_map();
  430. auto iter = map.find(name);
  431. mgb_assert(
  432. iter != map.end(), "can not find loader for ExternCOprRunner `%s'",
  433. name.c_str());
  434. auto data = ctx.load_shared_buf_with_len();
  435. auto desc = iter->second.first.create_desc(inputs.size(), data.data(), data.size());
  436. mgb_throw_if(nullptr == desc, MegBrainError, "loader create desc returns nullptr");
  437. if (auto trans = iter->second.second) {
  438. desc = trans(desc);
  439. }
  440. mgb_throw_if(nullptr == desc, MegBrainError, "loader create desc returns nullptr");
  441. return make_from_desc(dump_name, inputs, desc, config);
  442. }
  443. cg::OperatorNodeBase* ExternCOprRunner::shallow_copy(
  444. const serialization::OprShallowCopyContext& ctx,
  445. const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
  446. const OperatorNodeConfig& config) {
  447. auto&& opr = opr_.cast_final_safe<ExternCOprRunner>();
  448. auto dump_name = opr.m_dump_name;
  449. return make_from_desc_shared(dump_name, inputs, opr.m_desc, config);
  450. }
  451. MGBTensorShape ExternCOprRunner::tensor_shape_to_c(const TensorShape& shape) {
  452. mgb_throw_if(
  453. shape.ndim > MGB_TENSOR_MAX_NDIM, MegBrainError,
  454. "shape ndim too large: %zu", shape.ndim);
  455. MGBTensorShape ret;
  456. ret.ndim = shape.ndim;
  457. for (size_t i = 0; i < shape.ndim; ++i) {
  458. ret.shape[i] = shape[i];
  459. }
  460. return ret;
  461. }
  462. TensorShape ExternCOprRunner::tensor_shape_from_c(const MGBTensorShape& shape) {
  463. mgb_assert(
  464. shape.ndim <= TensorShape::MAX_NDIM, "shape ndim too large: %u",
  465. shape.ndim);
  466. TensorShape ret;
  467. ret.ndim = shape.ndim;
  468. for (size_t i = 0; i < shape.ndim; ++i) {
  469. ret.shape[i] = shape.shape[i];
  470. }
  471. return ret;
  472. }
  473. void mgb::config_extern_c_opr_dynamic_param(
  474. std::unique_ptr<cg::AsyncExecutable>& func,
  475. std::shared_ptr<ExternCOprParam> param) {
  476. mgb_throw_if(!param, MegBrainError, "invalid ExternCOprParam param!!");
  477. auto find_config_opr = false;
  478. auto cb = [&](cg::OperatorNodeBase* opr) {
  479. if (auto c_opr = opr->try_cast_final<opr::ExternCOprRunner>()) {
  480. auto dump_name = c_opr->get_dump_name().c_str();
  481. if (!param->extern_c_opr_dump_name ||
  482. !strncmp(param->extern_c_opr_dump_name, dump_name, strlen(dump_name))) {
  483. c_opr->set_param(param);
  484. find_config_opr = true;
  485. mgb_log_debug("config dynamic param for extern c opr: %s", dump_name);
  486. }
  487. }
  488. return !find_config_opr;
  489. };
  490. func->iter_opr_seq(cb);
  491. mgb_throw_if(
  492. !find_config_opr, MegBrainError,
  493. "graph do not include a ExternCOprRunner opr or error config "
  494. "extern_c_opr_dump_name!!");
  495. }
  496. /* ===================== public APIs ===================== */
  497. const MGBExternCOprApi* mgb_get_extern_c_opr_api_versioned(int version) {
  498. auto unreg = [](const char* name) -> int {
  499. return ExternCOprRunner::unregister_loader(name);
  500. };
  501. if (version == 0x23) {
  502. auto reg23 = [](const MGBOprLoader* loader) -> int {
  503. return loader_map()
  504. .insert({loader->name, {*loader, MGBOprDescV23::as_opr_desc}})
  505. .second;
  506. };
  507. static const MGBExternCOprApi ret = {reg23, unreg};
  508. return &ret;
  509. }
  510. if (version != MGB_EXTERN_C_OPR_VERSION)
  511. return nullptr;
  512. auto reg = [](const MGBOprLoader* loader) -> int {
  513. return loader_map().insert({loader->name, {*loader, nullptr}}).second;
  514. };
  515. static const MGBExternCOprApi ret = {reg, unreg};
  516. return &ret;
  517. }
  518. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}