You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shared_nd.i 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. /*
  2. * $File: shared_nd.i
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * $Copyright: Copyright (c) 2014-2017 Megvii Inc. All rights reserved.
  7. */
  8. %pythoncode {
  9. from .mgb_helper import SharedNDLazyInitializer
  10. } // pythoncode
  11. %feature("autodoc", """a value stored on computing device and can be modified
  12. by special operators in the graph""") SharedND;
  13. class SharedND {
  14. public:
  15. SharedND(CompNode comp_node, PyObject *dtype);
  16. void _set_init_shape(const std::vector<size_t> &shape);
  17. void _resize(const std::vector<size_t> &shape);
  18. void _reset_zero();
  19. PyObject* _get_npyarr();
  20. PyObject* _get_dtype();
  21. std::vector<size_t> _get_shape();
  22. void _copy_from_npyarr(PyObject *npyarr);
  23. void _copy_from_value_proxy(CompGraphCallbackValueProxy &value);
  24. void _share_from_value_proxy(CompGraphCallbackValueProxy &value);
  25. static SharedND _from_symvar(SymbolVar symvar);
  26. void _set_copy_sync(bool flag);
  27. uintptr_t _pubapi_dev_tensor_ptr(int version);
  28. void copy_to_sub_from_shared(
  29. int axis, ptrdiff_t begin, ptrdiff_t end, ptrdiff_t step,
  30. const SharedND &rhs);
  31. void copy_from_shared_sub(const SharedND &rhs,
  32. int axis, ptrdiff_t begin, ptrdiff_t end, ptrdiff_t step);
  33. CompNode _get_comp_node();
  34. SymbolVar _as_sym_var(CompGraph &graph, const std::string &name,
  35. bool volatile_);
  36. void _share_memory_from(const SharedND &rhs, size_t begin);
  37. void _reset_dev_tensor(const SharedND& rhs);
  38. %include "shared_nd_SharedND.py"
  39. };
  40. %template(_VectorSharedND) std::vector<SharedND>;
  41. class _HostSharedND {
  42. public:
  43. _HostSharedND(CompNode node, PyObject *dtype);
  44. static _HostSharedND make_proxy(SymbolVar var);
  45. SymbolVar _as_sym_var(CompGraph &cg, bool enable_static_infer,
  46. const std::string &name);
  47. PyObject* _get_dtype();
  48. void _resize(const std::vector<size_t> &shape);
  49. void _copy_from_npyarr(PyObject *npyarr, bool borrow);
  50. void _enable_borrow_on_cpu(bool flag);
  51. std::string __repr__() const;
  52. %include "shared_nd_HostSharedND.py"
  53. };
  54. %feature("autodoc",
  55. """a scalar value that can be modified after it has been created;
  56. compared to :class:`SharedND`, it has the advantage that no comp node needs to
  57. be specified.""") SharedScalar;
  58. class SharedScalar {
  59. public:
  60. SharedScalar(PyObject *val);
  61. void _set(PyObject *val);
  62. PyObject* _get();
  63. bool _dtype_locked();
  64. void _lock_dtype();
  65. SymbolVar _as_sym_var(CompGraph &cg, CompNode &cn);
  66. %pythoncode {
  67. def lock_dtype(self):
  68. """lock dtype so further set() calls must pass the same dtyped
  69. value"""
  70. self._lock_dtype()
  71. @property
  72. def dtype_locked(self):
  73. """whether dtype is locked"""
  74. return self._dtype_locked()
  75. def set(self, val):
  76. self._set(val)
  77. def get(self):
  78. """get the value stored in this SharedScalar"""
  79. return self._get()[0]
  80. def __getstate__(self):
  81. state = self.__dict__.copy()
  82. del state['this']
  83. state['__shared_scalar_value'] = self.get()
  84. state['__shared_scalar_dtype_locked'] = self.dtype_locked
  85. return state
  86. def __setstate__(self, state):
  87. val = SharedScalar(state.pop('__shared_scalar_value'))
  88. if state.pop('__shared_scalar_dtype_locked', True):
  89. val._lock_dtype()
  90. self.this = val.this
  91. for k, v in state.items():
  92. self.__dict__[k] = v
  93. def __repr__(self):
  94. return 'SharedScalar({})'.format(self.get())
  95. }
  96. };
  97. // vim: ft=swig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台