You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rng.cpp 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. /**
  2. * \file imperative/src/impl/ops/rng.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/imperative/ops/rng.h"
  13. #include "megbrain/comp_node_env.h"
  14. #include "megbrain/graph/helper.h"
  15. #include "megbrain/opr/rand.h"
  16. #include "../dnn_op_helper.h"
  17. #include "../op_trait.h"
  18. namespace mgb::imperative::rng {
  19. namespace {
  20. template <typename HandleFactory, typename THandle>
  21. class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj {
  22. public:
  23. using DT = CompNode::DeviceType;
  24. using Handle = THandle;
  25. using OpTypeInfo = size_t;
  26. template <typename... Args>
  27. Handle new_handle(Args&&... args) {
  28. return static_cast<HandleFactory*>(this)->do_new_handle(
  29. std::forward<Args>(args)...);
  30. }
  31. size_t delete_handle(Handle handle) {
  32. size_t removed = 0;
  33. if (!is_finalized()) {
  34. MGB_LOCK_GUARD(m_mtx);
  35. removed = m_handle2ops.erase(handle);
  36. }
  37. static_cast<HandleFactory*>(this)->do_delete_handle(handle);
  38. return removed;
  39. }
  40. template <typename DnnOp>
  41. auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) {
  42. mgb_assert(!is_finalized());
  43. DnnOpWithMutex* dnn_op_with_mtx;
  44. {
  45. MGB_LOCK_GUARD(m_mtx);
  46. dnn_op_with_mtx = &m_handle2ops[handle][tpinfo];
  47. }
  48. auto dnn_handle =
  49. MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
  50. std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx);
  51. bool initialized = false;
  52. DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get());
  53. if (dnn_op != nullptr) {
  54. mgb_assert(dnn_op->handle() == dnn_handle);
  55. initialized = true;
  56. } else {
  57. auto new_op = dnn_handle->create_operator<DnnOp>();
  58. dnn_op = new_op.get();
  59. dnn_op_with_mtx->op = std::move(new_op);
  60. }
  61. return std::make_tuple(initialized, dnn_op, std::move(lock));
  62. }
  63. protected:
  64. using DnnOpManagerBase = DnnOpManagerT<HandleFactory, Handle>;
  65. DnnOpManagerT() = default;
  66. private:
  67. struct DnnOpWithMutex {
  68. std::mutex mtx;
  69. std::unique_ptr<megdnn::OperatorBase> op;
  70. DnnOpWithMutex(): op{nullptr} {}
  71. };
  72. std::shared_ptr<void> on_comp_node_finalize() override {
  73. MGB_LOCK_GUARD(m_mtx);
  74. m_handle2ops.clear();
  75. return {};
  76. }
  77. std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex> > m_handle2ops;
  78. std::mutex m_mtx;
  79. };
  80. class RNGDnnOpManager final
  81. : public DnnOpManagerT<RNGDnnOpManager, Handle> {
  82. public:
  83. Handle new_handle(CompNode comp_node, uint64_t seed) {
  84. MGB_LOCK_GUARD(sm_mtx);
  85. return DnnOpManagerBase::new_handle(comp_node, seed);
  86. }
  87. size_t delete_handle(Handle handle) {
  88. MGB_LOCK_GUARD(sm_mtx);
  89. return DnnOpManagerBase::delete_handle(handle);
  90. }
  91. Handle do_new_handle(CompNode comp_node, uint64_t seed) {
  92. auto handle = m_handle_pool.alloc(comp_node, seed);
  93. return reinterpret_cast<Handle>(handle);
  94. }
  95. void do_delete_handle(Handle handle) {
  96. m_handle_pool.free(reinterpret_cast<HandleData*>(handle));
  97. }
  98. static uint64_t get_seed(Handle handle) {
  99. if (!handle) { return glob_default_seed; }
  100. return reinterpret_cast<HandleData*>(handle)->seed;
  101. }
  102. static CompNode get_comp_node(Handle handle) {
  103. mgb_assert(handle, "invalid handle");
  104. return reinterpret_cast<HandleData*>(handle)->comp_node;
  105. }
  106. static Handle get_default_handle(CompNode comp_node) {
  107. mgb_assert(comp_node.valid());
  108. MGB_LOCK_GUARD(sm_mtx);
  109. auto&& glob_handle = glob_default_handles[comp_node];
  110. if (!glob_handle) {
  111. glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
  112. }
  113. mgb_assert(get_seed(glob_handle) == glob_default_seed);
  114. return glob_handle;
  115. }
  116. static RNGDnnOpManager& inst() {
  117. static RNGDnnOpManager mgr;
  118. return mgr;
  119. }
  120. static void set_glob_default_seed(uint64_t seed) {
  121. MGB_LOCK_GUARD(sm_mtx);
  122. for(auto && elem : glob_default_handles){
  123. mgb_assert(elem.first.valid());
  124. if(elem.second){
  125. inst().DnnOpManagerBase::delete_handle(elem.second);
  126. }
  127. elem.second = inst().do_new_handle(elem.first, seed);
  128. }
  129. glob_default_seed = seed;
  130. }
  131. static uint64_t get_glob_default_seed() {
  132. MGB_LOCK_GUARD(sm_mtx);
  133. return glob_default_seed;
  134. }
  135. private:
  136. struct HandleData {
  137. CompNode comp_node;
  138. uint64_t seed;
  139. HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {}
  140. };
  141. MemPool<HandleData> m_handle_pool;
  142. static std::mutex sm_mtx;
  143. static CompNode::UnorderedMap<Handle> glob_default_handles;
  144. static uint64_t glob_default_seed;
  145. };
  146. uint64_t RNGDnnOpManager::glob_default_seed = 0;
  147. std::mutex RNGDnnOpManager::sm_mtx;
  148. CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles;
  149. template <typename Op>
  150. struct OpMeth;
  151. template <>
  152. struct OpMeth<UniformRNG> {
  153. using DnnOp = megdnn::UniformRNG;
  154. using Param = DnnOp::Param;
  155. using OpNode = mgb::opr::UniformRNG;
  156. static Param make_param(const UniformRNG& rng) {
  157. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  158. mgb_assert(handle_seed == rng.seed,
  159. "inconsistent rng seed: rng op: %lu handle: %lu",
  160. handle_seed, rng.seed);
  161. return {handle_seed, rng.dtype.enumv()};
  162. }
  163. };
  164. template <>
  165. struct OpMeth<PoissonRNG> {
  166. using DnnOp = megdnn::PoissonRNG;
  167. using Param = DnnOp::Param;
  168. using OpNode = mgb::opr::PoissonRNG;
  169. static Param make_param(const PoissonRNG& rng) {
  170. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  171. mgb_assert(handle_seed == rng.seed,
  172. "inconsistent rng seed: rng op: %lu handle: %lu",
  173. handle_seed, rng.seed);
  174. return {handle_seed};
  175. }
  176. };
  177. template <>
  178. struct OpMeth<GaussianRNG> {
  179. using DnnOp = megdnn::GaussianRNG;
  180. using Param = DnnOp::Param;
  181. using OpNode = mgb::opr::GaussianRNG;
  182. static Param make_param(const GaussianRNG& rng) {
  183. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  184. mgb_assert(handle_seed == rng.seed,
  185. "inconsistent rng seed: rng op: %lu handle: %lu",
  186. handle_seed, rng.seed);
  187. return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()};
  188. }
  189. };
  190. template <>
  191. struct OpMeth<GammaRNG> {
  192. using DnnOp = megdnn::GammaRNG;
  193. using Param = DnnOp::Param;
  194. using OpNode = mgb::opr::GammaRNG;
  195. static Param make_param(const GammaRNG& rng) {
  196. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  197. mgb_assert(handle_seed == rng.seed,
  198. "inconsistent rng seed: rng op: %lu handle: %lu",
  199. handle_seed, rng.seed);
  200. return {handle_seed};
  201. }
  202. };
  203. template <>
  204. struct OpMeth<PermutationRNG> {
  205. using DnnOp = megdnn::PermutationRNG;
  206. using Param = DnnOp::Param;
  207. using OpNode = mgb::opr::PermutationRNG;
  208. static Param make_param(const PermutationRNG& rng) {
  209. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  210. mgb_assert(handle_seed == rng.seed,
  211. "inconsistent rng seed: rng op: %lu handle: %lu",
  212. handle_seed, rng.seed);
  213. return {handle_seed, rng.dtype.enumv()};
  214. }
  215. };
  216. template <>
  217. struct OpMeth<BetaRNG> {
  218. using DnnOp = megdnn::BetaRNG;
  219. using Param = DnnOp::Param;
  220. using OpNode = mgb::opr::BetaRNG;
  221. static Param make_param(const BetaRNG& rng) {
  222. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  223. mgb_assert(handle_seed == rng.seed,
  224. "inconsistent rng seed: rng op: %lu handle: %lu",
  225. handle_seed, rng.seed);
  226. return {handle_seed};
  227. }
  228. };
  229. template <>
  230. struct OpMeth<ShuffleRNG> {
  231. using DnnOp = megdnn::ShuffleRNG;
  232. using Param = DnnOp::Param;
  233. using OpNode = mgb::opr::ShuffleRNG;
  234. static Param make_param(const ShuffleRNG& rng) {
  235. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  236. mgb_assert(handle_seed == rng.seed,
  237. "inconsistent rng seed: rng op: %lu handle: %lu",
  238. handle_seed, rng.seed);
  239. return {handle_seed};
  240. }
  241. };
  242. template <bool>
  243. struct _InferLayout;
  244. template <int nr_in>
  245. struct _RNGOprMaker;
  246. template <int nr_in, int nr_out>
  247. struct _RNGOprInvoker;
  248. template<>
  249. struct _InferLayout<true>
  250. {
  251. template<typename Op>
  252. static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){
  253. TensorShape tshape;
  254. auto hv = inp->get_value().proxy_to_default_cpu();
  255. cg::copy_tensor_value_to_shape(tshape, hv);
  256. return TensorLayout(tshape, rng.dtype);
  257. }
  258. template<typename Op>
  259. static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){
  260. TensorLayout out_layout = inp.layout;
  261. out_layout.dtype = rng.dtype;
  262. if (inp.layout.ndim == 0 || inp.value.empty()) {
  263. out_layout.ndim = 0;
  264. return out_layout;
  265. }
  266. mgb_assert(
  267. inp.layout.ndim == 1,
  268. "target shape of %s expects ndim=1; got ndim=%lu actually",
  269. rng.dyn_typeinfo()->name,
  270. inp.layout.ndim);
  271. size_t target_ndim = inp.layout.shape[0];
  272. out_layout.ndim = target_ndim;
  273. auto* ptr = inp.value.ptr<dt_int32>();
  274. for (size_t i = 0; i < target_ndim; ++i) {
  275. out_layout.shape[i] = ptr[i];
  276. }
  277. return out_layout;
  278. }
  279. };
  280. template<>
  281. struct _InferLayout<false>
  282. {
  283. template<typename Op>
  284. static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){
  285. return inp->layout();
  286. }
  287. template<typename Op>
  288. static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){
  289. mgb_assert(inp.layout.ndim);
  290. return inp.layout;
  291. }
  292. };
  293. #define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS) \
  294. template <> \
  295. struct _RNGOprInvoker<DNN_NR_INPUTS, DNN_NR_OUTPUTS> { \
  296. template <typename Opr> \
  297. static void exec(Opr* dnn_op, const SmallVector<TensorPtr>& inputs, \
  298. const SmallVector<TensorPtr>& outputs) { \
  299. size_t wk_size = 0; \
  300. wk_size = dnn_op->get_workspace_in_bytes( \
  301. _FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout())); \
  302. auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \
  303. megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
  304. dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \
  305. _FOR_EACH_OUT(->dev_tensor().as_megdnn()), \
  306. dnn_wk); \
  307. } \
  308. };
  309. #define _INST_RNG_MAKER(MGB_NR_INPUTS) \
  310. template <> \
  311. struct _RNGOprMaker<MGB_NR_INPUTS> { \
  312. template <typename Op> \
  313. static auto make(const VarNodeArray& inputs, const Op& rng) { \
  314. auto param = OpMeth<Op>::make_param(rng); \
  315. OperatorNodeConfig config; \
  316. if (rng.handle) { \
  317. config = {rng.make_name(), \
  318. RNGDnnOpManager::get_comp_node(rng.handle)}; \
  319. } else { \
  320. config = {rng.make_name()}; \
  321. } \
  322. return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \
  323. } \
  324. };
  325. #define _FOR_EACH_IN(subfix)
  326. #define _FOR_EACH_OUT(subfix) outputs[0] subfix
  327. _INST_RNG_INVOLKER(0, 1)
  328. #undef _FOR_EACH_OUT
  329. #undef _FOR_EACH_IN
  330. #define _FOR_EACH_IN(subfix) inputs[0] subfix,
  331. #define _FOR_EACH_OUT(subfix) outputs[0] subfix
  332. _INST_RNG_INVOLKER(1, 1)
  333. #undef _FOR_EACH_OUT
  334. #define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
  335. _INST_RNG_INVOLKER(1, 2)
  336. _INST_RNG_MAKER(1)
  337. #undef _FOR_EACH_OUT
  338. #undef _FOR_EACH_IN
  339. #define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix,
  340. #define _FOR_EACH_OUT(subfix) outputs[0] subfix
  341. _INST_RNG_INVOLKER(2, 1)
  342. _INST_RNG_MAKER(2)
  343. #undef _FOR_EACH_OUT
  344. #undef _FOR_EACH_IN
  345. #undef _INST_RNG_INVOLKER
  346. #undef _INST_RNG_MAKER
  347. template <typename Op>
  348. void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
  349. const SmallVector<TensorPtr>& outputs, const SmallVector<TensorPtr>& workspace) {
  350. auto&& rng = op.cast_final_safe<Op>();
  351. auto dest = outputs[0];
  352. if (dest->layout().is_empty()) return;
  353. auto cn = dest->comp_node();
  354. auto handle = rng.handle;
  355. if (!handle) {
  356. handle = RNGDnnOpManager::get_default_handle(cn);
  357. }
  358. // retrieve dnn_op from glob cache
  359. auto dnn_op_thread_safe = RNGDnnOpManager::inst()
  360. .get_dnn_op<typename OpMeth<Op>::DnnOp>(
  361. handle, reinterpret_cast<size_t>(op.dyn_typeinfo()),
  362. cn);
  363. auto initialized = std::get<0>(dnn_op_thread_safe);
  364. auto dnn_op = std::get<1>(dnn_op_thread_safe);
  365. if (initialized) {
  366. auto handle_seed = RNGDnnOpManager::get_seed(handle);
  367. mgb_assert(dnn_op->param().seed == handle_seed,
  368. "inconsistent rng seed: handle: %lu, dnn_op: %lu",
  369. handle_seed, dnn_op->param().seed);
  370. }
  371. dnn_op->param() = OpMeth<Op>::make_param(rng);
  372. _RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS,
  373. OpMeth<Op>::DnnOp::NR_OUTPUTS>::exec(dnn_op, inputs,
  374. outputs);
  375. }
  376. template <typename Op>
  377. SmallVector<LogicalTensorDesc> infer_output_attrs(
  378. const OpDef& op, const SmallVector<TensorPtr>& inputs) {
  379. LogicalTensorDesc dest;
  380. auto&& rng = op.cast_final_safe<Op>();
  381. auto handle = rng.handle;
  382. if (handle) {
  383. dest.comp_node = RNGDnnOpManager::get_comp_node(handle);
  384. } else {
  385. dest.comp_node = inputs[0]->comp_node();
  386. }
  387. constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
  388. if(!rng_with_shape){
  389. for(int i = 0; i < inputs.size(); ++i){
  390. mgb_assert(inputs[i]->comp_node() == dest.comp_node,
  391. "%s expects the device of inputs[%d] to be same as the device of handle; "
  392. "got %s and %s actually", rng.dyn_typeinfo()->name, i,
  393. inputs[i]->comp_node().to_string().c_str(),
  394. dest.comp_node.to_string().c_str());
  395. }
  396. }
  397. dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng);
  398. return {dest};
  399. }
  400. template <>
  401. SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>(
  402. const OpDef& op, const SmallVector<TensorPtr>& inputs) {
  403. SmallVector<LogicalTensorDesc> dests(2);
  404. auto&& rng = op.cast_final_safe<ShuffleRNG>();
  405. auto handle = rng.handle;
  406. if (handle) {
  407. dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle);
  408. dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle);
  409. } else {
  410. dests[0].comp_node = inputs[0]->comp_node();
  411. dests[1].comp_node = inputs[0]->comp_node();
  412. }
  413. dests[0].layout = TensorLayout(inputs[0]->layout());
  414. dests[0].layout.dtype = inputs[0]->layout().dtype;
  415. dests[1].layout =
  416. TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32());
  417. return dests;
  418. }
  419. template <typename Op>
  420. std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
  421. infer_output_mem_desc(const OpDef& def,
  422. const SmallVector<TensorPtr>& inputs_tensors,
  423. const SmallVector<MemoryDesc>& inputs_mems) {
  424. auto&& dests = infer_output_attrs<Op>(def, inputs_tensors);
  425. SmallVector<MemoryDesc> outputs;
  426. for (size_t i = 0; i < dests.size(); ++i) {
  427. outputs.push_back({dests[i].layout, 0, dests[i].comp_node,
  428. StorageIdentifier::make(i + 1)});
  429. }
  430. return {outputs, {}};
  431. }
  432. template <typename Op>
  433. SmallVector<TensorPtr> apply_on_physical_tensor(
  434. const OpDef& def, const SmallVector<TensorPtr>& inputs) {
  435. SmallVector<TensorPtr> outputs;
  436. SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs);
  437. for (auto&& i : desc) {
  438. outputs.push_back(Tensor::make(i.layout, i.comp_node));
  439. }
  440. exec<Op>(def, inputs, outputs, {});
  441. return outputs;
  442. }
  443. template <typename Op>
  444. void execute(
  445. const OpDef& def,
  446. SmallVector<TensorPtr> inputs,
  447. SmallVector<TensorPtr> outputs,
  448. SmallVector<TensorPtr> workspace) {
  449. exec<Op>(def, inputs, outputs, {});
  450. }
  451. template <typename Op, typename Output>
  452. Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  453. size_t nr_inp = inputs.size();
  454. constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS;
  455. auto&& rng = def.cast_final_safe<Op>();
  456. if(dnn_nr_inp == 0){
  457. mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually",
  458. rng.dyn_typeinfo()->name,
  459. nr_inp);
  460. }
  461. constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp;
  462. return _RNGOprMaker<mgb_nr_inp>::make(inputs, rng);
  463. }
  464. template<typename Op>
  465. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  466. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  467. LogicalTensorDesc dest;
  468. auto&& xxx_rng_def = def.cast_final_safe<Op>();
  469. size_t nr_inp = inputs.size();
  470. constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
  471. if (rng_with_shape){
  472. mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually",
  473. xxx_rng_def.dyn_typeinfo()->name,
  474. nr_inp);
  475. }
  476. dest.comp_node = inputs[0].comp_node;
  477. dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
  478. return {{dest}, true};
  479. }
  480. template <>
  481. std::tuple<SmallVector<LogicalTensorDesc>, bool>
  482. infer_output_attrs_fallible<ShuffleRNG>(
  483. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  484. SmallVector<LogicalTensorDesc> dests(2);
  485. dests[0].comp_node = inputs[0].comp_node;
  486. dests[0].layout = TensorLayout(inputs[0].layout);
  487. dests[0].layout.dtype = inputs[0].layout.dtype;
  488. dests[1].comp_node = inputs[0].comp_node;
  489. dests[1].layout = TensorLayout(TensorShape({inputs[0].layout.shape[0]}),
  490. dtype::Int32());
  491. return {dests, true};
  492. }
  493. } // anonymous namespace
  494. Handle new_handle(CompNode comp_node, uint64_t seed) {
  495. return RNGDnnOpManager::inst().new_handle(comp_node, seed);
  496. }
  497. size_t delete_handle(Handle handle) {
  498. return RNGDnnOpManager::inst().delete_handle(handle);
  499. }
  500. void set_global_rng_seed(uint64_t seed) {
  501. RNGDnnOpManager::set_glob_default_seed(seed);
  502. }
  503. uint64_t get_global_rng_seed() {
  504. return RNGDnnOpManager::get_glob_default_seed();
  505. }
  506. CompNode get_rng_handle_compnode(Handle handle){
  507. return RNGDnnOpManager::get_comp_node(handle);
  508. }
  509. #define REG_RNG_OP(NAME, Output) \
  510. namespace { \
  511. OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
  512. .apply_on_var_node(apply_on_var_node<NAME, Output>) \
  513. .apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
  514. .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
  515. .infer_output_mem_desc(infer_output_mem_desc<NAME>) \
  516. .execute(execute<NAME>) \
  517. .fallback(); \
  518. }
  519. REG_RNG_OP(UniformRNG, SymbolVar)
  520. REG_RNG_OP(GaussianRNG, SymbolVar)
  521. REG_RNG_OP(GammaRNG, SymbolVar)
  522. REG_RNG_OP(PermutationRNG, SymbolVar)
  523. REG_RNG_OP(PoissonRNG, SymbolVar)
  524. REG_RNG_OP(BetaRNG, SymbolVar)
  525. REG_RNG_OP(ShuffleRNG, SymbolVarArray)
  526. #undef REG_RNG_OP
  527. } // namespace mgb::imperative::rng
  528. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台