Browse Source

feat(whl/api/lar): enable megengine dll on Windows

1: reduce python whl package size
2: unify api link logic on all OS
3: add option: MGE_WINDOWS_BUILD_WITH_STATIC_CRT
    --- default OFF
    --- if build CRT(vc runtime) with STATIC with megengine.dll
        some CRT api will crash, for example, flush, so if you
        build with static megengine, and do not want to install CRT
        you can set MGE_WINDOWS_BUILD_WITH_STATIC_CRT TRUE
    --- how to install CRT:
        https://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160
        install VC_redist.x64.exe
4: rename megengine_export to megengine_shared(only export needed symbols ),
   caused by runtime symbols conflict with pytorch

GitOrigin-RevId: 93d8d80f29
tags/v1.7.0
Megvii Engine Team 4 years ago
parent
commit
25ec2530ba
100 changed files with 1381 additions and 1132 deletions
  1. +28
    -22
      CMakeLists.txt
  2. +1
    -0
      dnn/cuda-stub/CMakeLists.txt
  3. +24
    -2
      dnn/cuda-stub/src/libcuda.cpp
  4. +53
    -44
      dnn/include/megdnn/basic_types.h
  5. +19
    -17
      dnn/include/megdnn/dtype.h
  6. +13
    -12
      dnn/include/megdnn/handle.h
  7. +9
    -6
      dnn/include/megdnn/oprs/general.h
  8. +1
    -1
      dnn/include/megdnn/oprs/nn_int.h
  9. +4
    -2
      dnn/include/megdnn/thin/small_vector.h
  10. +2
    -1
      dnn/include/megdnn/version.h
  11. +4
    -4
      imperative/CMakeLists.txt
  12. +4
    -0
      imperative/python/megengine/__init__.py
  13. +8
    -4
      lite/CMakeLists.txt
  14. +4
    -0
      lite/pylite/megenginelite/base.py
  15. +3
    -3
      scripts/whl/macos/macos_build_whl.sh
  16. +4
    -4
      scripts/whl/manylinux2014/do_build_common.sh
  17. +20
    -25
      scripts/whl/windows/windows_build_whl.sh
  18. +17
    -14
      sdk/load-and-run/CMakeLists.txt
  19. +39
    -29
      src/CMakeLists.txt
  20. +13
    -0
      src/core/impl/comp_node/cuda/comp_node.cpp
  21. +6
    -6
      src/core/impl/graph/static_infer_impl.h
  22. +4
    -4
      src/core/impl/tensor.cpp
  23. +8
    -7
      src/core/include/megbrain/common.h
  24. +37
    -34
      src/core/include/megbrain/comp_node.h
  25. +12
    -10
      src/core/include/megbrain/comp_node_env.h
  26. +10
    -7
      src/core/include/megbrain/dtype.h
  27. +1
    -1
      src/core/include/megbrain/exception.h
  28. +10
    -7
      src/core/include/megbrain/graph/cg.h
  29. +15
    -15
      src/core/include/megbrain/graph/event.h
  30. +1
    -1
      src/core/include/megbrain/graph/extern_copr_api.h
  31. +8
    -8
      src/core/include/megbrain/graph/grad_impl.h
  32. +36
    -29
      src/core/include/megbrain/graph/helper.h
  33. +48
    -38
      src/core/include/megbrain/graph/operator_node.h
  34. +3
    -3
      src/core/include/megbrain/graph/static_infer.h
  35. +11
    -10
      src/core/include/megbrain/graph/symbol_var.h
  36. +37
    -29
      src/core/include/megbrain/graph/var_node.h
  37. +7
    -7
      src/core/include/megbrain/system.h
  38. +36
    -29
      src/core/include/megbrain/tensor.h
  39. +6
    -5
      src/core/include/megbrain/utils/debug.h
  40. +3
    -3
      src/core/include/megbrain/utils/event.h
  41. +5
    -4
      src/core/include/megbrain/utils/hash.h
  42. +3
    -3
      src/core/include/megbrain/utils/hashable.h
  43. +10
    -7
      src/core/include/megbrain/utils/infile_persistent_cache.h
  44. +18
    -16
      src/core/include/megbrain/utils/json.h
  45. +10
    -9
      src/core/include/megbrain/utils/mempool.h
  46. +17
    -5
      src/core/include/megbrain/utils/metahelper.h
  47. +2
    -2
      src/core/include/megbrain/utils/metahelper_basic.h
  48. +2
    -2
      src/core/include/megbrain/utils/persistent_cache.h
  49. +9
    -9
      src/core/include/megbrain/utils/thread_impl_1.h
  50. +2
    -1
      src/core/include/megbrain/utils/timer.h
  51. +3
    -1
      src/core/include/megbrain/version.h
  52. +19
    -15
      src/custom/include/megbrain/custom/manager.h
  53. +51
    -45
      src/custom/include/megbrain/custom/op.h
  54. +8
    -8
      src/custom/include/megbrain/custom/param.h
  55. +17
    -16
      src/custom/include/megbrain/custom/param_val.h
  56. +56
    -55
      src/custom/include/megbrain/custom/tensor.h
  57. +3
    -3
      src/custom/include/megbrain/custom/utils.h
  58. +9
    -9
      src/gopt/include/megbrain/gopt/framework.h
  59. +8
    -6
      src/gopt/include/megbrain/gopt/inference.h
  60. +39
    -19
      src/megbrain_build_config.h.in
  61. +20
    -17
      src/opr/include/megbrain/opr/basic_arith.h
  62. +10
    -8
      src/opr/include/megbrain/opr/blas.h
  63. +13
    -13
      src/opr/include/megbrain/opr/custom_opnode.h
  64. +4
    -4
      src/opr/include/megbrain/opr/dnn/adaptive_pooling.h
  65. +8
    -8
      src/opr/include/megbrain/opr/dnn/batch_norm.h
  66. +63
    -62
      src/opr/include/megbrain/opr/dnn/convolution.h
  67. +6
    -6
      src/opr/include/megbrain/opr/dnn/correlation.h
  68. +5
    -5
      src/opr/include/megbrain/opr/dnn/fake_quant.h
  69. +4
    -4
      src/opr/include/megbrain/opr/dnn/images2neibs.h
  70. +5
    -4
      src/opr/include/megbrain/opr/dnn/lrn.h
  71. +4
    -4
      src/opr/include/megbrain/opr/dnn/lsq.h
  72. +5
    -5
      src/opr/include/megbrain/opr/dnn/pooling.h
  73. +4
    -4
      src/opr/include/megbrain/opr/dnn/roi_align.h
  74. +11
    -11
      src/opr/include/megbrain/opr/dnn/roi_pooling.h
  75. +5
    -5
      src/opr/include/megbrain/opr/dnn/sliding_window_transpose.h
  76. +4
    -4
      src/opr/include/megbrain/opr/dnn/tqt.h
  77. +16
    -16
      src/opr/include/megbrain/opr/imgproc.h
  78. +8
    -8
      src/opr/include/megbrain/opr/indexing.h
  79. +4
    -4
      src/opr/include/megbrain/opr/internal/indexing_helper.h
  80. +24
    -17
      src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h
  81. +24
    -20
      src/opr/include/megbrain/opr/io.h
  82. +1
    -1
      src/opr/include/megbrain/opr/loop.h
  83. +32
    -25
      src/opr/include/megbrain/opr/misc.h
  84. +3
    -3
      src/opr/include/megbrain/opr/nn_int.h
  85. +6
    -6
      src/opr/include/megbrain/opr/rand.h
  86. +6
    -4
      src/opr/include/megbrain/opr/standalone/nms_opr.h
  87. +3
    -3
      src/opr/include/megbrain/opr/tensor_gen.h
  88. +105
    -97
      src/opr/include/megbrain/opr/tensor_manip.h
  89. +66
    -48
      src/opr/include/megbrain/opr/utility.h
  90. +2
    -2
      src/plugin/include/megbrain/plugin/cpu_dispatch_checker.h
  91. +7
    -5
      src/plugin/include/megbrain/plugin/infkern_finder.h
  92. +1
    -1
      src/plugin/include/megbrain/plugin/num_range_checker.h
  93. +6
    -5
      src/plugin/include/megbrain/plugin/opr_footprint.h
  94. +5
    -4
      src/plugin/include/megbrain/plugin/opr_io_dump.h
  95. +3
    -3
      src/plugin/include/megbrain/plugin/profiler.h
  96. +1
    -1
      src/plugin/include/megbrain/plugin/var_value_checker.h
  97. +4
    -0
      src/serialization/include/megbrain/serialization/extern_c_opr.h
  98. +13
    -10
      src/serialization/include/megbrain/serialization/extern_c_opr_io.h
  99. +9
    -6
      src/serialization/include/megbrain/serialization/file.h
  100. +1
    -1
      src/serialization/include/megbrain/serialization/helper.h

+ 28
- 22
CMakeLists.txt View File

@@ -82,6 +82,16 @@ option(MGE_WITH_LARGE_ARCHIVE "Enable big archive link support" OFF)
option(MGE_BUILD_WITH_ASAN "Enable build with ASAN, need compiler support" OFF) option(MGE_BUILD_WITH_ASAN "Enable build with ASAN, need compiler support" OFF)
option(MGE_WITH_CUSTOM_OP "Build with Custom op" OFF) option(MGE_WITH_CUSTOM_OP "Build with Custom op" OFF)
if(MSVC OR WIN32) if(MSVC OR WIN32)
# FIXME: static link Windows vc runtime with some version from Visual Studio have
# some runtime issue at some call PATH, for example: _imperative_rt.pyd --> megengine_shared.dll
# for example c api flush can not find the fd args, I have no idea about this issue
# as a Workround, dynamic link vc runtime, but at some case, we will static link vcrt
# when MGE_DEPLOY_INFERENCE_ON_WINDOWS_XP/MGE_DEPLOY_INFERENCE_ON_WINDOWS_XP_SP2, so please
# use lite_static_all_in_one(lite/CMakeLists.txt) in Windows XP env as possible
# How to install VC runtime if you env do not install, refer to:
# https://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160
option(MGE_STATIC_LINK_WITH_VC_RUNTIME "Enable mge static link with Windows vc runtime" OFF)

option(MGE_DEPLOY_INFERENCE_ON_WINDOWS_XP "Enable deploy inference on Windows xp" OFF) option(MGE_DEPLOY_INFERENCE_ON_WINDOWS_XP "Enable deploy inference on Windows xp" OFF)
# special MGE_DEPLOY_INFERENCE_ON_WINDOWS_XP_SP2 for Windows XP sp2(32bit) # special MGE_DEPLOY_INFERENCE_ON_WINDOWS_XP_SP2 for Windows XP sp2(32bit)
# internal behavior: # internal behavior:
@@ -103,6 +113,9 @@ if(MSVC OR WIN32)
# which always locate in Microsoft Visual Studio/*/*/VC/Tools/MSVC/*/bin/*/*/link.exe # which always locate in Microsoft Visual Studio/*/*/VC/Tools/MSVC/*/bin/*/*/link.exe
set(CMAKE_LINKER "link.exe") set(CMAKE_LINKER "link.exe")
if(MGE_DEPLOY_INFERENCE_ON_WINDOWS_XP OR MGE_DEPLOY_INFERENCE_ON_WINDOWS_XP_SP2) if(MGE_DEPLOY_INFERENCE_ON_WINDOWS_XP OR MGE_DEPLOY_INFERENCE_ON_WINDOWS_XP_SP2)
set(MGE_STATIC_LINK_WITH_VC_RUNTIME ON)
message(STATUS "Force set MGE_STATIC_LINK_WITH_VC_RUNTIME ON when build for Windows XP")

if(NOT ${MGE_ARCH} STREQUAL "i386") if(NOT ${MGE_ARCH} STREQUAL "i386")
message(FATAL_ERROR "only support 32bit when build for Windows xp") message(FATAL_ERROR "only support 32bit when build for Windows xp")
endif() endif()
@@ -273,10 +286,22 @@ if(MSVC OR WIN32)
# for cmake after 3.15.2 # for cmake after 3.15.2
cmake_policy(SET CMP0091 NEW) cmake_policy(SET CMP0091 NEW)
set(CMAKE_OBJECT_PATH_MAX 300) set(CMAKE_OBJECT_PATH_MAX 300)
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDebug")
if(MGE_BUILD_WITH_ASAN)
set(MGE_STATIC_LINK_WITH_VC_RUNTIME ON)
message(STATUS "Force set MGE_STATIC_LINK_WITH_VC_RUNTIME ON when build for Windows MGE_BUILD_WITH_ASAN")
endif()
if(MGE_STATIC_LINK_WITH_VC_RUNTIME)
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDebug")
else()
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded")
endif()
else() else()
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded")
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDebugDLL")
else()
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDLL")
endif()
endif() endif()


add_compile_definitions(NOMINMAX=1 _USE_MATH_DEFINES=1 WIN32=1) add_compile_definitions(NOMINMAX=1 _USE_MATH_DEFINES=1 WIN32=1)
@@ -1183,25 +1208,6 @@ if (NOT MGE_WITH_DISTRIBUTED)
DESTINATION ${MGE_INSTALL_CMAKEDIR}) DESTINATION ${MGE_INSTALL_CMAKEDIR})
endif() endif()


if(MSVC OR WIN32)
add_compile_options(
$<$<CONFIG:>:/MT>
$<$<CONFIG:Debug>:/MTd>
$<$<CONFIG:Release>:/MT>
)
foreach (CompilerFlag
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
if(${CompilerFlag} MATCHES "/MD")
string(REPLACE "/MD" "/MT" ${CompilerFlag} "${${CompilerFlag}}")
set(${CompilerFlag} "${${CompilerFlag}}" CACHE STRING "msvc compiler flags" FORCE)
message(VERBOSE "MSVC flags: ${CompilerFlag}:${${CompilerFlag}}")
endif()
endforeach()
endif()

if(MGE_WITH_JIT_MLIR) if(MGE_WITH_JIT_MLIR)
add_subdirectory(tools/mlir/mgb-opt) add_subdirectory(tools/mlir/mgb-opt)
add_subdirectory(tools/mlir/mgb-file-check) add_subdirectory(tools/mlir/mgb-file-check)


+ 1
- 0
dnn/cuda-stub/CMakeLists.txt View File

@@ -22,4 +22,5 @@ if (MSVC OR WIN32)
else() else()
target_link_libraries(cuda-stub PRIVATE dl -Wl,--no-undefined) target_link_libraries(cuda-stub PRIVATE dl -Wl,--no-undefined)
endif() endif()
target_include_directories(cuda-stub PRIVATE $<BUILD_INTERFACE:${PROJECT_BINARY_DIR}/genfiles>)
install (TARGETS cuda-stub EXPORT ${MGE_EXPORT_TARGETS}) install (TARGETS cuda-stub EXPORT ${MGE_EXPORT_TARGETS})

+ 24
- 2
dnn/cuda-stub/src/libcuda.cpp View File

@@ -1,9 +1,10 @@
#include "megbrain_build_config.h"

#pragma GCC visibility push(default) #pragma GCC visibility push(default)


#include <cstdio> #include <cstdio>
#define LOGE(fmt, v...) fprintf(stderr, "err: " fmt "\n", ##v) #define LOGE(fmt, v...) fprintf(stderr, "err: " fmt "\n", ##v)



extern "C" { extern "C" {
#include "cuda.h" #include "cuda.h"
} }
@@ -28,8 +29,29 @@ CUresult on_init_failed(int func_idx) {


#if CUDA_VERSION == 10010 #if CUDA_VERSION == 10010
#include "./libcuda-wrap_10.1.h" #include "./libcuda-wrap_10.1.h"

//! as some symbols link from cuda lib, but used at other module, export here
#ifdef WIN32
#pragma comment(linker, "/export:cudaSetDevice")
#pragma comment(linker, "/export:cuCtxGetCurrent")
#pragma comment(linker, "/export:cudaGetDeviceCount")
#pragma comment(linker, "/export:cudaGetDeviceProperties")
#pragma comment(linker, "/export:cudaRuntimeGetVersion")
#pragma comment(linker, "/export:cudaGetDevice")
#pragma comment(linker, "/export:cudaDeviceSynchronize")
#endif
#elif CUDA_VERSION == 10020 #elif CUDA_VERSION == 10020
#include "./libcuda-wrap_10.2.h" #include "./libcuda-wrap_10.2.h"
//! as some symbols link from cuda lib, but used at other module, export here
#ifdef WIN32
#pragma comment(linker, "/export:cudaSetDevice")
#pragma comment(linker, "/export:cuCtxGetCurrent")
#pragma comment(linker, "/export:cudaGetDeviceCount")
#pragma comment(linker, "/export:cudaGetDeviceProperties")
#pragma comment(linker, "/export:cudaRuntimeGetVersion")
#pragma comment(linker, "/export:cudaGetDevice")
#pragma comment(linker, "/export:cudaDeviceSynchronize")
#endif
#elif CUDA_VERSION == 11010 #elif CUDA_VERSION == 11010
#include "./libcuda-wrap_11.1.h" #include "./libcuda-wrap_11.1.h"
#elif CUDA_VERSION == 11020 #elif CUDA_VERSION == 11020
@@ -79,4 +101,4 @@ static const char* extra_so_paths[] = {
}; };


static const char* g_default_api_name = "cuda"; static const char* g_default_api_name = "cuda";
#include "./dlopen_helper.h"
#include "./dlopen_helper.h"

+ 53
- 44
dnn/include/megdnn/basic_types.h View File

@@ -104,22 +104,22 @@ struct TensorShape {
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
TensorShape() = default; TensorShape() = default;
TensorShape(const TensorShape& rhs) = default; TensorShape(const TensorShape& rhs) = default;
TensorShape(const SmallVector<size_t>& init_shape);
TensorShape(std::initializer_list<size_t> init_shape);
std::string to_string() const;
MGE_WIN_DECLSPEC_FUC TensorShape(const SmallVector<size_t>& init_shape);
MGE_WIN_DECLSPEC_FUC TensorShape(std::initializer_list<size_t> init_shape);
MGE_WIN_DECLSPEC_FUC std::string to_string() const;
#endif #endif


//! total number of elements //! total number of elements
size_t total_nr_elems() const;
MGE_WIN_DECLSPEC_FUC size_t total_nr_elems() const;


//! check whether two shapes are equal //! check whether two shapes are equal
bool eq_shape(const TensorShape& rhs) const;
MGE_WIN_DECLSPEC_FUC bool eq_shape(const TensorShape& rhs) const;


//! check whether the shape can be treated as a scalar //! check whether the shape can be treated as a scalar
bool is_scalar() const { return ndim == 1 && shape[0] == 1; } bool is_scalar() const { return ndim == 1 && shape[0] == 1; }


//! check whether ndim != 0 and at least one shape is 0 //! check whether ndim != 0 and at least one shape is 0
bool is_empty() const;
MGE_WIN_DECLSPEC_FUC bool is_empty() const;


//! access single element, without boundary check //! access single element, without boundary check
size_t& operator[](size_t i) { return shape[i]; } size_t& operator[](size_t i) { return shape[i]; }
@@ -168,8 +168,8 @@ struct TensorLayout : public TensorShape {
class ImplBase; class ImplBase;


#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
Format();
Format(DType dtype);
MGE_WIN_DECLSPEC_FUC Format();
MGE_WIN_DECLSPEC_FUC Format(DType dtype);


const ImplBase* impl() const { return m_impl; } const ImplBase* impl() const { return m_impl; }


@@ -190,16 +190,17 @@ struct TensorLayout : public TensorShape {
} }


//! get human-readable string description of this format //! get human-readable string description of this format
std::string to_string() const;
MGE_WIN_DECLSPEC_FUC std::string to_string() const;


std::string serialize() const;
static Format deserialize(const std::string& bin, const Handle* handle);
MGE_WIN_DECLSPEC_FUC std::string serialize() const;
MGE_WIN_DECLSPEC_FUC static Format deserialize(
const std::string& bin, const Handle* handle);


//! whether this is the default tensor format //! whether this is the default tensor format
bool is_default() const;
MGE_WIN_DECLSPEC_FUC bool is_default() const;


//! whether this is the lowbit aligned to bytes tensor format //! whether this is the lowbit aligned to bytes tensor format
bool is_lowbit_aligned() const;
MGE_WIN_DECLSPEC_FUC bool is_lowbit_aligned() const;


bool operator==(Format rhs) const { return m_impl == rhs.m_impl; } bool operator==(Format rhs) const { return m_impl == rhs.m_impl; }
bool operator!=(Format rhs) const { return m_impl != rhs.m_impl; } bool operator!=(Format rhs) const { return m_impl != rhs.m_impl; }
@@ -218,27 +219,28 @@ struct TensorLayout : public TensorShape {
DType dtype; DType dtype;
Format format; Format format;


TensorLayout();
MGE_WIN_DECLSPEC_FUC TensorLayout();


#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
TensorLayout(const TensorLayout& layout) = default; TensorLayout(const TensorLayout& layout) = default;


//! create empty layout with given dtype //! create empty layout with given dtype
explicit TensorLayout(DType dtype_);
MGE_WIN_DECLSPEC_FUC explicit TensorLayout(DType dtype_);


TensorLayout(DType dtype_, Format format);
MGE_WIN_DECLSPEC_FUC TensorLayout(DType dtype_, Format format);


//! create layout with given shape and contiguous stride. //! create layout with given shape and contiguous stride.
TensorLayout(const TensorShape& shape, DType dtype);
MGE_WIN_DECLSPEC_FUC TensorLayout(const TensorShape& shape, DType dtype);


TensorLayout(const TensorShape& shape, DType dtype, Format format);
MGE_WIN_DECLSPEC_FUC TensorLayout(
const TensorShape& shape, DType dtype, Format format);


//! creating layout with user-specified shape and stride. //! creating layout with user-specified shape and stride.
TensorLayout(
MGE_WIN_DECLSPEC_FUC TensorLayout(
const TensorShape& shape, const std::vector<ptrdiff_t>& stride, const TensorShape& shape, const std::vector<ptrdiff_t>& stride,
DType dtype); DType dtype);


TensorLayout(
MGE_WIN_DECLSPEC_FUC TensorLayout(
const TensorShape& shape, const std::vector<ptrdiff_t>& stride, DType dtype, const TensorShape& shape, const std::vector<ptrdiff_t>& stride, DType dtype,
Format format); Format format);


@@ -251,28 +253,30 @@ struct TensorLayout : public TensorShape {
* *
* \return total number of elements * \return total number of elements
*/ */
size_t init_contiguous_stride();
MGE_WIN_DECLSPEC_FUC size_t init_contiguous_stride();


/*! /*!
* \brief init stride to be contiguous by first assigning shape * \brief init stride to be contiguous by first assigning shape
* *
* Use current format. * Use current format.
*/ */
size_t init_contiguous_stride(const TensorShape& shape);
MGE_WIN_DECLSPEC_FUC size_t init_contiguous_stride(const TensorShape& shape);


size_t init_contiguous_stride(const TensorShape& shape, Format format);
MGE_WIN_DECLSPEC_FUC size_t
init_contiguous_stride(const TensorShape& shape, Format format);


/*! /*!
* \brief inplace version of remove_axis * \brief inplace version of remove_axis
*/ */
void remove_axis_inplace(size_t idx);
MGE_WIN_DECLSPEC_FUC void remove_axis_inplace(size_t idx);


/*! /*!
* \brief add an axis before given *axis* with given shape and stride * \brief add an axis before given *axis* with given shape and stride
* *
* Other shapes and strides would not be changed. * Other shapes and strides would not be changed.
*/ */
void add_axis_inplace(size_t axis, size_t shape, ptrdiff_t stride);
MGE_WIN_DECLSPEC_FUC void add_axis_inplace(
size_t axis, size_t shape, ptrdiff_t stride);


/*! /*!
* \brief add an axis before given *axis*, with shape 1 and contiguous * \brief add an axis before given *axis*, with shape 1 and contiguous
@@ -287,7 +291,7 @@ struct TensorLayout : public TensorShape {
* *
* By the way this API will modify the format according to the data type * By the way this API will modify the format according to the data type
*/ */
void modify_dtype_inplace(DType dtype);
MGE_WIN_DECLSPEC_FUC void modify_dtype_inplace(DType dtype);


/* =================== generate new layout =================== */ /* =================== generate new layout =================== */


@@ -297,21 +301,23 @@ struct TensorLayout : public TensorShape {
* example: * example:
* (2, 0, 1) -> AxBxC to CxAxB * (2, 0, 1) -> AxBxC to CxAxB
*/ */
TensorLayout dimshuffle(const std::vector<size_t>& dims) const
MEGDNN_WARN_UNUSED_RESULT;
MGE_WIN_DECLSPEC_FUC TensorLayout
dimshuffle(const std::vector<size_t>& dims) const MEGDNN_WARN_UNUSED_RESULT;


/** /**
* \brief Remove an axis from the layout by moving later shape/stride * \brief Remove an axis from the layout by moving later shape/stride
* elements earlier. No extra check is performed. * elements earlier. No extra check is performed.
*/ */
TensorLayout remove_axis(size_t idx) const MEGDNN_WARN_UNUSED_RESULT;
MGE_WIN_DECLSPEC_FUC TensorLayout
remove_axis(size_t idx) const MEGDNN_WARN_UNUSED_RESULT;


/** /**
* \brief Returns a different view. * \brief Returns a different view.
* *
* \throw TensorReshapeError if no stride exists for target shape. * \throw TensorReshapeError if no stride exists for target shape.
*/ */
TensorLayout reshape(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;
MGE_WIN_DECLSPEC_FUC TensorLayout
reshape(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;


/*! /*!
* \brief try to reshape to another view; return whether these two shapes * \brief try to reshape to another view; return whether these two shapes
@@ -319,14 +325,16 @@ struct TensorLayout : public TensorShape {
* \return true iff there exists target stride so this layout can be * \return true iff there exists target stride so this layout can be
* converted to target shape and the elements can match. * converted to target shape and the elements can match.
*/ */
bool try_reshape(TensorLayout& output, const TensorShape& shape) const
MEGDNN_WARN_UNUSED_RESULT;
MGE_WIN_DECLSPEC_FUC bool try_reshape(
TensorLayout& output,
const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;


/*! /*!
* \brief Broadcast on dims with shape == 1 to match target *shape*. * \brief Broadcast on dims with shape == 1 to match target *shape*.
* \throw TensorReshapeError if could not be satisfied * \throw TensorReshapeError if could not be satisfied
*/ */
TensorLayout broadcast(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;
MGE_WIN_DECLSPEC_FUC TensorLayout
broadcast(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;


/*! /*!
* \brief Collapse consecutive axes with contiguous layout together * \brief Collapse consecutive axes with contiguous layout together
@@ -335,13 +343,14 @@ struct TensorLayout : public TensorShape {
* scalar, the result would always be a one-dimensional empty or scalar, * scalar, the result would always be a one-dimensional empty or scalar,
* with stride being 1. * with stride being 1.
*/ */
TensorLayout collapse_contiguous() const MEGDNN_WARN_UNUSED_RESULT;
MGE_WIN_DECLSPEC_FUC TensorLayout
collapse_contiguous() const MEGDNN_WARN_UNUSED_RESULT;


/* =================== properties =================== */ /* =================== properties =================== */


std::string to_string() const;
MGE_WIN_DECLSPEC_FUC std::string to_string() const;


std::string serialize() const;
MGE_WIN_DECLSPEC_FUC std::string serialize() const;
#endif // MEGDNN_CC_HOST #endif // MEGDNN_CC_HOST


/*! /*!
@@ -353,17 +362,17 @@ struct TensorLayout : public TensorShape {
* Note that empty tensors (i.e. with 0 shapes) are not considered as * Note that empty tensors (i.e. with 0 shapes) are not considered as
* contiguous. * contiguous.
*/ */
bool is_contiguous() const;
MGE_WIN_DECLSPEC_FUC bool is_contiguous() const;


//! check whether it is physically contiguous disregarding format //! check whether it is physically contiguous disregarding format
bool is_physical_contiguous() const;
MGE_WIN_DECLSPEC_FUC bool is_physical_contiguous() const;


/*! /*!
* \brief check whether the layout is monotonous * \brief check whether the layout is monotonous
* *
* A tensor is monotonous if abs(stride[i]) >= abs(stride[i+1])*shape[i+1] * A tensor is monotonous if abs(stride[i]) >= abs(stride[i+1])*shape[i+1]
*/ */
bool is_abs_monotonous_allow_brdcst() const;
MGE_WIN_DECLSPEC_FUC bool is_abs_monotonous_allow_brdcst() const;


/*! /*!
* \brief check whether the layout is contiguous, allowing broadcasting * \brief check whether the layout is contiguous, allowing broadcasting
@@ -371,7 +380,7 @@ struct TensorLayout : public TensorShape {
* This checks whether the underlying storage is contiguous, where * This checks whether the underlying storage is contiguous, where
* broadcasting is also considered to be so. * broadcasting is also considered to be so.
*/ */
bool is_contiguous_allow_brdcst() const;
MGE_WIN_DECLSPEC_FUC bool is_contiguous_allow_brdcst() const;


/*! /*!
* \brief if this function returns true, then no two elements can occupy the * \brief if this function returns true, then no two elements can occupy the
@@ -382,15 +391,15 @@ struct TensorLayout : public TensorShape {
* still possible that actually no two elements share the same memory * still possible that actually no two elements share the same memory
* location. * location.
*/ */
bool is_non_overlapping_strong() const;
MGE_WIN_DECLSPEC_FUC bool is_non_overlapping_strong() const;


bool eq_layout(const TensorLayout& rhs) const;
MGE_WIN_DECLSPEC_FUC bool eq_layout(const TensorLayout& rhs) const;


//! get lowest and highest offset reachable from this layout //! get lowest and highest offset reachable from this layout
Span span() const;
MGE_WIN_DECLSPEC_FUC Span span() const;


//! total number of access bytes //! total number of access bytes
size_t access_bytes() const;
MGE_WIN_DECLSPEC_FUC size_t access_bytes() const;
}; };


/** /**


+ 19
- 17
dnn/include/megdnn/dtype.h View File

@@ -386,10 +386,11 @@ using DTypeParam = DTypeParamImpl<typename DTypeTrait<DType>::ctype>;
*/ */
class DType { class DType {
private: private:
MEGDNN_NORETURN void on_request_lowbit_size() const;
MGE_WIN_DECLSPEC_FUC MEGDNN_NORETURN void on_request_lowbit_size() const;
// HACK: This is required in ParameterizedDType::downcast_from // HACK: This is required in ParameterizedDType::downcast_from
public: public:
MEGDNN_NORETURN void on_assert_is_failed(const char* rname) const;
MGE_WIN_DECLSPEC_FUC MEGDNN_NORETURN void on_assert_is_failed(
const char* rname) const;


protected: protected:
struct Trait { struct Trait {
@@ -493,7 +494,7 @@ public:
bool operator!=(const DType& rhs) const { return m_trait != rhs.m_trait; } bool operator!=(const DType& rhs) const { return m_trait != rhs.m_trait; }


//! get dtype object from enum //! get dtype object from enum
static DType from_enum(DTypeEnum ev);
MGE_WIN_DECLSPEC_FUC static DType from_enum(DTypeEnum ev);


//! get a handle of the dtype that could be used for equivalence check //! get a handle of the dtype that could be used for equivalence check
const void* handle() const { return m_trait; } const void* handle() const { return m_trait; }
@@ -531,9 +532,10 @@ class ParameterizedDType MEGDNN_FINAL : public DType {
}; };


// static part of the trait // static part of the trait
static DType::Trait sm_trait;
static MGE_WIN_DECLSPEC_DATA DType::Trait sm_trait;


static Trait* make_from_param(const DTypeParam<SelfType>& param);
MGE_WIN_DECLSPEC_FUC static Trait* make_from_param(
const DTypeParam<SelfType>& param);
explicit ParameterizedDType(DType dtype) : DType(dtype) {} explicit ParameterizedDType(DType dtype) : DType(dtype) {}


public: public:
@@ -569,12 +571,12 @@ public:
//! dtype implementation classes //! dtype implementation classes
namespace dtype { namespace dtype {


#define IMPL(_name) \
class _name MEGDNN_FINAL : public DType { \
static Trait sm_trait; \
\
public: \
_name() : DType(&sm_trait) {} \
#define IMPL(_name) \
class _name MEGDNN_FINAL : public DType { \
static MGE_WIN_DECLSPEC_DATA Trait sm_trait; \
\
public: \
_name() : DType(&sm_trait) {} \
}; };


MEGDNN_FOREACH_DTYPE_NAME(IMPL) MEGDNN_FOREACH_DTYPE_NAME(IMPL)
@@ -764,7 +766,7 @@ struct DTypeParamImpl<dt_quint8> {
uint8_t zero_point; uint8_t zero_point;


DTypeParamImpl<dt_quint8>() = default; DTypeParamImpl<dt_quint8>() = default;
DTypeParamImpl<dt_quint8>(float scale, uint8_t zero_point);
MGE_WIN_DECLSPEC_FUC DTypeParamImpl<dt_quint8>(float scale, uint8_t zero_point);


#ifdef MEGDNN_CC_HOST #ifdef MEGDNN_CC_HOST
std::size_t hash() const; std::size_t hash() const;
@@ -788,7 +790,7 @@ struct DTypeParamImpl<dt_qint8> {
float scale; float scale;


DTypeParamImpl<dt_qint8>() = default; DTypeParamImpl<dt_qint8>() = default;
DTypeParamImpl<dt_qint8>(float scale);
MGE_WIN_DECLSPEC_FUC DTypeParamImpl<dt_qint8>(float scale);
#ifdef MEGDNN_CC_HOST #ifdef MEGDNN_CC_HOST
std::size_t hash() const; std::size_t hash() const;
#endif #endif
@@ -810,7 +812,7 @@ struct DTypeParamImpl<dt_qint16> {
float scale; float scale;


DTypeParamImpl<dt_qint16>() = default; DTypeParamImpl<dt_qint16>() = default;
DTypeParamImpl<dt_qint16>(float scale);
MGE_WIN_DECLSPEC_FUC DTypeParamImpl<dt_qint16>(float scale);
#ifdef MEGDNN_CC_HOST #ifdef MEGDNN_CC_HOST
std::size_t hash() const; std::size_t hash() const;
#endif // MEGDNN_CC_HOST #endif // MEGDNN_CC_HOST
@@ -831,7 +833,7 @@ struct DTypeParamImpl<dt_qint32> {
float scale; float scale;


DTypeParamImpl<dt_qint32>() = default; DTypeParamImpl<dt_qint32>() = default;
DTypeParamImpl<dt_qint32>(float scale);
MGE_WIN_DECLSPEC_FUC DTypeParamImpl<dt_qint32>(float scale);
#ifdef MEGDNN_CC_HOST #ifdef MEGDNN_CC_HOST
std::size_t hash() const; std::size_t hash() const;
#endif // MEGDNN_CC_HOST #endif // MEGDNN_CC_HOST
@@ -854,7 +856,7 @@ struct DTypeParamImpl<dt_quint4> {
uint8_t zero_point; uint8_t zero_point;


DTypeParamImpl<dt_quint4>() = default; DTypeParamImpl<dt_quint4>() = default;
DTypeParamImpl<dt_quint4>(float scale, uint8_t zero_point);
MGE_WIN_DECLSPEC_FUC DTypeParamImpl<dt_quint4>(float scale, uint8_t zero_point);
#ifdef MEGDNN_CC_HOST #ifdef MEGDNN_CC_HOST
std::size_t hash() const; std::size_t hash() const;
#endif #endif
@@ -879,7 +881,7 @@ struct DTypeParamImpl<dt_qint4> {
float scale; float scale;


DTypeParamImpl<dt_qint4>() = default; DTypeParamImpl<dt_qint4>() = default;
DTypeParamImpl<dt_qint4>(float scale);
MGE_WIN_DECLSPEC_FUC DTypeParamImpl<dt_qint4>(float scale);
#ifdef MEGDNN_CC_HOST #ifdef MEGDNN_CC_HOST
std::size_t hash() const; std::size_t hash() const;
#endif #endif


+ 13
- 12
dnn/include/megdnn/handle.h View File

@@ -73,20 +73,20 @@ public:
* *
* **Debug level 1 and 2 should not be used in productions.** * **Debug level 1 and 2 should not be used in productions.**
*/ */
static std::unique_ptr<Handle> make(
MGE_WIN_DECLSPEC_FUC static std::unique_ptr<Handle> make(
megcoreComputingHandle_t computing_handle, int debug_level = 0); megcoreComputingHandle_t computing_handle, int debug_level = 0);


#if MEGDNN_WITH_CUDA #if MEGDNN_WITH_CUDA
static std::unique_ptr<Handle> make_cuda_handle(
MGE_WIN_DECLSPEC_FUC static std::unique_ptr<Handle> make_cuda_handle(
megcoreComputingHandle_t computing_handle); megcoreComputingHandle_t computing_handle);
template <typename opr> template <typename opr>
std::unique_ptr<opr> create_cuda_operator();
MGE_WIN_DECLSPEC_FUC std::unique_ptr<opr> create_cuda_operator();
#endif #endif
#if MEGDNN_WITH_ROCM #if MEGDNN_WITH_ROCM
static std::unique_ptr<Handle> make_rocm_handle(
MGE_WIN_DECLSPEC_FUC static std::unique_ptr<Handle> make_rocm_handle(
megcoreComputingHandle_t computing_handle); megcoreComputingHandle_t computing_handle);
template <typename opr> template <typename opr>
std::unique_ptr<opr> create_rocm_operator();
MGE_WIN_DECLSPEC_FUC std::unique_ptr<opr> create_rocm_operator();
#endif #endif


virtual ~Handle(); virtual ~Handle();
@@ -105,7 +105,7 @@ public:
* *
* This function can be called at most once. * This function can be called at most once.
*/ */
void set_destructor(const thin_function<void()>& d);
MGE_WIN_DECLSPEC_FUC void set_destructor(const thin_function<void()>& d);


/*! /*!
* \brief set a callback to be invoked when an operator is destructed * \brief set a callback to be invoked when an operator is destructed
@@ -116,13 +116,13 @@ public:
cb.swap(m_on_opr_destructed); cb.swap(m_on_opr_destructed);
} }


void on_opr_destructed(OperatorBase* opr);
MGE_WIN_DECLSPEC_FUC void on_opr_destructed(OperatorBase* opr);


/** /**
* \brief Create operator of Opr type. * \brief Create operator of Opr type.
*/ */
template <typename Opr> template <typename Opr>
std::unique_ptr<Opr> create_operator();
MGE_WIN_DECLSPEC_FUC std::unique_ptr<Opr> create_operator();


/* /*
* ============================================================= * =============================================================
@@ -134,13 +134,13 @@ public:
* \brief The internal data pointer of TensorND should be aligned to * \brief The internal data pointer of TensorND should be aligned to
* alignment_requirement() in bytes. * alignment_requirement() in bytes.
*/ */
virtual size_t alignment_requirement() const;
MGE_WIN_DECLSPEC_FUC virtual size_t alignment_requirement() const;


//! get alignment in bytes for rows of image 2D tensor format //! get alignment in bytes for rows of image 2D tensor format
virtual size_t image2d_pitch_alignment() const;
MGE_WIN_DECLSPEC_FUC virtual size_t image2d_pitch_alignment() const;


//! get vendor type //! get vendor type
virtual HandleVendorType vendor_type() const;
MGE_WIN_DECLSPEC_FUC virtual HandleVendorType vendor_type() const;


HandleType type() const { return m_handle_type; } HandleType type() const { return m_handle_type; }


@@ -149,7 +149,8 @@ public:
* 1. The handle of the src and the dst is the same kind * 1. The handle of the src and the dst is the same kind
* 2. The dst is continguous. * 2. The dst is continguous.
*/ */
virtual bool check_cross_dev_copy_constraint(const TensorLayout& src);
MGE_WIN_DECLSPEC_FUC virtual bool check_cross_dev_copy_constraint(
const TensorLayout& src);


private: private:
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u;


+ 9
- 6
dnn/include/megdnn/oprs/general.h View File

@@ -51,7 +51,7 @@ public:
name(NULL) {} name(NULL) {}


//! get trait from a mode; this function is thread safe //! get trait from a mode; this function is thread safe
static const ModeTrait& from_mode(Mode mode);
MGE_WIN_DECLSPEC_FUC static const ModeTrait& from_mode(Mode mode);
}; };


//! get trait of current mode //! get trait of current mode
@@ -69,17 +69,20 @@ public:
virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0; virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0;


//! deduce output shape (do not check whether arity matches) //! deduce output shape (do not check whether arity matches)
static void deduce_shape(const TensorShapeArray& src, TensorShape& dst);
MGE_WIN_DECLSPEC_FUC static void deduce_shape(
const TensorShapeArray& src, TensorShape& dst);


static void deduce_format(const TensorFormatArray& src, TensorFormat& dst);
MGE_WIN_DECLSPEC_FUC static void deduce_format(
const TensorFormatArray& src, TensorFormat& dst);


//! deduce output layout //! deduce output layout
void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst);
MGE_WIN_DECLSPEC_FUC void deduce_layout(
const TensorLayoutArray& src, TensorLayout& dst);


protected: protected:
//! throw exception if incorrect layout; broadcast input shape to //! throw exception if incorrect layout; broadcast input shape to
//! output shape //! output shape
void check_layout_and_broadcast(
MGE_WIN_DECLSPEC_FUC void check_layout_and_broadcast(
const TensorLayoutPtrArray& src, const TensorLayout& dst); const TensorLayoutPtrArray& src, const TensorLayout& dst);


private: private:
@@ -577,7 +580,7 @@ public:
ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {} ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {}


//! generate offsets to be used with ParamPackConcat and ParamPackSplit //! generate offsets to be used with ParamPackConcat and ParamPackSplit
static std::vector<dt_int32> gen_offsets(
MGE_WIN_DECLSPEC_FUC static std::vector<dt_int32> gen_offsets(
const TensorShapeArray& shapes, size_t alignment, size_t dtype_size); const TensorShapeArray& shapes, size_t alignment, size_t dtype_size);
}; };




+ 1
- 1
dnn/include/megdnn/oprs/nn_int.h View File

@@ -43,7 +43,7 @@ public:
const char* name = nullptr; //!< name of the mode const char* name = nullptr; //!< name of the mode


//! get trait from a mode; this function is thread safe //! get trait from a mode; this function is thread safe
static const ModeTrait& from_mode(Mode mode);
MGE_WIN_DECLSPEC_FUC static const ModeTrait& from_mode(Mode mode);
}; };


virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0; virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0;


+ 4
- 2
dnn/include/megdnn/thin/small_vector.h View File

@@ -50,7 +50,8 @@ class SmallVectorBase {
protected: protected:
void *m_begin_ptr, *m_end_ptr, *m_capacity_ptr; void *m_begin_ptr, *m_end_ptr, *m_capacity_ptr;


MEGDNN_NORETURN static void on_invalid_at(size_t idx, size_t size);
MGE_WIN_DECLSPEC_FUC MEGDNN_NORETURN static void on_invalid_at(
size_t idx, size_t size);


protected: protected:
SmallVectorBase(void* first_elm, size_t size) SmallVectorBase(void* first_elm, size_t size)
@@ -58,7 +59,8 @@ protected:
m_end_ptr(first_elm), m_end_ptr(first_elm),
m_capacity_ptr(static_cast<char*>(first_elm) + size) {} m_capacity_ptr(static_cast<char*>(first_elm) + size) {}


void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size);
MGE_WIN_DECLSPEC_FUC void grow_pod(
void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size);


public: public:
size_t size_in_bytes() const { size_t size_in_bytes() const {


+ 2
- 1
dnn/include/megdnn/version.h View File

@@ -14,6 +14,7 @@
#define MEGDNN_MINOR 3 #define MEGDNN_MINOR 3
#define MEGDNN_PATCH 0 #define MEGDNN_PATCH 0


#include "megbrain_build_config.h"
#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"


namespace megdnn { namespace megdnn {
@@ -22,7 +23,7 @@ struct Version {
}; };


//! get megdnn version of the binary //! get megdnn version of the binary
Version get_version();
MGE_WIN_DECLSPEC_FUC Version get_version();
} // namespace megdnn } // namespace megdnn


#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"


+ 4
- 4
imperative/CMakeLists.txt View File

@@ -25,15 +25,15 @@ add_custom_target(_version_ld SOURCES ${MGE_VERSION_SCRIPT})
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/pybind11 ${PROJECT_BINARY_DIR}/third_party/pybind11) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/pybind11 ${PROJECT_BINARY_DIR}/third_party/pybind11)
pybind11_add_module(${MODULE_NAME} NO_EXTRAS ${SRCS}) pybind11_add_module(${MODULE_NAME} NO_EXTRAS ${SRCS})
if (APPLE) if (APPLE)
target_link_libraries(${MODULE_NAME} PRIVATE megengine_export)
target_link_libraries(${MODULE_NAME} PRIVATE megengine_shared)
elseif (MSVC OR WIN32) elseif (MSVC OR WIN32)
# Windows does not support implicitly importing data members from DLL.
target_link_libraries(${MODULE_NAME} PRIVATE megbrain megdnn ${MGE_CUDA_LIBS})
target_link_libraries(${MODULE_NAME} PRIVATE megengine_shared)
target_compile_definitions(${MODULE_NAME} PRIVATE MGE_DLL_IMPORT_DATA)
message(STATUS "CMAKE_MSVC_RUNTIME_LIBRARY: ${CMAKE_MSVC_RUNTIME_LIBRARY}") message(STATUS "CMAKE_MSVC_RUNTIME_LIBRARY: ${CMAKE_MSVC_RUNTIME_LIBRARY}")
set_target_properties(${MODULE_NAME} PROPERTIES MSVC_RUNTIME_LIBRARY "${CMAKE_MSVC_RUNTIME_LIBRARY}") set_target_properties(${MODULE_NAME} PROPERTIES MSVC_RUNTIME_LIBRARY "${CMAKE_MSVC_RUNTIME_LIBRARY}")
else() else()
# use to fix runtime crash when build both mgb(MGE_WITH_PYTHON_MODULE) and imperative(MGE_BUILD_IMPERATIVE_RT) # use to fix runtime crash when build both mgb(MGE_WITH_PYTHON_MODULE) and imperative(MGE_BUILD_IMPERATIVE_RT)
target_link_libraries(${MODULE_NAME} PRIVATE megengine_export -Wl,--version-script=${MGE_VERSION_SCRIPT})
target_link_libraries(${MODULE_NAME} PRIVATE megengine_shared -Wl,--version-script=${MGE_VERSION_SCRIPT})
endif() endif()


add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/range-v3 ${PROJECT_BINARY_DIR}/third_party/range-v3) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/range-v3 ${PROJECT_BINARY_DIR}/third_party/range-v3)


+ 4
- 0
imperative/python/megengine/__init__.py View File

@@ -54,6 +54,8 @@ if sys.platform == "win32":
err.strerror += ' Error loading "{}" or one of its dependencies.'.format( err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
dll dll
) )
err.strerror += " \nplease install VC runtime from: "
err.strerror += " \nhttps://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160"
raise err raise err
elif res is not None: elif res is not None:
is_loaded = True is_loaded = True
@@ -67,6 +69,8 @@ if sys.platform == "win32":
err.strerror += ' Error loading "{}" or one of its dependencies.'.format( err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
dll dll
) )
err.strerror += " \nplease install VC runtime from: "
err.strerror += " \nhttps://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160"
raise err raise err


kernel32.SetErrorMode(old_error_mode) kernel32.SetErrorMode(old_error_mode)


+ 8
- 4
lite/CMakeLists.txt View File

@@ -42,6 +42,9 @@ include_directories($<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/genfiles>)
if(LITE_BUILD_WITH_MGE) if(LITE_BUILD_WITH_MGE)
target_link_libraries(lite_static PRIVATE megbrain megdnn ${MGE_CUDA_LIBS}) target_link_libraries(lite_static PRIVATE megbrain megdnn ${MGE_CUDA_LIBS})
add_compile_definitions(LITE_BUILD_WITH_MGE=1) add_compile_definitions(LITE_BUILD_WITH_MGE=1)
if(WIN32 OR MSVC)
target_compile_definitions(lite_static PRIVATE MGE_DLL_IMPORT_DATA)
endif()
message(STATUS "build lite with MegEngine.") message(STATUS "build lite with MegEngine.")
else() else()
target_link_libraries(lite_static PUBLIC flatbuffers) target_link_libraries(lite_static PUBLIC flatbuffers)
@@ -71,12 +74,13 @@ endif()
# define a shared lib for whl # define a shared lib for whl
add_library(lite_shared_whl SHARED $<TARGET_OBJECTS:lite_static>) add_library(lite_shared_whl SHARED $<TARGET_OBJECTS:lite_static>)
if(LITE_BUILD_WITH_MGE) if(LITE_BUILD_WITH_MGE)
if (MSVC OR WIN32 OR IOS)
# TODO: this will lead whl size increase on Windows, caused by
# Windows does not support implicitly importing data members from DLL.
if (IOS)
target_link_libraries(lite_shared_whl PRIVATE megbrain megdnn ${MGE_CUDA_LIBS}) target_link_libraries(lite_shared_whl PRIVATE megbrain megdnn ${MGE_CUDA_LIBS})
else() else()
target_link_libraries(lite_shared_whl PRIVATE megengine_export)
target_link_libraries(lite_shared_whl PRIVATE megengine_shared)
endif()
if(WIN32 OR MSVC)
target_compile_definitions(lite_shared_whl PRIVATE MGE_DLL_IMPORT_DATA)
endif() endif()
endif() endif()
if(ANDROID) if(ANDROID)


+ 4
- 0
lite/pylite/megenginelite/base.py View File

@@ -56,6 +56,8 @@ if sys.platform == "win32":
err.strerror += ' Error loading "{}" or one of its dependencies.'.format( err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
dll dll
) )
err.strerror += " \nplease install VC runtime from: "
err.strerror += " \nhttps://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160"
raise err raise err
elif res is not None: elif res is not None:
is_loaded = True is_loaded = True
@@ -69,6 +71,8 @@ if sys.platform == "win32":
err.strerror += ' Error loading "{}" or one of its dependencies.'.format( err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
dll dll
) )
err.strerror += " \nplease install VC runtime from: "
err.strerror += " \nhttps://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160"
raise err raise err


kernel32.SetErrorMode(old_error_mode) kernel32.SetErrorMode(old_error_mode)


+ 3
- 3
scripts/whl/macos/macos_build_whl.sh View File

@@ -89,7 +89,7 @@ function config_python_env() {
fi fi
} }


MEGENGINE_LIB="${SRC_DIR}/build_dir/host/MGE_WITH_CUDA_OFF/MGE_INFERENCE_ONLY_OFF/Release/build/src/libmegengine_export.dylib"
MEGENGINE_LIB="${SRC_DIR}/build_dir/host/MGE_WITH_CUDA_OFF/MGE_INFERENCE_ONLY_OFF/Release/build/src/libmegengine_shared.dylib"
function depend_real_copy() { function depend_real_copy() {
REAL_DST=$1 REAL_DST=$1
echo "real copy lib to $1" echo "real copy lib to $1"
@@ -192,7 +192,7 @@ function do_build() {
fi fi


#handle dlopen path #handle dlopen path
install_name_tool -change @rpath/libmegengine_export.dylib @loader_path/lib/libmegengine_export.dylib _imperative_rt.so
install_name_tool -change @rpath/libmegengine_shared.dylib @loader_path/lib/libmegengine_shared.dylib _imperative_rt.so


#copy megbrain_export lib #copy megbrain_export lib
DEPEND_LIB=${BUILD_DIR}/staging/megengine/core/lib/ DEPEND_LIB=${BUILD_DIR}/staging/megengine/core/lib/
@@ -209,7 +209,7 @@ function do_build() {
cp ${SRC_DIR}/build_dir/host/MGE_WITH_CUDA_OFF/MGE_INFERENCE_ONLY_OFF/Release/build/lite/liblite_shared_whl.dylib ${LITE_LIB} cp ${SRC_DIR}/build_dir/host/MGE_WITH_CUDA_OFF/MGE_INFERENCE_ONLY_OFF/Release/build/lite/liblite_shared_whl.dylib ${LITE_LIB}
llvm-strip -s ${LITE_LIB} llvm-strip -s ${LITE_LIB}
#handle dlopen path #handle dlopen path
install_name_tool -change @rpath/libmegengine_export.dylib @loader_path/../../megengine/core/lib/libmegengine_export.dylib ${LITE_LIB}
install_name_tool -change @rpath/libmegengine_shared.dylib @loader_path/../../megengine/core/lib/libmegengine_shared.dylib ${LITE_LIB}


cd ${BUILD_DIR}/staging cd ${BUILD_DIR}/staging
${PYTHON_DIR}/bin/python3 setup.py bdist_wheel ${PYTHON_DIR}/bin/python3 setup.py bdist_wheel


+ 4
- 4
scripts/whl/manylinux2014/do_build_common.sh View File

@@ -50,10 +50,10 @@ function patch_elf_depend_lib_mgb_mge() {
patchelf --force-rpath --set-rpath '$ORIGIN/lib' ${BUILD_DIR}/staging/megengine/core/_imperative_rt.so patchelf --force-rpath --set-rpath '$ORIGIN/lib' ${BUILD_DIR}/staging/megengine/core/_imperative_rt.so
handle_strip ${BUILD_DIR}/staging/megengine/core/_imperative_rt.so handle_strip ${BUILD_DIR}/staging/megengine/core/_imperative_rt.so


cp ${BUILD_DIR}/src/libmegengine_export.so ${LIBS_DIR}
patchelf --remove-rpath ${LIBS_DIR}/libmegengine_export.so
patchelf --force-rpath --set-rpath '$ORIGIN/.' ${LIBS_DIR}/libmegengine_export.so
handle_strip ${LIBS_DIR}/libmegengine_export.so
cp ${BUILD_DIR}/src/libmegengine_shared.so ${LIBS_DIR}
patchelf --remove-rpath ${LIBS_DIR}/libmegengine_shared.so
patchelf --force-rpath --set-rpath '$ORIGIN/.' ${LIBS_DIR}/libmegengine_shared.so
handle_strip ${LIBS_DIR}/libmegengine_shared.so


# as some version of cudnn/trt libs have dlopen libs, so we can not use auditwheel # as some version of cudnn/trt libs have dlopen libs, so we can not use auditwheel
# TODO: PR for auditwheel to support args for dlopen libs # TODO: PR for auditwheel to support args for dlopen libs


+ 20
- 25
scripts/whl/windows/windows_build_whl.sh View File

@@ -18,8 +18,6 @@ function append_path_env_and_check() {
export VS_PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2019/Enterprise export VS_PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2019/Enterprise
echo "export LLVM install path" echo "export LLVM install path"
export LLVM_PATH=/c/Program\ Files/LLVM_12_0_1 export LLVM_PATH=/c/Program\ Files/LLVM_12_0_1
# for llvm-strip
export PATH=${LLVM_PATH}/bin/:$PATH
} }


append_path_env_and_check append_path_env_and_check
@@ -78,16 +76,23 @@ CUBLAS_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cublas6
CURAND_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/curand64_10.dll" CURAND_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/curand64_10.dll"
CUBLASLT_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cublasLt64_10.dll" CUBLASLT_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cublasLt64_10.dll"
CUDART_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cudart64_101.dll" CUDART_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cudart64_101.dll"
MGE_EXPORT_LIB="${SRC_DIR}/build_dir/host/build/src/megengine_shared.dll"

function depend_real_copy() { function depend_real_copy() {
REAL_DST=$1 REAL_DST=$1
echo "real copy lib to $1" echo "real copy lib to $1"
cp "${TRT_LIB}" ${REAL_DST}
cp "${CUDNN_LIB}" ${REAL_DST}
cp "${CUSOLVER_LIB}" ${REAL_DST}
cp "${CUBLAS_LIB}" ${REAL_DST}
cp "${CURAND_LIB}" ${REAL_DST}
cp "${CUBLASLT_LIB}" ${REAL_DST}
cp "${CUDART_LIB}" ${REAL_DST}
cp "${MGE_EXPORT_LIB}" ${REAL_DST}

if [ ${BUILD_WHL_CPU_ONLY} = "OFF" ]; then
echo "copy nvidia lib...."
cp "${TRT_LIB}" ${REAL_DST}
cp "${CUDNN_LIB}" ${REAL_DST}
cp "${CUSOLVER_LIB}" ${REAL_DST}
cp "${CUBLAS_LIB}" ${REAL_DST}
cp "${CURAND_LIB}" ${REAL_DST}
cp "${CUBLASLT_LIB}" ${REAL_DST}
cp "${CUDART_LIB}" ${REAL_DST}
fi
} }


function copy_more_dll() { function copy_more_dll() {
@@ -97,23 +102,15 @@ function copy_more_dll() {
rm -rf ${CP_WHL_DST_IMP} rm -rf ${CP_WHL_DST_IMP}
mkdir ${CP_WHL_DST_IMP} mkdir ${CP_WHL_DST_IMP}


# workround for cpu-only version import failed, use a
# empty.file to triger setup.py to create a null empty
echo "empty" > ${CP_WHL_DST_IMP}/empty.file

if [ ${BUILD_WHL_CPU_ONLY} = "OFF" ]; then
echo "copy nvidia lib to whl use...."
depend_real_copy ${CP_WHL_DST_IMP}
fi
depend_real_copy ${CP_WHL_DST_IMP}
} }


function lite_copy_more_dll() { function lite_copy_more_dll() {
if [ ${BUILD_WHL_CPU_ONLY} = "OFF" ]; then
if [ ${IN_CI} = "true" ]; then
echo "copy lib for lite for ci test"
IMP_TEST_DST=${SRC_DIR}/build_dir/host/build/lite/test/
depend_real_copy ${IMP_TEST_DST}
fi
if [ ${IN_CI} = "true" ]; then
echo "copy lib for lite for ci test"
IMP_TEST_DST=${SRC_DIR}/build_dir/host/build/lite/test/
depend_real_copy ${IMP_TEST_DST}
rm "${IMP_TEST_DST}/megengine_shared.dll"
fi fi
} }


@@ -199,7 +196,6 @@ function do_build() {
echo "ERR: can not find valid rt file" echo "ERR: can not find valid rt file"
exit -1 exit -1
fi fi
llvm-strip -s ${rt_file}
mv ${rt_file} _imperative_rt.pyd mv ${rt_file} _imperative_rt.pyd


copy_more_dll copy_more_dll
@@ -212,7 +208,6 @@ function do_build() {
mkdir -p ${LITE_CORE_LIB_DIR} mkdir -p ${LITE_CORE_LIB_DIR}
cd ${LITE_CORE_LIB_DIR} cd ${LITE_CORE_LIB_DIR}
cp ${BUILD_DIR}/lite/lite_shared_whl.dll liblite_shared_whl.pyd cp ${BUILD_DIR}/lite/lite_shared_whl.dll liblite_shared_whl.pyd
llvm-strip -s liblite_shared_whl.pyd
lite_copy_more_dll lite_copy_more_dll


cd ${BUILD_DIR}/staging cd ${BUILD_DIR}/staging


+ 17
- 14
sdk/load-and-run/CMakeLists.txt View File

@@ -1,21 +1,24 @@
include_directories(src) include_directories(src)
file (GLOB_RECURSE SOURCES src/*.cpp main.cpp)
add_executable (load_and_run ${SOURCES})
file(GLOB_RECURSE SOURCES src/*.cpp main.cpp)


if (WIN32)
# Windows does not support implicitly importing data members from DLL.
target_link_libraries(load_and_run megbrain megdnn ${MGE_CUDA_LIBS})
else()
target_link_libraries (load_and_run megengine)
add_executable(load_and_run ${SOURCES})
target_link_libraries(load_and_run megbrain megdnn ${MGE_CUDA_LIBS})

# load_and_run_depends_shared always for CI check, please do not delete
if(BUILD_SHARED_LIBS)
add_executable(load_and_run_depends_shared ${SOURCES})
target_link_libraries(load_and_run_depends_shared megengine)
if(WIN32 OR MSVC)
target_compile_definitions(load_and_run_depends_shared PRIVATE MGE_DLL_IMPORT_DATA)
endif()
endif()

install(TARGETS load_and_run EXPORT ${MGE_EXPORT_TARGETS} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
if(BUILD_SHARED_LIBS)
install(TARGETS load_and_run_depends_shared EXPORT ${MGE_EXPORT_TARGETS} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
endif() endif()
install (TARGETS load_and_run EXPORT ${MGE_EXPORT_TARGETS} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})


if(MGE_WITH_TEST) if(MGE_WITH_TEST)
add_executable(json_loader_test test/json_loader_test.cpp src/json_loader.h src/json_loader.cpp) add_executable(json_loader_test test/json_loader_test.cpp src/json_loader.h src/json_loader.cpp)
# Windows does not support implicitly importing data members from DLL.
if (WIN32)
target_link_libraries (json_loader_test megbrain megdnn ${MGE_CUDA_LIBS})
else()
target_link_libraries (json_loader_test megengine)
endif()
target_link_libraries(json_loader_test megbrain megdnn ${MGE_CUDA_LIBS})
endif() endif()

+ 39
- 29
src/CMakeLists.txt View File

@@ -204,35 +204,45 @@ endif()


set (_VER_FILE ${PROJECT_SOURCE_DIR}/src/version.ld) set (_VER_FILE ${PROJECT_SOURCE_DIR}/src/version.ld)


# Windows does not support implicitly importing data members from DLL.
# on Windows:
# depends on megdnn/megbrain target, refs to sdk/load-and-run/CMakeLists.txt
# depends on megengine lite_share or lite_static
if(NOT WIN32)
message(VERBOSE "create a export SHARED lib for python use")
add_library(megengine_export SHARED)
target_link_libraries(megengine_export PUBLIC megbrain megdnn)
target_link_libraries(megengine_export PRIVATE ${MGE_CUDA_LIBS})
if (MGE_WITH_DISTRIBUTED)
message(VERBOSE "megengine_export configured to link megray")
target_link_libraries(megengine_export PUBLIC megray)
endif()

# Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF
add_library(megengine)
target_link_libraries(megengine PRIVATE ${MGE_CUDA_LIBS})
target_link_libraries(megengine PUBLIC megbrain megdnn)
if (UNIX AND NOT APPLE)
target_link_options(megengine PRIVATE -Wl,--no-undefined -Wl,--version-script=${_VER_FILE})
set_target_properties(megengine PROPERTIES LINK_DEPENDS ${_VER_FILE})
endif()
# Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready
# for this.
install(TARGETS megengine
EXPORT ${MGE_EXPORT_TARGETS}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif()
# Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF
add_library(megengine)
# force define a SHARED target for whl, caused by when build for APPLE
# we will force set BUILD_SHARED_LIBS=OFF for xcode needed
add_library(megengine_shared SHARED)
target_link_libraries(megengine PRIVATE ${MGE_CUDA_LIBS})
target_link_libraries(megengine PUBLIC megbrain megdnn)
target_link_libraries(megengine_shared PUBLIC megbrain megdnn)
target_link_libraries(megengine_shared PRIVATE ${MGE_CUDA_LIBS})
if (UNIX AND NOT APPLE)
target_link_options(megengine PRIVATE -Wl,--no-undefined -Wl,--version-script=${_VER_FILE})
set_target_properties(megengine PROPERTIES LINK_DEPENDS ${_VER_FILE})
target_link_options(megengine_shared PRIVATE -Wl,--no-undefined -Wl,--version-script=${_VER_FILE})
set_target_properties(megengine_shared PROPERTIES LINK_DEPENDS ${_VER_FILE})
endif()
if(WIN32 OR MSVC)
target_compile_definitions(megbrain PRIVATE MGE_DLL_EXPORT)
target_compile_definitions(megdnn PRIVATE MGE_DLL_EXPORT)
target_compile_definitions(megengine PRIVATE MGE_DLL_EXPORT)
target_compile_definitions(megengine_shared PRIVATE MGE_DLL_EXPORT)
target_compile_definitions(megbrain PRIVATE MGE_DLL_EXPORT_DATA)
target_compile_definitions(megdnn PRIVATE MGE_DLL_EXPORT_DATA)
target_compile_definitions(megengine PRIVATE MGE_DLL_EXPORT_DATA)
target_compile_definitions(megengine_shared PRIVATE MGE_DLL_EXPORT_DATA)
# please do not use WINDOWS_EXPORT_ALL_SYMBOLS, as symbols max than 65535 when build with CUDA
#set_target_properties(megengine PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
#set_target_properties(megengine_shared PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
endif()
if (MGE_WITH_DISTRIBUTED)
message(VERBOSE "megengine configured to link megray")
target_link_libraries(megengine PUBLIC megray)
target_link_libraries(megengine_shared PUBLIC megray)
endif()
# Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready
# for this.
install(TARGETS megengine
EXPORT ${MGE_EXPORT_TARGETS}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})


if (NOT MGE_WITH_DISTRIBUTED) if (NOT MGE_WITH_DISTRIBUTED)
install(TARGETS megbrain install(TARGETS megbrain


+ 13
- 0
src/core/impl/comp_node/cuda/comp_node.cpp View File

@@ -612,6 +612,19 @@ bool CudaCompNodeImpl::check_global_finalized() {
"recovery by OS!!"); "recovery by OS!!");
return true; return true;
} }
//! FIXME: megengine dynamic with VCRT, atexit fuctions table have
//! some order issue, which will lead to cuda runtime uploading, this
//! always happened at python3 unload dll(means python3 will exit),
//! as a workround, recovery resource by OS temporarily, may need
//! remove this after upgrade cuda runtime
int dev = -1;
if (cudaErrorCudartUnloading == cudaGetDevice(&dev)) {
mgb_log_debug(
"windows cudaErrorCudartUnloading happened!!, resource "
"recovery by OS!!");
return true;
}

#endif #endif
return false; return false;
} }


+ 6
- 6
src/core/impl/graph/static_infer_impl.h View File

@@ -58,18 +58,18 @@ public:
/*! /*!
* \brief get a tag handler for shape inference * \brief get a tag handler for shape inference
*/ */
TagHandler* get_tag_handler_for_shape(Tag tag);
MGE_WIN_DECLSPEC_FUC TagHandler* get_tag_handler_for_shape(Tag tag);


/*! /*!
* \brief get a tag handler for value inference * \brief get a tag handler for value inference
*/ */
TagHandler* get_tag_handler_for_value(Tag tag);
MGE_WIN_DECLSPEC_FUC TagHandler* get_tag_handler_for_value(Tag tag);


/*! /*!
* \brief clear registered handler for a tag; this is only used in error * \brief clear registered handler for a tag; this is only used in error
* handling in opr creation * handling in opr creation
*/ */
void clear_tag_handler(Tag tag);
MGE_WIN_DECLSPEC_FUC void clear_tag_handler(Tag tag);


/*! /*!
* \brief set the operator that is allowd to call register_*_infer * \brief set the operator that is allowd to call register_*_infer
@@ -87,13 +87,13 @@ public:
* tag * tag
* \return set of missing inputs; the pointer is always available * \return set of missing inputs; the pointer is always available
*/ */
const TagHandlerSet& get_missing_inp(TagHandler* dest);
MGE_WIN_DECLSPEC_FUC const TagHandlerSet& get_missing_inp(TagHandler* dest);


/*! /*!
* \brief update mutable src tag's shape explictly which only used by * \brief update mutable src tag's shape explictly which only used by
eager eval eager eval
*/ */
void update_mutable_src_shape(Tag tag);
MGE_WIN_DECLSPEC_FUC void update_mutable_src_shape(Tag tag);


/*! /*!
* \brief get original deps given in the InferDesc which is registered * \brief get original deps given in the InferDesc which is registered
@@ -103,7 +103,7 @@ public:
* deps since the StaticInferManagerImpl folds the infererence chain of * deps since the StaticInferManagerImpl folds the infererence chain of
* the const var shape * the const var shape
*/ */
DepVal get_deps(const DepElement& elem);
MGE_WIN_DECLSPEC_FUC DepVal get_deps(const DepElement& elem);


private: private:
friend class CompSeqManager; friend class CompSeqManager;


+ 4
- 4
src/core/impl/tensor.cpp View File

@@ -333,7 +333,7 @@ namespace mgb {
// host to host // host to host
template <> template <>
template <> template <>
void TensorStorage<HostTensorStorageTrait>::copy_from(
MGE_WIN_DECLSPEC_FUC void TensorStorage<HostTensorStorageTrait>::copy_from(
const TensorStorage<HostTensorStorageTrait>& src, size_t size) const { const TensorStorage<HostTensorStorageTrait>& src, size_t size) const {
mgb_assert(size <= this->size() && size <= src.size()); mgb_assert(size <= this->size() && size <= src.size());
memcpy(ptr(), src.ptr(), size); memcpy(ptr(), src.ptr(), size);
@@ -342,7 +342,7 @@ void TensorStorage<HostTensorStorageTrait>::copy_from(
// device to host // device to host
template <> template <>
template <> template <>
void TensorStorage<HostTensorStorageTrait>::copy_from(
MGE_WIN_DECLSPEC_FUC void TensorStorage<HostTensorStorageTrait>::copy_from(
const TensorStorage<DeviceTensorStorageTrait>& src, size_t size) const { const TensorStorage<DeviceTensorStorageTrait>& src, size_t size) const {
bool need_sync = false; bool need_sync = false;
mgb_assert(size <= this->size() && size <= src.size()); mgb_assert(size <= this->size() && size <= src.size());
@@ -370,7 +370,7 @@ void TensorStorage<HostTensorStorageTrait>::copy_from(
// host to device // host to device
template <> template <>
template <> template <>
void TensorStorage<DeviceTensorStorageTrait>::copy_from(
MGE_WIN_DECLSPEC_FUC void TensorStorage<DeviceTensorStorageTrait>::copy_from(
const TensorStorage<HostTensorStorageTrait>& src, size_t size) const { const TensorStorage<HostTensorStorageTrait>& src, size_t size) const {
mgb_assert(size <= this->size() && size <= src.size()); mgb_assert(size <= this->size() && size <= src.size());
m_comp_node.copy_to_device(ptr(), src.ptr(), size); m_comp_node.copy_to_device(ptr(), src.ptr(), size);
@@ -379,7 +379,7 @@ void TensorStorage<DeviceTensorStorageTrait>::copy_from(
// device to device // device to device
template <> template <>
template <> template <>
void TensorStorage<DeviceTensorStorageTrait>::copy_from(
MGE_WIN_DECLSPEC_FUC void TensorStorage<DeviceTensorStorageTrait>::copy_from(
const TensorStorage<DeviceTensorStorageTrait>& src, size_t size) const { const TensorStorage<DeviceTensorStorageTrait>& src, size_t size) const {
mgb_assert(size <= this->size() && size <= src.size()); mgb_assert(size <= this->size() && size <= src.size());
if (src.comp_node().device_type() == CompNode::DeviceType::CPU && if (src.comp_node().device_type() == CompNode::DeviceType::CPU &&


+ 8
- 7
src/core/include/megbrain/common.h View File

@@ -110,7 +110,7 @@ void __on_exception_throw__(const std::exception& exc) __attribute__((noreturn))
} while (0) } while (0)


// assert // assert
void __assert_fail__(
MGE_WIN_DECLSPEC_FUC void __assert_fail__(
const char* file, int line, const char* func, const char* expr, const char* file, int line, const char* func, const char* expr,
const char* msg_fmt = 0, ...) __attribute__((format(printf, 5, 6), noreturn)); const char* msg_fmt = 0, ...) __attribute__((format(printf, 5, 6), noreturn));
#if MGB_ASSERT_LOC #if MGB_ASSERT_LOC
@@ -165,23 +165,23 @@ typedef void (*LogHandler)(
* *
* \return previous log level * \return previous log level
*/ */
LogLevel set_log_level(LogLevel level);
MGE_WIN_DECLSPEC_FUC LogLevel set_log_level(LogLevel level);


/*! /*!
* \brief get logging level * \brief get logging level
* *
* \return current log level * \return current log level
*/ */
LogLevel get_log_level();
MGE_WIN_DECLSPEC_FUC LogLevel get_log_level();


/*! /*!
* \brief set callback for receiving log requests * \brief set callback for receiving log requests
* \return previous log handler * \return previous log handler
*/ */
LogHandler set_log_handler(LogHandler handler);
MGE_WIN_DECLSPEC_FUC LogHandler set_log_handler(LogHandler handler);


#if MGB_ENABLE_LOGGING #if MGB_ENABLE_LOGGING
void __log__(
MGE_WIN_DECLSPEC_FUC void __log__(
LogLevel level, const char* file, const char* func, int line, const char* fmt, LogLevel level, const char* file, const char* func, int line, const char* fmt,
...) __attribute__((format(printf, 5, 6))); ...) __attribute__((format(printf, 5, 6)));


@@ -233,9 +233,10 @@ void __log__(
/*! /*!
* \brief printf-like std::string constructor * \brief printf-like std::string constructor
*/ */
std::string ssprintf(const char* fmt, ...) __attribute__((format(printf, 1, 2)));
MGE_WIN_DECLSPEC_FUC std::string ssprintf(const char* fmt, ...)
__attribute__((format(printf, 1, 2)));


std::string svsprintf(const char* fmt, va_list ap);
MGE_WIN_DECLSPEC_FUC std::string svsprintf(const char* fmt, va_list ap);


#if 0 #if 0
// used for win32 with vs prior to 2015 // used for win32 with vs prior to 2015


+ 37
- 34
src/core/include/megbrain/comp_node.h View File

@@ -129,18 +129,19 @@ public:
* currently supported ID format: (gpu|cpu)<n>[:m] where n is the * currently supported ID format: (gpu|cpu)<n>[:m] where n is the
* device number, possibly with m as the stream id. * device number, possibly with m as the stream id.
*/ */
static Locator parse(const std::string& id);
MGE_WIN_DECLSPEC_FUC static Locator parse(const std::string& id);


/*! /*!
* \brief set mapping between device numbers of a device type * \brief set mapping between device numbers of a device type
*/ */
static void set_device_map(DeviceType type, int from, int to);
MGE_WIN_DECLSPEC_FUC static void set_device_map(
DeviceType type, int from, int to);


/*! /*!
* \brief set the actual device type to be used for * \brief set the actual device type to be used for
* DeviceType::UNSPEC * DeviceType::UNSPEC
*/ */
static void set_unspec_device_type(DeviceType type);
MGE_WIN_DECLSPEC_FUC static void set_unspec_device_type(DeviceType type);


/*! /*!
* \brief get corresponding physical Locator * \brief get corresponding physical Locator
@@ -148,13 +149,13 @@ public:
* DeviceType::UNSPEC would be resolved, and device map would be * DeviceType::UNSPEC would be resolved, and device map would be
* applied on device number * applied on device number
*/ */
Locator to_physical() const;
MGE_WIN_DECLSPEC_FUC Locator to_physical() const;


/*! /*!
* \brief get string description of this locator that can be parsed * \brief get string description of this locator that can be parsed
* again * again
*/ */
std::string to_string() const;
MGE_WIN_DECLSPEC_FUC std::string to_string() const;


bool operator==(const Locator& rhs) const { bool operator==(const Locator& rhs) const {
return type == rhs.type && device == rhs.device && stream == rhs.stream; return type == rhs.type && device == rhs.device && stream == rhs.stream;
@@ -186,7 +187,7 @@ public:
/*! /*!
* \brief manually destroy all comp node resources * \brief manually destroy all comp node resources
*/ */
static void finalize();
MGE_WIN_DECLSPEC_FUC static void finalize();


/*! /*!
* \brief load a computing node from logical locator ID; * \brief load a computing node from logical locator ID;
@@ -201,7 +202,7 @@ public:
return load(locator.to_physical(), locator); return load(locator.to_physical(), locator);
} }


static CompNode load(
MGE_WIN_DECLSPEC_FUC static CompNode load(
const Locator& locator_physical, const Locator& locator_logical); const Locator& locator_physical, const Locator& locator_logical);


/* =================== memory management ======================== */ /* =================== memory management ======================== */
@@ -216,10 +217,10 @@ public:
* *
* Exception should be raised if allocation fails. * Exception should be raised if allocation fails.
*/ */
void* alloc_device(size_t size) const;
MGE_WIN_DECLSPEC_FUC void* alloc_device(size_t size) const;


//! deallocate device buffer; see alloc_device() for more details //! deallocate device buffer; see alloc_device() for more details
void free_device(void* ptr) const;
MGE_WIN_DECLSPEC_FUC void free_device(void* ptr) const;


/*! /*!
* \brief allocate memory on host that is associated with the device, * \brief allocate memory on host that is associated with the device,
@@ -227,9 +228,9 @@ public:
* *
* Both allocation and deallocation on host are synchronous. * Both allocation and deallocation on host are synchronous.
*/ */
void* alloc_host(size_t size) const;
MGE_WIN_DECLSPEC_FUC void* alloc_host(size_t size) const;


void free_host(void* ptr) const;
MGE_WIN_DECLSPEC_FUC void free_host(void* ptr) const;


//! copy from underlying device to host //! copy from underlying device to host
void copy_to_host(void* host_ptr, const void* device_ptr, size_t size) const { void copy_to_host(void* host_ptr, const void* device_ptr, size_t size) const {
@@ -269,19 +270,20 @@ public:
* \brief release consecutive free chunks on all devices to defragment; * \brief release consecutive free chunks on all devices to defragment;
* see DevMemAlloc::try_coalesce_free * see DevMemAlloc::try_coalesce_free
*/ */
static void try_coalesce_all_free_memory();
MGE_WIN_DECLSPEC_FUC static void try_coalesce_all_free_memory();


/* /*
* \brief specifies how to pre-allocate from raw dev allocator * \brief specifies how to pre-allocate from raw dev allocator
* *
*/ */
static void set_prealloc_config(
MGE_WIN_DECLSPEC_FUC static void set_prealloc_config(
size_t alignment, size_t min_req, size_t max_overhead, double growth_factor, size_t alignment, size_t min_req, size_t max_overhead, double growth_factor,
DeviceType device_type); DeviceType device_type);
/*! /*!
* \brief get compute capability of the specified device * \brief get compute capability of the specified device
*/ */
static size_t get_compute_capability(int dev, DeviceType device_type);
MGE_WIN_DECLSPEC_FUC static size_t get_compute_capability(
int dev, DeviceType device_type);


/* =================== synchronization ======================== */ /* =================== synchronization ======================== */


@@ -304,7 +306,7 @@ public:
/*! /*!
* \brief synchronize all computing nodes * \brief synchronize all computing nodes
*/ */
static void sync_all();
MGE_WIN_DECLSPEC_FUC static void sync_all();


/* =================== misc ======================== */ /* =================== misc ======================== */


@@ -341,7 +343,7 @@ public:
#endif #endif


//! change to another stream on the same memory node //! change to another stream on the same memory node
CompNode change_stream(int dest_stream) const;
MGE_WIN_DECLSPEC_FUC CompNode change_stream(int dest_stream) const;


//! get string representation //! get string representation
std::string to_string() const { std::string to_string() const {
@@ -371,10 +373,10 @@ public:
Locator locator_logical() const { return m_impl->locator_logical(); } Locator locator_logical() const { return m_impl->locator_logical(); }


//! see CompNodeEnv::activate //! see CompNodeEnv::activate
void activate() const;
MGE_WIN_DECLSPEC_FUC void activate() const;


//! get device type of this comp node //! get device type of this comp node
DeviceType device_type() const;
MGE_WIN_DECLSPEC_FUC DeviceType device_type() const;


/*! /*!
* \brief check for error on the asynchronous computing stream * \brief check for error on the asynchronous computing stream
@@ -385,7 +387,7 @@ public:
* directly throw exception; return nullptr if no error. * directly throw exception; return nullptr if no error.
*/ */
MGB_WARN_UNUSED_RESULT MGB_WARN_UNUSED_RESULT
std::unique_ptr<MegBrainError> check_async_error() const;
MGE_WIN_DECLSPEC_FUC std::unique_ptr<MegBrainError> check_async_error() const;


/*! /*!
* \brief create a CompNodeSeqRecorder associated with this computing * \brief create a CompNodeSeqRecorder associated with this computing
@@ -461,7 +463,7 @@ public:


bool contain_flag(Flag flag) { return contain_flag(device_type(), flag); } bool contain_flag(Flag flag) { return contain_flag(device_type(), flag); }


static bool contain_flag(DeviceType device_type, Flag flag);
MGE_WIN_DECLSPEC_FUC static bool contain_flag(DeviceType device_type, Flag flag);


using UnorderedSet = ThinHashSet<CompNode>; using UnorderedSet = ThinHashSet<CompNode>;


@@ -469,16 +471,17 @@ public:
using UnorderedMap = ThinHashMap<CompNode, T>; using UnorderedMap = ThinHashMap<CompNode, T>;


//! apply function to each initialized comp node //! apply function to each initialized comp node
static void foreach (thin_function<void(CompNode)> callback);
MGE_WIN_DECLSPEC_FUC static void foreach (thin_function<void(CompNode)> callback);


//! get total number of specific devices on this system //! get total number of specific devices on this system
static size_t get_device_count(DeviceType type, bool warn = true);
MGE_WIN_DECLSPEC_FUC static size_t get_device_count(
DeviceType type, bool warn = true);


/* =================== specialized ======================== */ /* =================== specialized ======================== */


//! get default CPU comp node //! get default CPU comp node
// implemented in comp_node/cpu/comp_node.cpp // implemented in comp_node/cpu/comp_node.cpp
static CompNode default_cpu();
MGE_WIN_DECLSPEC_FUC static CompNode default_cpu();


/*! /*!
* \brief set whether to enable affinity setting for CPU comp nodes * \brief set whether to enable affinity setting for CPU comp nodes
@@ -491,7 +494,7 @@ public:
* *
* \return original setting * \return original setting
*/ */
static bool enable_affinity_for_cpu(bool flag);
MGE_WIN_DECLSPEC_FUC static bool enable_affinity_for_cpu(bool flag);


protected: protected:
//! ImplBase with env(); defined in CompNodeEnv //! ImplBase with env(); defined in CompNodeEnv
@@ -680,15 +683,15 @@ class CompNode::EventPool {
size_t m_flags; size_t m_flags;


public: public:
explicit EventPool(CompNode cn, size_t flags = 0);
~EventPool();
MGE_WIN_DECLSPEC_FUC explicit EventPool(CompNode cn, size_t flags = 0);
MGE_WIN_DECLSPEC_FUC ~EventPool();


CompNode::Event* alloc();
MGE_WIN_DECLSPEC_FUC CompNode::Event* alloc();


void free(CompNode::Event* ev);
MGE_WIN_DECLSPEC_FUC void free(CompNode::Event* ev);


//! assert that all allocated events have been freed //! assert that all allocated events have been freed
void assert_all_freed();
MGE_WIN_DECLSPEC_FUC void assert_all_freed();
}; };


void CompNode::device_wait_event(Event& event) const { void CompNode::device_wait_event(Event& event) const {
@@ -732,14 +735,14 @@ class DepedentObjList {
} }


protected: protected:
virtual std::shared_ptr<void> callback() = 0;
MGE_WIN_DECLSPEC_FUC virtual std::shared_ptr<void> callback() = 0;
~DepedentObjList() = default; ~DepedentObjList() = default;


static void add(DepedentObjList* ptr);
static void remove(DepedentObjList* ptr);
MGE_WIN_DECLSPEC_FUC static void add(DepedentObjList* ptr);
MGE_WIN_DECLSPEC_FUC static void remove(DepedentObjList* ptr);


public: public:
static void invoke_callback_and_clean();
MGE_WIN_DECLSPEC_FUC static void invoke_callback_and_clean();
}; };


} // namespace comp_node_detail } // namespace comp_node_detail
@@ -764,7 +767,7 @@ public:
class CompNodeDepedentObject : private comp_node_detail::DepedentObjList { class CompNodeDepedentObject : private comp_node_detail::DepedentObjList {
//! 1: in on_comp_node_finalize(); 2: after on_comp_node_finalize() //! 1: in on_comp_node_finalize(); 2: after on_comp_node_finalize()
int m_state = 0; int m_state = 0;
std::shared_ptr<void> callback() override final;
MGE_WIN_DECLSPEC_FUC std::shared_ptr<void> callback() override final;


protected: protected:
CompNodeDepedentObject() { add(this); } CompNodeDepedentObject() { add(this); }


+ 12
- 10
src/core/include/megbrain/comp_node_env.h View File

@@ -191,10 +191,10 @@ namespace mgb {
#endif #endif


#if MGB_CUDA #if MGB_CUDA
[[noreturn]] void _on_cuda_error(
[[noreturn]] MGE_WIN_DECLSPEC_FUC void _on_cuda_error(
const char* expr, cudaError_t err, const char* file, const char* func, const char* expr, cudaError_t err, const char* file, const char* func,
int line); int line);
[[noreturn]] void _on_cuda_cu_error(
[[noreturn]] MGE_WIN_DECLSPEC_FUC void _on_cuda_cu_error(
const char* expr, CUresult err, const char* file, const char* func, int line); const char* expr, CUresult err, const char* file, const char* func, int line);
#endif #endif


@@ -509,13 +509,14 @@ public:
bool* do_task_inplace = nullptr; bool* do_task_inplace = nullptr;
#endif #endif


void enable_dispatch();
MGE_WIN_DECLSPEC_FUC void enable_dispatch();


void disable_dispatch(bool* flag);
MGE_WIN_DECLSPEC_FUC void disable_dispatch(bool* flag);


void dispatch(Task&& task) const;
MGE_WIN_DECLSPEC_FUC void dispatch(Task&& task) const;


void dispatch(MultiThreadingTask&& task, size_t parallelism) const;
MGE_WIN_DECLSPEC_FUC void dispatch(
MultiThreadingTask&& task, size_t parallelism) const;


void set_affinity(AffinityCallBack&& cb) const { void set_affinity(AffinityCallBack&& cb) const {
dispatcher->set_affinity(std::move(cb)); dispatcher->set_affinity(std::move(cb));
@@ -560,7 +561,8 @@ private:
std::unique_ptr<UserDataContainer> m_user_data_container; std::unique_ptr<UserDataContainer> m_user_data_container;
mutable RecursiveSpinlock m_user_data_container_mtx; mutable RecursiveSpinlock m_user_data_container_mtx;


[[noreturn]] void on_bad_device_type(DeviceType expected) const;
[[noreturn]] MGE_WIN_DECLSPEC_FUC void on_bad_device_type(
DeviceType expected) const;


#if MGB_ENABLE_COMP_NODE_ASYNC_INIT #if MGB_ENABLE_COMP_NODE_ASYNC_INIT
//! whether async init is in future; set by init*_async methods //! whether async init is in future; set by init*_async methods
@@ -575,7 +577,7 @@ private:
} }
} }


void wait_async_init();
MGE_WIN_DECLSPEC_FUC void wait_async_init();
#else #else
void ensure_async_init_finished() const {} void ensure_async_init_finished() const {}
#endif #endif
@@ -597,10 +599,10 @@ class MegDNNHandle final : public UserDataContainer::UserData,
#endif #endif


public: public:
MegDNNHandle(const CompNodeEnv& env);
MGE_WIN_DECLSPEC_FUC MegDNNHandle(const CompNodeEnv& env);
~MegDNNHandle() noexcept; ~MegDNNHandle() noexcept;


static MegDNNHandle& get(const CompNodeEnv& env);
MGE_WIN_DECLSPEC_FUC static MegDNNHandle& get(const CompNodeEnv& env);


megdnn::Handle* operator->() const { return handle(); } megdnn::Handle* operator->() const { return handle(); }




+ 10
- 7
src/core/include/megbrain/dtype.h View File

@@ -97,7 +97,7 @@ public:
/*! /*!
* \brief set to given value by raw storage * \brief set to given value by raw storage
*/ */
DTypeScalar& set_raw(DType dtype, const void* storage);
MGE_WIN_DECLSPEC_FUC DTypeScalar& set_raw(DType dtype, const void* storage);


/*! /*!
* \brief set to given value, with dtype corresponding to ctype * \brief set to given value, with dtype corresponding to ctype
@@ -114,7 +114,8 @@ public:
* \brief set to given value, but use current dtype and cast value to it * \brief set to given value, but use current dtype and cast value to it
*/ */
template <typename ctype> template <typename ctype>
typename ctype_enable_if<ctype>::type set_retain_dtype(ctype val);
MGE_WIN_DECLSPEC_FUC typename ctype_enable_if<ctype>::type set_retain_dtype(
ctype val);


/*! /*!
* \brief get underlying value, which must be exactly given type * \brief get underlying value, which must be exactly given type
@@ -172,30 +173,32 @@ static_assert(
sizeof(DTypeScalar) == sizeof(DTypeScalar::max_ctype) + sizeof(DType), sizeof(DTypeScalar) == sizeof(DTypeScalar::max_ctype) + sizeof(DType),
"bad DTypeScalar size"); "bad DTypeScalar size");


DType dtype_promotion(DType t0, DType t1);
MGE_WIN_DECLSPEC_FUC DType dtype_promotion(DType t0, DType t1);


/*! /*!
* \brief copy from byte representation to compact representation for lowbit * \brief copy from byte representation to compact representation for lowbit
* types * types
*/ */
void lowbit_memcpy_byte2compact(DType dtype, void* dest, const void* src, size_t n);
MGE_WIN_DECLSPEC_FUC void lowbit_memcpy_byte2compact(
DType dtype, void* dest, const void* src, size_t n);


/*! /*!
* \brief copy from compact representation to byte representation for lowbit * \brief copy from compact representation to byte representation for lowbit
* types * types
*/ */
void lowbit_memcpy_compact2byte(DType dtype, void* dest, const void* src, size_t n);
MGE_WIN_DECLSPEC_FUC void lowbit_memcpy_compact2byte(
DType dtype, void* dest, const void* src, size_t n);


/*! /*!
* \brief copy from byte representation to an aligend tensor for lowbit types * \brief copy from byte representation to an aligend tensor for lowbit types
*/ */
void lowbit_memcpy_byte2aligned(
MGE_WIN_DECLSPEC_FUC void lowbit_memcpy_byte2aligned(
void* dest, const void* src, const ::megdnn::TensorLayout& ly); void* dest, const void* src, const ::megdnn::TensorLayout& ly);


/*! /*!
* \brief copy from an aligend tensor to byte representation for lowbit types * \brief copy from an aligend tensor to byte representation for lowbit types
*/ */
void lowbit_memcpy_aligned2byte(
MGE_WIN_DECLSPEC_FUC void lowbit_memcpy_aligned2byte(
void* dest, const void* src, const ::megdnn::TensorLayout& ly); void* dest, const void* src, const ::megdnn::TensorLayout& ly);


} // namespace mgb } // namespace mgb


+ 1
- 1
src/core/include/megbrain/exception.h View File

@@ -110,7 +110,7 @@ public:


private: private:
std::shared_ptr<ExtraInfo> m_extra_info; std::shared_ptr<ExtraInfo> m_extra_info;
void init();
MGE_WIN_DECLSPEC_FUC void init();
}; };


//! base class for system error: error caused by uncontrollable environment //! base class for system error: error caused by uncontrollable environment


+ 10
- 7
src/core/include/megbrain/graph/cg.h View File

@@ -48,7 +48,7 @@ public:
* \param[out] dest output tensor storage; its comp node has been * \param[out] dest output tensor storage; its comp node has been
* initialized to target comp node * initialized to target comp node
*/ */
virtual void alloc_static(
MGE_WIN_DECLSPEC_FUC virtual void alloc_static(
ComputingGraph* graph, DeviceTensorStorage& dest, size_t size); ComputingGraph* graph, DeviceTensorStorage& dest, size_t size);


/*! /*!
@@ -59,7 +59,8 @@ public:
* Note: if allocation fails, MemAllocError should be raised so * Note: if allocation fails, MemAllocError should be raised so
* VarDevMemDefragmenter can catch the error and do defragmentation. * VarDevMemDefragmenter can catch the error and do defragmentation.
*/ */
virtual void alloc_dynamic(VarNode* var, DeviceTensorStorage& dest, size_t size);
MGE_WIN_DECLSPEC_FUC virtual void alloc_dynamic(
VarNode* var, DeviceTensorStorage& dest, size_t size);


/*! /*!
* \brief Ensure a contiguous storage for memory defragmenter * \brief Ensure a contiguous storage for memory defragmenter
@@ -68,7 +69,7 @@ public:
* allocation requests can be placed in a contiguous storage. This function * allocation requests can be placed in a contiguous storage. This function
* would be called before calling alloc_dynamic() on the individual vars. * would be called before calling alloc_dynamic() on the individual vars.
*/ */
virtual void defrag_prealloc_contig(
MGE_WIN_DECLSPEC_FUC virtual void defrag_prealloc_contig(
ComputingGraph* graph, CompNode comp_node, size_t size); ComputingGraph* graph, CompNode comp_node, size_t size);


/*! /*!
@@ -77,7 +78,8 @@ public:
* If version changes before graph exec, static memory would be reallocated. * If version changes before graph exec, static memory would be reallocated.
* This function would be only called once in each graph execution. * This function would be only called once in each graph execution.
*/ */
virtual size_t static_alloc_version(ComputingGraph* graph) const;
MGE_WIN_DECLSPEC_FUC virtual size_t static_alloc_version(
ComputingGraph* graph) const;
}; };


/** /**
@@ -168,7 +170,7 @@ struct GraphCommonOptimizeOptions {
class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
public CompNodeDepedentObject { public CompNodeDepedentObject {
public: public:
ComputingGraph();
MGE_WIN_DECLSPEC_FUC ComputingGraph();
virtual ~ComputingGraph() = default; virtual ~ComputingGraph() = default;


/*! /*!
@@ -181,10 +183,11 @@ public:


virtual size_t next_node_id() = 0; virtual size_t next_node_id() = 0;


static std::shared_ptr<ComputingGraph> make();
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<ComputingGraph> make();


//! assert that refcnt for ptr is one and destories the ptr //! assert that refcnt for ptr is one and destories the ptr
static void assert_destroy(std::shared_ptr<ComputingGraph>& ptr);
MGE_WIN_DECLSPEC_FUC static void assert_destroy(
std::shared_ptr<ComputingGraph>& ptr);


/*! /*!
* \brief callback to be invoked when some output is ready * \brief callback to be invoked when some output is ready


+ 15
- 15
src/core/include/megbrain/graph/event.h View File

@@ -33,7 +33,7 @@ struct OprInserted {
//! associated exception if insertion fails; nullptr if no error //! associated exception if insertion fails; nullptr if no error
MegBrainError* exc; MegBrainError* exc;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -44,7 +44,7 @@ struct OprExecStart {
OperatorNodeBase* opr; OperatorNodeBase* opr;
GraphExecutable::ExecEnv* env; GraphExecutable::ExecEnv* env;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -55,7 +55,7 @@ struct AfterWait {
CompNode comp_node; CompNode comp_node;
OperatorNodeBase* opr; OperatorNodeBase* opr;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -66,7 +66,7 @@ struct OprExecKernelStart {
OperatorNodeBase* opr; OperatorNodeBase* opr;
GraphExecutable::ExecEnv* env; GraphExecutable::ExecEnv* env;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -76,7 +76,7 @@ struct OprExecKernelEnd {
OperatorNodeBase* opr; OperatorNodeBase* opr;
GraphExecutable::ExecEnv* env; GraphExecutable::ExecEnv* env;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -86,7 +86,7 @@ struct OprExecFinished {
OperatorNodeBase* opr; OperatorNodeBase* opr;
GraphExecutable::ExecEnv* env; GraphExecutable::ExecEnv* env;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -98,7 +98,7 @@ struct BeforeKernel {
OperatorNodeBase* opr; OperatorNodeBase* opr;
CompNode comp_node; CompNode comp_node;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -110,7 +110,7 @@ struct AfterKernel {
OperatorNodeBase* opr; OperatorNodeBase* opr;
CompNode comp_node; CompNode comp_node;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -128,7 +128,7 @@ struct StaticMemAlloc {
CompNode comp_node; CompNode comp_node;
size_t alloc_size; size_t alloc_size;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -139,7 +139,7 @@ struct CompSeqOrderDetermined {
ComputingGraph* graph; ComputingGraph* graph;
AsyncExecutable* exec; AsyncExecutable* exec;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -163,7 +163,7 @@ struct CompSeqExecBeforeStart {
//! configuration) //! configuration)
size_t seq_version; size_t seq_version;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -197,7 +197,7 @@ struct CompSeqExecFinished {
ComputingGraph* graph; ComputingGraph* graph;
AsyncExecutable* exec; AsyncExecutable* exec;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -211,7 +211,7 @@ struct CompSeqExecError {
ComputingGraph* grah; ComputingGraph* grah;
AsyncExecutable* exec; AsyncExecutable* exec;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


/*! /*!
@@ -221,7 +221,7 @@ struct SubgraphAssociated {
ComputingGraph* par_graph; ComputingGraph* par_graph;
ComputingGraph* sub_graph; ComputingGraph* sub_graph;


MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };


#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER #if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER
@@ -229,7 +229,7 @@ struct SubgraphAssociated {
* \brief signaled before graph memory defragementation * \brief signaled before graph memory defragementation
*/ */
struct BeforeMemDefrag { struct BeforeMemDefrag {
MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
}; };
#endif #endif




+ 1
- 1
src/core/include/megbrain/graph/extern_copr_api.h View File

@@ -20,7 +20,7 @@ namespace mgb {
/*! /*!
* \brief config extern c opr dynamic param * \brief config extern c opr dynamic param
*/ */
void config_extern_c_opr_dynamic_param(
MGE_WIN_DECLSPEC_FUC void config_extern_c_opr_dynamic_param(
std::unique_ptr<cg::AsyncExecutable>& func, std::unique_ptr<cg::AsyncExecutable>& func,
std::shared_ptr<ExternCOprParam> param); std::shared_ptr<ExternCOprParam> param);




+ 8
- 8
src/core/include/megbrain/graph/grad_impl.h View File

@@ -31,7 +31,7 @@ public:


//! check that m_all.size() matches opr->input().size(), and return //! check that m_all.size() matches opr->input().size(), and return
//! m_all //! m_all
VarNodeArray& all(OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC VarNodeArray& all(OperatorNodeBase* opr);
}; };


/*! /*!
@@ -69,12 +69,12 @@ using VarVirtualReceiverGrad = thin_function<VarNode*(
/*! /*!
* \brief register grad func for an operator type * \brief register grad func for an operator type
*/ */
void register_grad_func(Typeinfo* opr_type, OprGradFunc grad);
MGE_WIN_DECLSPEC_FUC void register_grad_func(Typeinfo* opr_type, OprGradFunc grad);


/*! /*!
* \brief lookup grad func for an operator type * \brief lookup grad func for an operator type
*/ */
OprGradFunc* lookup_grad_func(Typeinfo* opr_type);
MGE_WIN_DECLSPEC_FUC OprGradFunc* lookup_grad_func(Typeinfo* opr_type);


/*! /*!
* \brief add a callback to be invoked when grad of given var is computed * \brief add a callback to be invoked when grad of given var is computed
@@ -85,7 +85,7 @@ OprGradFunc* lookup_grad_func(Typeinfo* opr_type);
* Remember to call add_extra_dep_for_grad if the GradTransformer needs to * Remember to call add_extra_dep_for_grad if the GradTransformer needs to
* compute grad on other var. * compute grad on other var.
*/ */
void add_grad_transformer(VarNode* var, const GradTransformer& cb);
MGE_WIN_DECLSPEC_FUC void add_grad_transformer(VarNode* var, const GradTransformer& cb);


/*! /*!
* \brief set a callback to compute the gradient of *inputs* * \brief set a callback to compute the gradient of *inputs*
@@ -96,7 +96,7 @@ void add_grad_transformer(VarNode* var, const GradTransformer& cb);
* Note: graph transformation should be disabled until grad has been * Note: graph transformation should be disabled until grad has been
* computed if virtual receiver is needed * computed if virtual receiver is needed
*/ */
void add_var_virtual_receiver(
MGE_WIN_DECLSPEC_FUC void add_var_virtual_receiver(
const VarNodeArray& inputs, const VarNodeArray& outputs, const VarNodeArray& inputs, const VarNodeArray& outputs,
const VarVirtualReceiverGrad& grad); const VarVirtualReceiverGrad& grad);


@@ -108,7 +108,7 @@ void add_var_virtual_receiver(
* *
* \param add_volatile_out see call_opr_grad_on_given_io * \param add_volatile_out see call_opr_grad_on_given_io
*/ */
void add_var_virtual_receiver_reuse_opr_grad(
MGE_WIN_DECLSPEC_FUC void add_var_virtual_receiver_reuse_opr_grad(
const VarNodeArray& inputs, const VarNodeArray& outputs, OperatorNodeBase* opr, const VarNodeArray& inputs, const VarNodeArray& outputs, OperatorNodeBase* opr,
bool add_volatile_out); bool add_volatile_out);


@@ -119,7 +119,7 @@ void add_var_virtual_receiver_reuse_opr_grad(
* graph, so when computing gradients, \p inp would be considered to * graph, so when computing gradients, \p inp would be considered to
* contribute to target var if \p out contributes to target var. * contribute to target var if \p out contributes to target var.
*/ */
void add_extra_dep_for_grad(VarNode* inp, VarNode* out);
MGE_WIN_DECLSPEC_FUC void add_extra_dep_for_grad(VarNode* inp, VarNode* out);


/*! /*!
* \brief call registered OprGradFunc on given input and output vars * \brief call registered OprGradFunc on given input and output vars
@@ -130,7 +130,7 @@ void add_extra_dep_for_grad(VarNode* inp, VarNode* out);
* \param add_volatile_out whether to add null vars in the place of volatile * \param add_volatile_out whether to add null vars in the place of volatile
* output vars to outputs * output vars to outputs
*/ */
VarNode* call_opr_grad_on_given_io(
MGE_WIN_DECLSPEC_FUC VarNode* call_opr_grad_on_given_io(
OperatorNodeBase* opr, const VarNodeArray& inputs, const VarNodeArray& outputs, OperatorNodeBase* opr, const VarNodeArray& inputs, const VarNodeArray& outputs,
size_t idx, const VarNodeArray& out_grad, bool add_volatile_out); size_t idx, const VarNodeArray& out_grad, bool add_volatile_out);




+ 36
- 29
src/core/include/megbrain/graph/helper.h View File

@@ -24,7 +24,8 @@ class VarNode;
* \brief get the involved comp nodes of an operator; the operator must have * \brief get the involved comp nodes of an operator; the operator must have
* been compiled * been compiled
*/ */
CompNode::UnorderedSet get_opr_comp_node_set(OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC CompNode::UnorderedSet get_opr_comp_node_set(
OperatorNodeBase* opr);


/*! /*!
* \brief whether var shape could be statically inferred * \brief whether var shape could be statically inferred
@@ -102,22 +103,24 @@ static inline bool need_device_computing_on_var(
/*! /*!
* \brief whether all input vars of an operator has static storage * \brief whether all input vars of an operator has static storage
*/ */
bool is_all_input_static_storage(OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC bool is_all_input_static_storage(OperatorNodeBase* opr);


/*! /*!
* \brief transform a SymbolVarArray to a VarNodeArray * \brief transform a SymbolVarArray to a VarNodeArray
*/ */
VarNodeArray to_var_node_array(const SymbolVarArray& symbol_var_array);
MGE_WIN_DECLSPEC_FUC VarNodeArray
to_var_node_array(const SymbolVarArray& symbol_var_array);


/*! /*!
* \brief transform a VarNodeArray to a SymbolVarArray * \brief transform a VarNodeArray to a SymbolVarArray
*/ */
SymbolVarArray to_symbol_var_array(const VarNodeArray& var_node_array);
MGE_WIN_DECLSPEC_FUC SymbolVarArray
to_symbol_var_array(const VarNodeArray& var_node_array);


/*! /*!
* \brief return a string to describe the list of variables * \brief return a string to describe the list of variables
*/ */
std::string dump_var_info(const VarNodeArrayView& vars);
MGE_WIN_DECLSPEC_FUC std::string dump_var_info(const VarNodeArrayView& vars);


/*! /*!
* \brief compute grad of target w.r.t. wrt (i.e. d(target)/d(wrt)) * \brief compute grad of target w.r.t. wrt (i.e. d(target)/d(wrt))
@@ -127,24 +130,24 @@ std::string dump_var_info(const VarNodeArrayView& vars);
* \return the var representing grad, or nullptr if target does not depend on * \return the var representing grad, or nullptr if target does not depend on
* wrt * wrt
*/ */
SymbolVar grad(
SymbolVar target, SymbolVar wrt, bool warn_mid_wrt = true,
bool return_zero_for_nodep = true);
MGE_WIN_DECLSPEC_FUC SymbolVar
grad(SymbolVar target, SymbolVar wrt, bool warn_mid_wrt = true,
bool return_zero_for_nodep = true);


/*! /*!
* \brief equivalant to calling grad(grad, wrt) one by one if symbolic; * \brief equivalant to calling grad(grad, wrt) one by one if symbolic;
* since cache in grad manager would be cleared each time, this method is more * since cache in grad manager would be cleared each time, this method is more
* efficient if eager. * efficient if eager.
*/ */
SymbolVarArray grad(
SymbolVar target, SymbolVarArray wrts, bool warn_mid_wrt = true,
bool return_zero_for_nodep = true);
MGE_WIN_DECLSPEC_FUC SymbolVarArray
grad(SymbolVar target, SymbolVarArray wrts, bool warn_mid_wrt = true,
bool return_zero_for_nodep = true);


/*! /*!
* \brief get current grad target, which must be called inside * \brief get current grad target, which must be called inside
* OperatorNodeBase::grad() implementations * OperatorNodeBase::grad() implementations
*/ */
SymbolVar current_grad_target(ComputingGraph& graph);
MGE_WIN_DECLSPEC_FUC SymbolVar current_grad_target(ComputingGraph& graph);


struct SpecialOprStat { struct SpecialOprStat {
bool has_virtual_grad = false; bool has_virtual_grad = false;
@@ -158,7 +161,7 @@ struct SpecialOprStat {
* \return a list of vars correpsonding to \p dest whose dependencies have been * \return a list of vars correpsonding to \p dest whose dependencies have been
* replaced according to \p varmap * replaced according to \p varmap
*/ */
SymbolVarArray replace_vars(
MGE_WIN_DECLSPEC_FUC SymbolVarArray replace_vars(
const SymbolVarArray& dest, const ThinHashMap<SymbolVar, SymbolVar>& varmap); const SymbolVarArray& dest, const ThinHashMap<SymbolVar, SymbolVar>& varmap);


/*! /*!
@@ -169,7 +172,7 @@ SymbolVarArray replace_vars(
* \return a list of vars correpsonding to \p dest whose dependencies have been * \return a list of vars correpsonding to \p dest whose dependencies have been
* replaced according to \p oprmap * replaced according to \p oprmap
*/ */
SymbolVarArray replace_oprs(
MGE_WIN_DECLSPEC_FUC SymbolVarArray replace_oprs(
const SymbolVarArray& dest, const SymbolVarArray& dest,
const ThinHashMap<OperatorNodeBase*, OperatorNodeBase*>& oprmap); const ThinHashMap<OperatorNodeBase*, OperatorNodeBase*>& oprmap);


@@ -180,10 +183,10 @@ SymbolVarArray replace_oprs(
* \return a list of vars correpsonding to \p dest whose owner_graph have been * \return a list of vars correpsonding to \p dest whose owner_graph have been
* replaced with \p new_graph * replaced with \p new_graph
*/ */
SymbolVarArray replace_vars_comp_graph(
const SymbolVarArray& dest, ComputingGraph* new_graph);
MGE_WIN_DECLSPEC_FUC SymbolVarArray
replace_vars_comp_graph(const SymbolVarArray& dest, ComputingGraph* new_graph);


SymbolVarArray find_h2d(const SymbolVarArray& dest);
MGE_WIN_DECLSPEC_FUC SymbolVarArray find_h2d(const SymbolVarArray& dest);


/*! /*!
* \brief go through OperatorNodeBase::NodeProp::Attribute::src_opr until it * \brief go through OperatorNodeBase::NodeProp::Attribute::src_opr until it
@@ -191,7 +194,7 @@ SymbolVarArray find_h2d(const SymbolVarArray& dest);
* *
* This function also performs path compression * This function also performs path compression
*/ */
OperatorNodeBase* get_opr_root_source_opr(OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC OperatorNodeBase* get_opr_root_source_opr(OperatorNodeBase* opr);


//! describes how two mem plans intersect //! describes how two mem plans intersect
enum class MemPlanIntersectionType { enum class MemPlanIntersectionType {
@@ -199,13 +202,14 @@ enum class MemPlanIntersectionType {
IDENTICAL, //!< completely same IDENTICAL, //!< completely same
OVERLAP //!< intersects but not identical OVERLAP //!< intersects but not identical
}; };
MemPlanIntersectionType get_mem_plan_intersection_type(VarNode* a, VarNode* b);
MGE_WIN_DECLSPEC_FUC MemPlanIntersectionType
get_mem_plan_intersection_type(VarNode* a, VarNode* b);


/*! /*!
* \brief request output var to writable forward input var if no mem plan of * \brief request output var to writable forward input var if no mem plan of
* other input vars intersects with this input var * other input vars intersects with this input var
*/ */
void request_fwd_in2out_writable_if_no_mem_ovelap(
MGE_WIN_DECLSPEC_FUC void request_fwd_in2out_writable_if_no_mem_ovelap(
OperatorNodeBase* opr, size_t inp, size_t out); OperatorNodeBase* opr, size_t inp, size_t out);


/*! /*!
@@ -217,7 +221,7 @@ void request_fwd_in2out_writable_if_no_mem_ovelap(
* *
* Note: implemented in cg_impl.cpp, since it is used during graph init * Note: implemented in cg_impl.cpp, since it is used during graph init
*/ */
void update_output_var_shapes(OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC void update_output_var_shapes(OperatorNodeBase* opr);


/*! /*!
* \brief add an output to be used as the workspace for an operator * \brief add an output to be used as the workspace for an operator
@@ -227,17 +231,19 @@ void update_output_var_shapes(OperatorNodeBase* opr);
* This helper is usually called from an opr constructor and used for adding the * This helper is usually called from an opr constructor and used for adding the
* last output. * last output.
*/ */
void add_workspace_output(OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC void add_workspace_output(OperatorNodeBase* opr);


/*! /*!
* \brief copy a raw tensor shape into a host tensor * \brief copy a raw tensor shape into a host tensor
*/ */
void copy_shape_to_tensor_value(DeviceTensorND& dest, const TensorShape& shp);
MGE_WIN_DECLSPEC_FUC void copy_shape_to_tensor_value(
DeviceTensorND& dest, const TensorShape& shp);


/*! /*!
* \brief copy value of a host tensor into a raw tensor shape * \brief copy value of a host tensor into a raw tensor shape
*/ */
void copy_tensor_value_to_shape(TensorShape& dest, const DeviceTensorND& val);
MGE_WIN_DECLSPEC_FUC void copy_tensor_value_to_shape(
TensorShape& dest, const DeviceTensorND& val);


/*! /*!
* \brief get a symbolvar whose value is tensor shape, used for other * \brief get a symbolvar whose value is tensor shape, used for other
@@ -246,7 +252,7 @@ void copy_tensor_value_to_shape(TensorShape& dest, const DeviceTensorND& val);
* \param opr_name operator that invokes this function; used in error * \param opr_name operator that invokes this function; used in error
* function if *config* is invalid * function if *config* is invalid
*/ */
SymbolVar var_from_tensor_shape(
MGE_WIN_DECLSPEC_FUC SymbolVar var_from_tensor_shape(
ComputingGraph& graph, const OperatorNodeConfig& config, const char* opr_name, ComputingGraph& graph, const OperatorNodeConfig& config, const char* opr_name,
const TensorShape& shape); const TensorShape& shape);


@@ -275,7 +281,7 @@ public:
: m_cb{std::move(cb)}, m_extra_dep(std::move(extra_dep)) {} : m_cb{std::move(cb)}, m_extra_dep(std::move(extra_dep)) {}


//! add an operator whose deps should be discovered //! add an operator whose deps should be discovered
void add(OperatorNodeBase* dest);
MGE_WIN_DECLSPEC_FUC void add(OperatorNodeBase* dest);


void add(SymbolVar var) { add(var.node()->owner_opr()); } void add(SymbolVar var) { add(var.node()->owner_opr()); }


@@ -334,7 +340,7 @@ public:
* *
* This function should be called only once on a graph * This function should be called only once on a graph
*/ */
static void register_to(
MGE_WIN_DECLSPEC_FUC static void register_to(
ComputingGraph* dest, const ComputingGraph* src, const TransFunc& trans); ComputingGraph* dest, const ComputingGraph* src, const TransFunc& trans);


/*! /*!
@@ -342,12 +348,13 @@ public:
* \return previously registered transformer on given graph or nullptr * \return previously registered transformer on given graph or nullptr
* if none registered * if none registered
*/ */
static const InterGraphVarTransformer* get(const ComputingGraph& graph);
MGE_WIN_DECLSPEC_FUC static const InterGraphVarTransformer* get(
const ComputingGraph& graph);


/*! /*!
* \brief transform a var into this graph * \brief transform a var into this graph
*/ */
VarNode* trans(VarNode* src) const;
MGE_WIN_DECLSPEC_FUC VarNode* trans(VarNode* src) const;


private: private:
ComputingGraph* m_graph_dest; ComputingGraph* m_graph_dest;


+ 48
- 38
src/core/include/megbrain/graph/operator_node.h View File

@@ -31,13 +31,13 @@ class ExecutionMask;
* \brief configuration for operator nodes * \brief configuration for operator nodes
*/ */
class OperatorNodeConfig final : public Hashable { class OperatorNodeConfig final : public Hashable {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


public: public:
using CompNodeArray = SmallVector<CompNode, 1>; using CompNodeArray = SmallVector<CompNode, 1>;


OperatorNodeConfig() = default; OperatorNodeConfig() = default;
~OperatorNodeConfig();
MGE_WIN_DECLSPEC_FUC ~OperatorNodeConfig();


OperatorNodeConfig(std::string name) : m_name{std::move(name)} {} OperatorNodeConfig(std::string name) : m_name{std::move(name)} {}


@@ -101,18 +101,18 @@ public:
/*! /*!
* \brief set preferred single comp node * \brief set preferred single comp node
*/ */
OperatorNodeConfig& comp_node(const CompNode& node);
MGE_WIN_DECLSPEC_FUC OperatorNodeConfig& comp_node(const CompNode& node);


/*! /*!
* \brief directly set all the CompNodes * \brief directly set all the CompNodes
*/ */
OperatorNodeConfig& comp_node_arr(const CompNodeArray& arr);
MGE_WIN_DECLSPEC_FUC OperatorNodeConfig& comp_node_arr(const CompNodeArray& arr);


/*! /*!
* \brief get single comp node if the user has set it, or an invalid * \brief get single comp node if the user has set it, or an invalid
* comp node if the config is empty * comp node if the config is empty
*/ */
CompNode get_single_comp_node() const;
MGE_WIN_DECLSPEC_FUC CompNode get_single_comp_node() const;


/*! /*!
* \brief follow the computing node of dest * \brief follow the computing node of dest
@@ -121,7 +121,7 @@ public:
return comp_node(dest.node()->comp_node()); return comp_node(dest.node()->comp_node());
} }


OperatorNodeConfig& output_dtype(DType dtype);
MGE_WIN_DECLSPEC_FUC OperatorNodeConfig& output_dtype(DType dtype);


DType output_dtype() const { return m_output_dtype; } DType output_dtype() const { return m_output_dtype; }


@@ -132,9 +132,9 @@ public:


const CompNodeArray& comp_node() const { return m_comp_node; } const CompNodeArray& comp_node() const { return m_comp_node; }


size_t hash() const override;
MGE_WIN_DECLSPEC_FUC size_t hash() const override;


bool is_same_st(const Hashable& rhs) const override;
MGE_WIN_DECLSPEC_FUC bool is_same_st(const Hashable& rhs) const override;


private: private:
static constexpr size_t sm_initial_instance_id = 1333331; static constexpr size_t sm_initial_instance_id = 1333331;
@@ -163,7 +163,7 @@ public:
* *
* The default implementation does nothing * The default implementation does nothing
*/ */
virtual void record_execute_deps(ExecDependencyArray& record);
MGE_WIN_DECLSPEC_FUC virtual void record_execute_deps(ExecDependencyArray& record);


protected: protected:
~GraphExecutable() = default; ~GraphExecutable() = default;
@@ -408,7 +408,7 @@ public:
* \brief reset dep type; the vars could contain duplicated var nodes, * \brief reset dep type; the vars could contain duplicated var nodes,
* in which case the corresponding dep type would be ORed together * in which case the corresponding dep type would be ORed together
*/ */
void reset_dep_type(
MGE_WIN_DECLSPEC_FUC void reset_dep_type(
const VarNodeArray& vars, const SmallVector<DepType>& dep_types); const VarNodeArray& vars, const SmallVector<DepType>& dep_types);


/*! /*!
@@ -488,11 +488,11 @@ public:
const VarNodeArrayView& input_var_naming; const VarNodeArrayView& input_var_naming;
}; };


virtual ~OperatorNodeBase() noexcept;
MGE_WIN_DECLSPEC_FUC virtual ~OperatorNodeBase() noexcept;


#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
/* ===================== json io ===================== */ /* ===================== json io ===================== */
std::shared_ptr<json::Value> to_json() const override;
MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> to_json() const override;


//! extra value to be added to json //! extra value to be added to json
std::shared_ptr<json::Object> to_json_extra_json = json::Object::make(); std::shared_ptr<json::Object> to_json_extra_json = json::Object::make();
@@ -511,7 +511,7 @@ public:
const VarNodeArray& output() const { return m_output; } const VarNodeArray& output() const { return m_output; }


// non-volatile outputs // non-volatile outputs
const VarNodeArray usable_output() const;
MGE_WIN_DECLSPEC_FUC const VarNodeArray usable_output() const;


VarNode* input(size_t idx) const { return m_input.at(idx); } VarNode* input(size_t idx) const { return m_input.at(idx); }


@@ -519,7 +519,7 @@ public:


//! hash that combines all inputs, m_config.comp_node() and all //! hash that combines all inputs, m_config.comp_node() and all
//! add_equivalence_component calls //! add_equivalence_component calls
size_t hash() const override final;
MGE_WIN_DECLSPEC_FUC size_t hash() const override final;


/*! /*!
* \brief get node prop, which is available and constant after node * \brief get node prop, which is available and constant after node
@@ -527,7 +527,7 @@ public:
* *
* Note that this function calls do_make_node_prop() on first call * Note that this function calls do_make_node_prop() on first call
*/ */
const NodeProp& node_prop() const;
MGE_WIN_DECLSPEC_FUC const NodeProp& node_prop() const;


/*! /*!
* \brief called by ComputingGraph to mark that this node has been * \brief called by ComputingGraph to mark that this node has been
@@ -549,7 +549,7 @@ public:
* 3. call do_execute * 3. call do_execute
* 4. set_ready on output * 4. set_ready on output
*/ */
void execute(ExecEnv& env) override final;
MGE_WIN_DECLSPEC_FUC void execute(ExecEnv& env) override final;


/*! /*!
* \brief specifies waiting strategy on one comp node for input vars * \brief specifies waiting strategy on one comp node for input vars
@@ -617,7 +617,7 @@ public:
* \brief get callbacks to be invoked on events related to this * \brief get callbacks to be invoked on events related to this
* operator; default implementation returns empty event * operator; default implementation returns empty event
*/ */
virtual OprEventCallback get_opr_event_callback();
MGE_WIN_DECLSPEC_FUC virtual OprEventCallback get_opr_event_callback();


/*! /*!
* \brief called when stream of comp node of output vars is changed for * \brief called when stream of comp node of output vars is changed for
@@ -635,7 +635,7 @@ public:
* *
* This function is called once during operator insertion. * This function is called once during operator insertion.
*/ */
virtual void init_output_dtype();
MGE_WIN_DECLSPEC_FUC virtual void init_output_dtype();


/*! /*!
* \brief initialize output format by calling VarNode::format * \brief initialize output format by calling VarNode::format
@@ -645,7 +645,7 @@ public:
* *
* This function is called once during operator insertion * This function is called once during operator insertion
*/ */
virtual void init_output_format();
MGE_WIN_DECLSPEC_FUC virtual void init_output_format();


/*! /*!
* \brief inititialize output comp_node by calling VarNode::comp_node * \brief inititialize output comp_node by calling VarNode::comp_node
@@ -687,7 +687,7 @@ public:
* \param dynamic if true, initialize mem plans for vars that could not * \param dynamic if true, initialize mem plans for vars that could not
* be statically inferred; otherwise for statically inferable vars * be statically inferred; otherwise for statically inferable vars
*/ */
virtual void init_output_mem_plan(bool dynamic);
MGE_WIN_DECLSPEC_FUC virtual void init_output_mem_plan(bool dynamic);


/* /*
* ============================================================= * =============================================================
@@ -703,7 +703,7 @@ public:
}; };


//! add input var to this operator //! add input var to this operator
void add_input(
MGE_WIN_DECLSPEC_FUC void add_input(
std::initializer_list<VarNode*> list, std::initializer_list<VarNode*> list,
AddInputSortType sort_type = AddInputSortType::NONE); AddInputSortType sort_type = AddInputSortType::NONE);


@@ -711,7 +711,7 @@ public:
* \brief allocate a new output VarNode; the name would be appended to * \brief allocate a new output VarNode; the name would be appended to
* this->name to form the final name * this->name to form the final name
*/ */
VarNode* add_output(const Maybe<std::string>& name);
MGE_WIN_DECLSPEC_FUC VarNode* add_output(const Maybe<std::string>& name);


/*! /*!
* \brief add extra component for equivalence check * \brief add extra component for equivalence check
@@ -734,7 +734,7 @@ public:
* \brief allocate a new node prop and initialize dep entry as all * \brief allocate a new node prop and initialize dep entry as all
* inputs * inputs
*/ */
virtual NodeProp* do_make_node_prop() const;
MGE_WIN_DECLSPEC_FUC virtual NodeProp* do_make_node_prop() const;


/*! /*!
* \brief Update operator priority. * \brief Update operator priority.
@@ -744,13 +744,13 @@ public:
* priority. * priority.
* \return whether the priority would be changed. * \return whether the priority would be changed.
*/ */
virtual bool update_priority() const;
MGE_WIN_DECLSPEC_FUC virtual bool update_priority() const;


protected: protected:
/*! /*!
* \param input_var_naming used for generating default node name * \param input_var_naming used for generating default node name
*/ */
OperatorNodeBase(
MGE_WIN_DECLSPEC_FUC OperatorNodeBase(
ComputingGraph* owner, const OperatorNodeConfig& config, ComputingGraph* owner, const OperatorNodeConfig& config,
const std::string& default_name, const VarNodeArrayView& input_var_naming); const std::string& default_name, const VarNodeArrayView& input_var_naming);


@@ -781,9 +781,10 @@ private:
mutable Maybe<NodeProp> m_node_prop; mutable Maybe<NodeProp> m_node_prop;
Maybe<InputWaitingSpec> m_input_waiting_spec; Maybe<InputWaitingSpec> m_input_waiting_spec;


void do_add_equivalence_component(HashableContainer&& hashable);
MGE_WIN_DECLSPEC_FUC void do_add_equivalence_component(
HashableContainer&& hashable);


bool is_same_st(const Hashable& rhs) const override final;
MGE_WIN_DECLSPEC_FUC bool is_same_st(const Hashable& rhs) const override final;
}; };


/*! /*!
@@ -856,7 +857,7 @@ protected:
* mixin_on_output_comp_node_stream_changed(), which is called from * mixin_on_output_comp_node_stream_changed(), which is called from
* opr.on_output_comp_node_stream_changed() invoked by this function. * opr.on_output_comp_node_stream_changed() invoked by this function.
*/ */
static void mixin_init_output_comp_node(OperatorNodeBase& opr);
MGE_WIN_DECLSPEC_FUC static void mixin_init_output_comp_node(OperatorNodeBase& opr);


/*! /*!
* \brief only infer output comp node, without modifying anything * \brief only infer output comp node, without modifying anything
@@ -865,7 +866,7 @@ protected:
* least one input exists and they are all placed on the same comp node. * least one input exists and they are all placed on the same comp node.
* It also checks the comp node set in config. * It also checks the comp node set in config.
*/ */
static CompNode mixin_infer_output_comp_node(
MGE_WIN_DECLSPEC_FUC static CompNode mixin_infer_output_comp_node(
const OperatorNodeBase& opr, bool cross_mem); const OperatorNodeBase& opr, bool cross_mem);


CompNode mixin_comp_node() const { return m_comp_node; } CompNode mixin_comp_node() const { return m_comp_node; }
@@ -874,22 +875,25 @@ protected:
* \brief initialize NodeProp with SINGLE_COMP_NODE, and setup * \brief initialize NodeProp with SINGLE_COMP_NODE, and setup
* dependency on input * dependency on input
*/ */
NodeProp* mixin_do_make_node_prop(const OperatorNodeBase& opr) const;
MGE_WIN_DECLSPEC_FUC NodeProp* mixin_do_make_node_prop(
const OperatorNodeBase& opr) const;


void mixin_do_execute(OperatorNodeBase& opr, OperatorNodeBase::ExecEnv& env);
MGE_WIN_DECLSPEC_FUC void mixin_do_execute(
OperatorNodeBase& opr, OperatorNodeBase::ExecEnv& env);


void mixin_on_output_comp_node_stream_changed(OperatorNodeBase& opr);
MGE_WIN_DECLSPEC_FUC void mixin_on_output_comp_node_stream_changed(
OperatorNodeBase& opr);


/*! /*!
* \brief set comp node during initializing * \brief set comp node during initializing
*/ */
void mixin_comp_node(OperatorNodeBase& opr, CompNode node);
MGE_WIN_DECLSPEC_FUC void mixin_comp_node(OperatorNodeBase& opr, CompNode node);


/*! /*!
* \brief override by subclass to perform raw computing; this function * \brief override by subclass to perform raw computing; this function
* is already dispatched on corresponding stream in ExecEnv * is already dispatched on corresponding stream in ExecEnv
*/ */
virtual void scn_do_execute() = 0;
MGE_WIN_DECLSPEC_FUC virtual void scn_do_execute() = 0;


~SingleCNOperatorNode() = default; ~SingleCNOperatorNode() = default;
}; };
@@ -903,7 +907,8 @@ class OutshapePureByInshapeOpr : public OperatorNodeMixinBase {
size_t m_inp_run_id = -1; size_t m_inp_run_id = -1;
TensorShapeArray m_out_shp; TensorShapeArray m_out_shp;


bool infer_desc(size_t out_idx, TensorShape& dest, const StaticInferInpVal& inp);
MGE_WIN_DECLSPEC_FUC bool infer_desc(
size_t out_idx, TensorShape& dest, const StaticInferInpVal& inp);


protected: protected:
/*! /*!
@@ -912,9 +917,11 @@ protected:
* of output vars that should be managed by this helper (they would be * of output vars that should be managed by this helper (they would be
* the first vars of all output vars). * the first vars of all output vars).
*/ */
void mixin_set_nr_managed_outputs(OperatorNodeBase& opr, size_t nr);
MGE_WIN_DECLSPEC_FUC void mixin_set_nr_managed_outputs(
OperatorNodeBase& opr, size_t nr);


void mixin_init_output_static_infer_desc(OperatorNodeBase& opr);
MGE_WIN_DECLSPEC_FUC void mixin_init_output_static_infer_desc(
OperatorNodeBase& opr);


/*! /*!
* \brief get output shapes from input shapes * \brief get output shapes from input shapes
@@ -926,7 +933,7 @@ protected:
virtual void get_output_var_shape( virtual void get_output_var_shape(
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const = 0; const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const = 0;


~OutshapePureByInshapeOpr();
MGE_WIN_DECLSPEC_FUC ~OutshapePureByInshapeOpr();
}; };


/*! /*!
@@ -1010,6 +1017,9 @@ using OprNodeArray = SmallVector<OperatorNodeBase*>;
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \
MGB_DYN_TYPE_OBJ_FINAL_DECL; MGB_DYN_TYPE_OBJ_FINAL_DECL;


#define MGB_DEFINE_OPR_CLASS_WITH_EXPORT(_name, _base, ...) \
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;
} // namespace cg } // namespace cg
} // namespace mgb } // namespace mgb




+ 3
- 3
src/core/include/megbrain/graph/static_infer.h View File

@@ -117,13 +117,13 @@ struct ShapeInferDesc {
* \brief make a ShapeInferDesc that copies shape of another var into * \brief make a ShapeInferDesc that copies shape of another var into
* dest var * dest var
*/ */
static ShapeInferDesc make_identity(VarNode* src);
MGE_WIN_DECLSPEC_FUC static ShapeInferDesc make_identity(VarNode* src);


/*! /*!
* \brief make a constant ShapeInferDesc that always produces given * \brief make a constant ShapeInferDesc that always produces given
* value * value
*/ */
static ShapeInferDesc make_const(const TensorShape& shp);
MGE_WIN_DECLSPEC_FUC static ShapeInferDesc make_const(const TensorShape& shp);
}; };


/*! /*!
@@ -154,7 +154,7 @@ struct ValueInferDesc {
* \brief make a ValueInferDesc that copies shape of another var into * \brief make a ValueInferDesc that copies shape of another var into
* dest var * dest var
*/ */
static ValueInferDesc make_identity(VarNode* src);
MGE_WIN_DECLSPEC_FUC static ValueInferDesc make_identity(VarNode* src);
}; };


struct InferType { struct InferType {


+ 11
- 10
src/core/include/megbrain/graph/symbol_var.h View File

@@ -53,15 +53,15 @@ public:
* \brief set a new name; note that the underlying VarNode would be * \brief set a new name; note that the underlying VarNode would be
* modified, not this SymbolVar itself * modified, not this SymbolVar itself
*/ */
SymbolVar rename(const std::string& name) const;
MGE_WIN_DECLSPEC_FUC SymbolVar rename(const std::string& name) const;


SymbolVar reshape(const TensorShape& tshape) const;
SymbolVar reshape(SymbolVar tshape) const;
SymbolVar broadcast(const TensorShape& tshape) const;
SymbolVar broadcast(SymbolVar tshape) const;
SymbolVar symshape() const;
SymbolVar flatten() const;
SymbolVar add_axis(size_t idx) const;
MGE_WIN_DECLSPEC_FUC SymbolVar reshape(const TensorShape& tshape) const;
MGE_WIN_DECLSPEC_FUC SymbolVar reshape(SymbolVar tshape) const;
MGE_WIN_DECLSPEC_FUC SymbolVar broadcast(const TensorShape& tshape) const;
MGE_WIN_DECLSPEC_FUC SymbolVar broadcast(SymbolVar tshape) const;
MGE_WIN_DECLSPEC_FUC SymbolVar symshape() const;
MGE_WIN_DECLSPEC_FUC SymbolVar flatten() const;
MGE_WIN_DECLSPEC_FUC SymbolVar add_axis(size_t idx) const;


const TensorShape& shape() const { return m_node->shape(); } const TensorShape& shape() const { return m_node->shape(); }


@@ -105,7 +105,8 @@ public:
* \brief make a const scalar value on given computing graph and * \brief make a const scalar value on given computing graph and
* computing node * computing node
*/ */
static SymbolVar make_scalar(DTypeScalar value, ComputingGraph& cg, CompNode cn);
MGE_WIN_DECLSPEC_FUC static SymbolVar make_scalar(
DTypeScalar value, ComputingGraph& cg, CompNode cn);


/*! /*!
* \brief make a const scalar value using computing graph and comp node * \brief make a const scalar value using computing graph and comp node
@@ -131,7 +132,7 @@ public:
* This essentially synchronizes the dispatch queue and then call * This essentially synchronizes the dispatch queue and then call
* dev_tensor() * dev_tensor()
*/ */
const DeviceTensorND& eager_eval_get_value() const;
MGE_WIN_DECLSPEC_FUC const DeviceTensorND& eager_eval_get_value() const;


bool allow_shape_change() const { return m_node->allow_shape_change(); } bool allow_shape_change() const { return m_node->allow_shape_change(); }
}; };


+ 37
- 29
src/core/include/megbrain/graph/var_node.h View File

@@ -145,7 +145,7 @@ public:
MemAllocPlan& layout(const TensorLayout& dest, bool allow_shape_change = false); MemAllocPlan& layout(const TensorLayout& dest, bool allow_shape_change = false);


#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
std::shared_ptr<json::Value> to_json() const override;
MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> to_json() const override;
#endif #endif


/*! /*!
@@ -153,13 +153,13 @@ public:
* *
* Release tensor storage if refcnt drops to zero * Release tensor storage if refcnt drops to zero
*/ */
MemAllocPlan& release_chunk();
MGE_WIN_DECLSPEC_FUC MemAllocPlan& release_chunk();


/*! /*!
* \brief reset chunk to a privately owned chunk, and setup offset and * \brief reset chunk to a privately owned chunk, and setup offset and
* layout from owner var, and clear tensor storage * layout from owner var, and clear tensor storage
*/ */
MemAllocPlan& reset_from_owner_var();
MGE_WIN_DECLSPEC_FUC MemAllocPlan& reset_from_owner_var();


/*! /*!
* \brief reset to a special marker that indicates this var is not * \brief reset to a special marker that indicates this var is not
@@ -187,10 +187,11 @@ public:
} }


//! assign layout, offset and chunk from another mem alloc plan //! assign layout, offset and chunk from another mem alloc plan
MemAllocPlan& assign(const MemAllocPlan& src);
MGE_WIN_DECLSPEC_FUC MemAllocPlan& assign(const MemAllocPlan& src);


//! assign for readonly forward //! assign for readonly forward
MemAllocPlan& assign_for_forward(const MemAllocPlan& src, const SubTensorSpec& sub);
MGE_WIN_DECLSPEC_FUC MemAllocPlan& assign_for_forward(
const MemAllocPlan& src, const SubTensorSpec& sub);


/*! /*!
* \brief next readonly-forward reader of this MemAllocPlan * \brief next readonly-forward reader of this MemAllocPlan
@@ -212,7 +213,7 @@ private:


public: public:
MemAllocPlan* next() const { return m_next; } MemAllocPlan* next() const { return m_next; }
void reset();
MGE_WIN_DECLSPEC_FUC void reset();
inline void insert_after(const MemAllocPlan& prev, MemAllocPlan* self); inline void insert_after(const MemAllocPlan& prev, MemAllocPlan* self);
inline void remove_self(); inline void remove_self();
}; };
@@ -261,7 +262,8 @@ public:
* with given layout may be forwarded to opr directly, otherwise it * with given layout may be forwarded to opr directly, otherwise it
* will be implicitly rearranged to a contiguous one. * will be implicitly rearranged to a contiguous one.
*/ */
VarNode& add_layout_constraint(LayoutConstraintCallback callback);
MGE_WIN_DECLSPEC_FUC VarNode& add_layout_constraint(
LayoutConstraintCallback callback);


/*! /*!
* \brief requires the layout to be contiguous * \brief requires the layout to be contiguous
@@ -272,7 +274,7 @@ public:
* existing callbacks would be cleared and new callbacks would be * existing callbacks would be cleared and new callbacks would be
* ignored after add_layout_constraint_contiguous() is invoked. * ignored after add_layout_constraint_contiguous() is invoked.
*/ */
VarNode& add_layout_constraint_contiguous();
MGE_WIN_DECLSPEC_FUC VarNode& add_layout_constraint_contiguous();


/*! /*!
* \brief requires the layout to be monotone while allowing broadcast * \brief requires the layout to be monotone while allowing broadcast
@@ -281,7 +283,7 @@ public:
* implemented by marking a flag; however user-defined callbacks are * implemented by marking a flag; however user-defined callbacks are
* still invoked since they might impose stronger constraints. * still invoked since they might impose stronger constraints.
*/ */
VarNode& add_layout_constraint_monotone();
MGE_WIN_DECLSPEC_FUC VarNode& add_layout_constraint_monotone();


/*! /*!
* \brief request that memory should be readonly forwarded from other * \brief request that memory should be readonly forwarded from other
@@ -292,7 +294,7 @@ public:
* *
* \return whether this request could be satisfied * \return whether this request could be satisfied
*/ */
MGB_WARN_UNUSED_RESULT bool set_fwd_in2out_readonly(
MGB_WARN_UNUSED_RESULT MGE_WIN_DECLSPEC_FUC bool set_fwd_in2out_readonly(
VarNode* input, const SubTensorSpec& sub); VarNode* input, const SubTensorSpec& sub);


/*! /*!
@@ -302,7 +304,7 @@ public:
* Note that this function must be called from * Note that this function must be called from
* OperatorNodeBase::mem_plan_fwd_in2out_writable. * OperatorNodeBase::mem_plan_fwd_in2out_writable.
*/ */
VarNode& set_fwd_in2out_writable(VarNode* input);
MGE_WIN_DECLSPEC_FUC VarNode& set_fwd_in2out_writable(VarNode* input);


/*! /*!
* \brief require this var to share memory from another var; only used * \brief require this var to share memory from another var; only used
@@ -311,14 +313,14 @@ public:
* Note that this function must be called during operator node * Note that this function must be called during operator node
* initialization * initialization
*/ */
VarNode& set_fwd_in2out_writable_force(VarNode* input);
MGE_WIN_DECLSPEC_FUC VarNode& set_fwd_in2out_writable_force(VarNode* input);


/* ===================== getter and setters ===================== */ /* ===================== getter and setters ===================== */


OperatorNodeBase* owner_opr() const { return m_owner; } OperatorNodeBase* owner_opr() const { return m_owner; }


//! get name; if name is not valid, get name of owner opr //! get name; if name is not valid, get name of owner opr
const std::string& name() const;
MGE_WIN_DECLSPEC_FUC const std::string& name() const;


//! get name as C-string //! get name as C-string
const char* cname() const { return name().c_str(); } const char* cname() const { return name().c_str(); }
@@ -327,7 +329,7 @@ public:
bool has_name_set() const { return m_has_name_set; } bool has_name_set() const { return m_has_name_set; }


//! set name explicitly //! set name explicitly
VarNode& name(std::string name);
MGE_WIN_DECLSPEC_FUC VarNode& name(std::string name);


//! get data type of data in this var //! get data type of data in this var
DType dtype() const { return m_dev_tensor.dtype(); } DType dtype() const { return m_dev_tensor.dtype(); }
@@ -336,10 +338,10 @@ public:
TensorFormat format() const { return m_dev_tensor.format(); } TensorFormat format() const { return m_dev_tensor.format(); }


//! set dtype; this function can only be called once //! set dtype; this function can only be called once
VarNode& dtype(DType dtype);
MGE_WIN_DECLSPEC_FUC VarNode& dtype(DType dtype);


//! set format; this function can only be called once //! set format; this function can only be called once
VarNode& format(TensorFormat format);
MGE_WIN_DECLSPEC_FUC VarNode& format(TensorFormat format);


MemAllocPlan& mem_plan() { return m_mem_plan; } MemAllocPlan& mem_plan() { return m_mem_plan; }


@@ -351,7 +353,7 @@ public:
} }


//! get the underlying device tensor to fill data //! get the underlying device tensor to fill data
const DeviceTensorND& dev_tensor() const;
MGE_WIN_DECLSPEC_FUC const DeviceTensorND& dev_tensor() const;


/*! /*!
* \brief get the underlying device tensor that can be modified(like * \brief get the underlying device tensor that can be modified(like
@@ -360,7 +362,7 @@ public:
* This should only be called from the owner opr of this var, and this * This should only be called from the owner opr of this var, and this
* var must have flag NO_SYS_MEM_ALLOC. * var must have flag NO_SYS_MEM_ALLOC.
*/ */
DeviceTensorND& mutable_dev_tensor();
MGE_WIN_DECLSPEC_FUC DeviceTensorND& mutable_dev_tensor();


/*! /*!
* \brief previous dev ptr before deallocating dev_tensor; used for * \brief previous dev ptr before deallocating dev_tensor; used for
@@ -377,7 +379,7 @@ public:
* \brief set comp node; only the memory node could be changed if called * \brief set comp node; only the memory node could be changed if called
* multiple times * multiple times
*/ */
VarNode& comp_node(const CompNode& cn);
MGE_WIN_DECLSPEC_FUC VarNode& comp_node(const CompNode& cn);


const TensorShape& shape() const { return m_shape; } const TensorShape& shape() const { return m_shape; }


@@ -389,7 +391,7 @@ public:
* \brief reset VarNode shape * \brief reset VarNode shape
* \return whether shape differs from old shape * \return whether shape differs from old shape
*/ */
VarNode& shape(const TensorShape& shape);
MGE_WIN_DECLSPEC_FUC VarNode& shape(const TensorShape& shape);


bool allow_shape_change() const { return m_allow_shape_change; } bool allow_shape_change() const { return m_allow_shape_change; }


@@ -399,7 +401,7 @@ public:
} }


#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
std::shared_ptr<json::Value> to_json() const override;
MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> to_json() const override;
#endif #endif


/*! /*!
@@ -413,7 +415,7 @@ public:


enum class Flag : uint32_t; enum class Flag : uint32_t;


VarNode& add_flag(Flag flag);
MGE_WIN_DECLSPEC_FUC VarNode& add_flag(Flag flag);


inline bool contain_flag(Flag flag) const; inline bool contain_flag(Flag flag) const;


@@ -429,7 +431,8 @@ public:
* *
* \warning Alloc size_req memory if size_req != 0. * \warning Alloc size_req memory if size_req != 0.
*/ */
VarNode& shape_alloc(const TensorShape& shape, size_t size_req = 0);
MGE_WIN_DECLSPEC_FUC VarNode& shape_alloc(
const TensorShape& shape, size_t size_req = 0);


/*! /*!
* \brief directly reset device tensor from another var * \brief directly reset device tensor from another var
@@ -459,7 +462,8 @@ public:
* \param value the tensor to be used; it must be contiguous or empty * \param value the tensor to be used; it must be contiguous or empty
* and be placed on the same comp node of this var. * and be placed on the same comp node of this var.
*/ */
VarNode& reset_dev_tensor_from_tensor(const DeviceTensorND& value);
MGE_WIN_DECLSPEC_FUC VarNode& reset_dev_tensor_from_tensor(
const DeviceTensorND& value);


/*! /*!
* \brief add a var to add RT_FORCE_DYNAMIC_MEM_ALLOC flag if such flag * \brief add a var to add RT_FORCE_DYNAMIC_MEM_ALLOC flag if such flag
@@ -472,7 +476,8 @@ public:
* This method should be called from * This method should be called from
* OperatorNodeBase::init_rt_force_dynamic_mem_alloc_imply_chain impls. * OperatorNodeBase::init_rt_force_dynamic_mem_alloc_imply_chain impls.
*/ */
VarNode& add_rt_force_dynamic_mem_alloc_imply_chain(VarNode* dest);
MGE_WIN_DECLSPEC_FUC VarNode& add_rt_force_dynamic_mem_alloc_imply_chain(
VarNode* dest);


/* ===================== graph compiler special ===================== */ /* ===================== graph compiler special ===================== */


@@ -486,7 +491,8 @@ public:
* \param fixed_alloc if not null, it should be a tensor providing * \param fixed_alloc if not null, it should be a tensor providing
* memory allocation for this var. * memory allocation for this var.
*/ */
MemAllocPlan& init_mem_plan(const DeviceTensorND* fixed_alloc = nullptr);
MGE_WIN_DECLSPEC_FUC MemAllocPlan& init_mem_plan(
const DeviceTensorND* fixed_alloc = nullptr);


/*! /*!
* \brief get the shape and value infer trait * \brief get the shape and value infer trait
@@ -541,12 +547,14 @@ private:


std::vector<VarNode*> m_rt_force_dynamic_mem_alloc_imply_chain; std::vector<VarNode*> m_rt_force_dynamic_mem_alloc_imply_chain;


void modify_flag(Flag delta, Flag new_flag);
MGE_WIN_DECLSPEC_FUC void modify_flag(Flag delta, Flag new_flag);


void assign_dev_tensor_from_tensor(const DeviceTensorND& value);
MGE_WIN_DECLSPEC_FUC void assign_dev_tensor_from_tensor(
const DeviceTensorND& value);


#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
std::shared_ptr<json::Value> dump_static_infer_info_to_json() const;
MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> dump_static_infer_info_to_json()
const;
#endif #endif


friend class static_infer::StaticInferManagerImpl; friend class static_infer::StaticInferManagerImpl;


+ 7
- 7
src/core/include/megbrain/system.h View File

@@ -25,27 +25,27 @@ namespace mgb {
namespace sys { namespace sys {


//! set name of caller thread //! set name of caller thread
void set_thread_name(const std::string& name);
MGE_WIN_DECLSPEC_FUC void set_thread_name(const std::string& name);


#if !__DEPLOY_ON_XP_SP2__ #if !__DEPLOY_ON_XP_SP2__
/*! /*!
* \brief get name of of given thread * \brief get name of of given thread
* \param tid thread id, or None to for the caller thread * \param tid thread id, or None to for the caller thread
*/ */
std::string get_thread_name(Maybe<std::thread::id> tid = None);
MGE_WIN_DECLSPEC_FUC std::string get_thread_name(Maybe<std::thread::id> tid = None);
#endif #endif


//! get number of CPU cores on this system //! get number of CPU cores on this system
int get_cpu_count();
MGE_WIN_DECLSPEC_FUC int get_cpu_count();


//! set cpu affinity for caller thread //! set cpu affinity for caller thread
void set_cpu_affinity(const std::vector<int>& cpuset);
MGE_WIN_DECLSPEC_FUC void set_cpu_affinity(const std::vector<int>& cpuset);


//! whether stderr supports ansi color code //! whether stderr supports ansi color code
bool stderr_ansi_color();
MGE_WIN_DECLSPEC_FUC bool stderr_ansi_color();


//! get total ram and free ram in bytes //! get total ram and free ram in bytes
std::pair<size_t, size_t> get_ram_status_bytes();
MGE_WIN_DECLSPEC_FUC std::pair<size_t, size_t> get_ram_status_bytes();


/*! /*!
* \brief invoke a function with time limit * \brief invoke a function with time limit
@@ -207,7 +207,7 @@ public:
virtual void kill_worker() = 0; virtual void kill_worker() = 0;


//! global unique instance //! global unique instance
static TimedFuncInvoker& ins();
MGE_WIN_DECLSPEC_FUC static TimedFuncInvoker& ins();
}; };


} // namespace sys } // namespace sys


+ 36
- 29
src/core/include/megbrain/tensor.h View File

@@ -50,7 +50,7 @@ public:
} }


//! make a SubTensorSpec from given layout and offset //! make a SubTensorSpec from given layout and offset
static SubTensorSpec make_from_offset_elem(
MGE_WIN_DECLSPEC_FUC static SubTensorSpec make_from_offset_elem(
const TensorLayout& layout, ptrdiff_t offset_elem); const TensorLayout& layout, ptrdiff_t offset_elem);


//! get underlying layout //! get underlying layout
@@ -72,7 +72,7 @@ public:
* \brief merge with another SubTensorSpec: accum offset, and replace * \brief merge with another SubTensorSpec: accum offset, and replace
* layout by rhs * layout by rhs
*/ */
void merge_with(const SubTensorSpec& rhs);
MGE_WIN_DECLSPEC_FUC void merge_with(const SubTensorSpec& rhs);
}; };


/*! /*!
@@ -99,7 +99,7 @@ public:
* \param axis the axis to apply this slice; -1 can be used for * \param axis the axis to apply this slice; -1 can be used for
* flattened layout * flattened layout
*/ */
SubTensorSpec apply(TensorLayout layout, int axis) const;
MGE_WIN_DECLSPEC_FUC SubTensorSpec apply(TensorLayout layout, int axis) const;
}; };


template <class Trait> template <class Trait>
@@ -133,7 +133,7 @@ public:


TensorStorage(const TensorStorage& rhs) { *this = rhs; } TensorStorage(const TensorStorage& rhs) { *this = rhs; }


TensorStorage& operator=(const TensorStorage& rhs);
MGE_WIN_DECLSPEC_FUC TensorStorage& operator=(const TensorStorage& rhs);


/*! /*!
* \brief whether given tensor span is valid in this storage * \brief whether given tensor span is valid in this storage
@@ -153,14 +153,14 @@ public:
* 2. This method would only grow storage, but it would not release * 2. This method would only grow storage, but it would not release
* memory * memory
*/ */
TensorStorage& ensure_size(size_t sz);
MGE_WIN_DECLSPEC_FUC TensorStorage& ensure_size(size_t sz);


/*! /*!
* \brief return a subtensor that shares the memory; the returned * \brief return a subtensor that shares the memory; the returned
* subtensor is not allowed to realloc * subtensor is not allowed to realloc
* \param offset offset given in bytes * \param offset offset given in bytes
*/ */
TensorStorage sub(ptrdiff_t offset) const;
MGE_WIN_DECLSPEC_FUC TensorStorage sub(ptrdiff_t offset) const;


//! apply lazy resize and get ptr //! apply lazy resize and get ptr
dt_byte* ptr() const { dt_byte* ptr() const {
@@ -204,7 +204,8 @@ public:
* changed, the underlying data would be released and this tensor would * changed, the underlying data would be released and this tensor would
* become empty * become empty
*/ */
TensorStorage& comp_node(CompNode node, bool allow_mem_node_change = false);
MGE_WIN_DECLSPEC_FUC TensorStorage& comp_node(
CompNode node, bool allow_mem_node_change = false);


/*! /*!
* \brief copy from another TensorStorage, possibly of other storage * \brief copy from another TensorStorage, possibly of other storage
@@ -216,12 +217,13 @@ public:
* this or src * this or src
*/ */
template <class RTrait> template <class RTrait>
void copy_from(const TensorStorage<RTrait>& src, size_t size) const;
MGE_WIN_DECLSPEC_FUC void copy_from(
const TensorStorage<RTrait>& src, size_t size) const;


/*! /*!
* \brief reset the tensor storage to given memory area * \brief reset the tensor storage to given memory area
*/ */
void reset(CompNode node, size_t size, RawStorage data);
MGE_WIN_DECLSPEC_FUC void reset(CompNode node, size_t size, RawStorage data);


/*! /*!
* \brief make a TensorStorage that shares memory with another * \brief make a TensorStorage that shares memory with another
@@ -233,7 +235,8 @@ public:
template < template <
class RTrait, typename = typename std::enable_if< class RTrait, typename = typename std::enable_if<
!std::is_same<Trait, RTrait>::value>::type> !std::is_same<Trait, RTrait>::value>::type>
static TensorStorage make_proxy(const TensorStorage<RTrait>& src);
MGE_WIN_DECLSPEC_FUC static TensorStorage make_proxy(
const TensorStorage<RTrait>& src);


/*! /*!
* \brief make a DeviceTensorStorage on default_cpu * \brief make a DeviceTensorStorage on default_cpu
@@ -302,9 +305,9 @@ private:
on_invalid_comp_node(); on_invalid_comp_node();
} }


dt_byte* apply_lazy_and_get_ptr();
MGE_WIN_DECLSPEC_FUC dt_byte* apply_lazy_and_get_ptr();


[[noreturn]] static void on_invalid_comp_node();
[[noreturn]] MGE_WIN_DECLSPEC_FUC static void on_invalid_comp_node();
}; };


template <class TensorStorage> template <class TensorStorage>
@@ -326,30 +329,31 @@ class TensorND {
public: public:
using ChainReturnType = TensorND<TensorStorage>; using ChainReturnType = TensorND<TensorStorage>;


TensorND();
MGE_WIN_DECLSPEC_FUC TensorND();


explicit TensorND(CompNode node);
MGE_WIN_DECLSPEC_FUC explicit TensorND(CompNode node);


explicit TensorND(DType dtype);
MGE_WIN_DECLSPEC_FUC explicit TensorND(DType dtype);


TensorND(CompNode node, DType dtype);
MGE_WIN_DECLSPEC_FUC TensorND(CompNode node, DType dtype);


//! allocate contiguous tensor //! allocate contiguous tensor
TensorND(
MGE_WIN_DECLSPEC_FUC TensorND(
CompNode node, const TensorShape& shape, DType dtype = dtype::Float32{}, CompNode node, const TensorShape& shape, DType dtype = dtype::Float32{},
TensorFormat format = {}); TensorFormat format = {});


//! allocate contiguous tensor from given comp node and layout; layout //! allocate contiguous tensor from given comp node and layout; layout
//! is required to be contiguous, and its dtype and format would be used //! is required to be contiguous, and its dtype and format would be used
TensorND(CompNode node, const TensorLayout& layout);
MGE_WIN_DECLSPEC_FUC TensorND(CompNode node, const TensorLayout& layout);


/* ================= shape and basic functionality ================= */ /* ================= shape and basic functionality ================= */


//! get subtensor according to given slices //! get subtensor according to given slices
ChainReturnType operator[](std::initializer_list<Slice> slice) const;
MGE_WIN_DECLSPEC_FUC ChainReturnType
operator[](std::initializer_list<Slice> slice) const;


//! get subtensor according to spec //! get subtensor according to spec
ChainReturnType sub(const SubTensorSpec& spec) const;
MGE_WIN_DECLSPEC_FUC ChainReturnType sub(const SubTensorSpec& spec) const;


//! whether underlying storage is empty //! whether underlying storage is empty
bool empty() const { return m_storage.empty(); } bool empty() const { return m_storage.empty(); }
@@ -409,19 +413,21 @@ public:
* *
* dtype and format would not be changed * dtype and format would not be changed
*/ */
ChainReturnType& resize(const TensorShape& shape);
MGE_WIN_DECLSPEC_FUC ChainReturnType& resize(const TensorShape& shape);


/*! /*!
* \brief totally reset the tensor to given storage and layout * \brief totally reset the tensor to given storage and layout
*/ */
ChainReturnType& reset(TensorStorage storage, const TensorLayout& layout);
MGE_WIN_DECLSPEC_FUC ChainReturnType& reset(
TensorStorage storage, const TensorLayout& layout);


/* ================= getter and setters ================= */ /* ================= getter and setters ================= */


/*! /*!
* \brief change comp node; see TensorStorage::comp_node() * \brief change comp node; see TensorStorage::comp_node()
*/ */
ChainReturnType& comp_node(CompNode comp_node, bool allow_mem_node_change = false);
MGE_WIN_DECLSPEC_FUC ChainReturnType& comp_node(
CompNode comp_node, bool allow_mem_node_change = false);


CompNode comp_node() const { return m_storage.comp_node(); } CompNode comp_node() const { return m_storage.comp_node(); }


@@ -431,7 +437,7 @@ public:
* \brief change the storage and invalidate all data, resulting in an * \brief change the storage and invalidate all data, resulting in an
* empty tensor * empty tensor
*/ */
ChainReturnType& storage(const TensorStorage& storage);
MGE_WIN_DECLSPEC_FUC ChainReturnType& storage(const TensorStorage& storage);


//! get data type //! get data type
DType dtype() const { return m_layout.dtype; } DType dtype() const { return m_layout.dtype; }
@@ -444,14 +450,14 @@ public:
* *
* layout would be cleared (reset to ndim=0) if dtype actually changes * layout would be cleared (reset to ndim=0) if dtype actually changes
*/ */
ChainReturnType& dtype(DType dtype);
MGE_WIN_DECLSPEC_FUC ChainReturnType& dtype(DType dtype);


/*! /*!
* \brief change underlying tensor format * \brief change underlying tensor format
* *
* layout would be cleared (reset to ndim=0) if format actually changes * layout would be cleared (reset to ndim=0) if format actually changes
*/ */
ChainReturnType& format(TensorFormat format);
MGE_WIN_DECLSPEC_FUC ChainReturnType& format(TensorFormat format);


/*! /*!
* \brief copy from another tensor and initialize contiguous layout * \brief copy from another tensor and initialize contiguous layout
@@ -470,7 +476,7 @@ public:
* to be contiguous. * to be contiguous.
*/ */
template <class RStorage> template <class RStorage>
ChainReturnType& copy_from(const TensorND<RStorage>& src);
MGE_WIN_DECLSPEC_FUC ChainReturnType& copy_from(const TensorND<RStorage>& src);


/*! /*!
* \brief copy from another tensor of the same shape, retaining current * \brief copy from another tensor of the same shape, retaining current
@@ -481,7 +487,8 @@ public:
* contiguous. * contiguous.
*/ */
template <class RStorage> template <class RStorage>
const ChainReturnType& copy_from_fixlayout(const TensorND<RStorage>& src) const;
MGE_WIN_DECLSPEC_FUC const ChainReturnType& copy_from_fixlayout(
const TensorND<RStorage>& src) const;


//! non-const version of copy_from_fixlayout //! non-const version of copy_from_fixlayout
template <class RStorage> template <class RStorage>
@@ -547,7 +554,7 @@ public:
/*! /*!
* \brief call memset in the data of a device tensor * \brief call memset in the data of a device tensor
*/ */
void dev_tensor_memset(const DeviceTensorND& tensor, int val);
MGE_WIN_DECLSPEC_FUC void dev_tensor_memset(const DeviceTensorND& tensor, int val);


/*! /*!
* \brief fill zeros in the content of a dev tensor * \brief fill zeros in the content of a dev tensor


+ 6
- 5
src/core/include/megbrain/utils/debug.h View File

@@ -28,7 +28,7 @@ public:
using SystemError::SystemError; using SystemError::SystemError;


//! function to throw this exception; could be overwritten //! function to throw this exception; could be overwritten
static void (*throw_)();
static MGE_WIN_DECLSPEC_DATA void (*throw_)();
}; };


struct BacktraceResult { struct BacktraceResult {
@@ -53,7 +53,7 @@ BacktraceResult backtrace(int nr_exclude = 1);
* 1: log warning message * 1: log warning message
* 2: throw ForkAfterCudaError() exception * 2: throw ForkAfterCudaError() exception
*/ */
void set_fork_cuda_warning_flag(int flag);
MGE_WIN_DECLSPEC_FUC void set_fork_cuda_warning_flag(int flag);


/*! /*!
* \brief supress fork warning in this scope * \brief supress fork warning in this scope
@@ -79,7 +79,8 @@ public:
* The binary can be parsed by `megbrain.plugin.load_tensor_binary` python * The binary can be parsed by `megbrain.plugin.load_tensor_binary` python
* function * function
*/ */
std::string dump_tensor(const HostTensorND& value, const std::string& name);
MGE_WIN_DECLSPEC_FUC std::string dump_tensor(
const HostTensorND& value, const std::string& name);


static inline std::string dump_tensor( static inline std::string dump_tensor(
const DeviceTensorND& value, const std::string& name) { const DeviceTensorND& value, const std::string& name) {
@@ -87,7 +88,7 @@ static inline std::string dump_tensor(
} }


//! write the value of a string to file //! write the value of a string to file
void write_to_file(
MGE_WIN_DECLSPEC_FUC void write_to_file(
const char* filename, const std::string& content, const char* mode = "wb"); const char* filename, const std::string& content, const char* mode = "wb");


/*! /*!
@@ -96,7 +97,7 @@ void write_to_file(
* \return None if tensors are considered equal; or a human-readable * \return None if tensors are considered equal; or a human-readable
* message indicating their difference * message indicating their difference
*/ */
Maybe<std::string> compare_tensor_value(
MGE_WIN_DECLSPEC_FUC Maybe<std::string> compare_tensor_value(
const HostTensorND& expect, const char* expect_expr, const HostTensorND& get, const HostTensorND& expect, const char* expect_expr, const HostTensorND& get,
const char* get_expr, float maxerr); const char* get_expr, float maxerr);




+ 3
- 3
src/core/include/megbrain/utils/event.h View File

@@ -65,7 +65,7 @@ public:
class ReceiverHandlerImpl; class ReceiverHandlerImpl;
struct ReceiverHandlerImplDeleter { struct ReceiverHandlerImplDeleter {
public: public:
void operator()(ReceiverHandlerImpl*);
MGE_WIN_DECLSPEC_FUC void operator()(ReceiverHandlerImpl*);
}; };
using ReceiverHandler = using ReceiverHandler =
std::unique_ptr<ReceiverHandlerImpl, ReceiverHandlerImplDeleter>; std::unique_ptr<ReceiverHandlerImpl, ReceiverHandlerImplDeleter>;
@@ -109,8 +109,8 @@ public:
private: private:
std::vector<ReceiverHandler> m_permanent_handler; std::vector<ReceiverHandler> m_permanent_handler;


ReceiverHandler do_register_receiver(
Typeinfo* type, std::unique_ptr<ReceiverBase> receiver);
MGE_WIN_DECLSPEC_FUC ReceiverHandler
do_register_receiver(Typeinfo* type, std::unique_ptr<ReceiverBase> receiver);
}; };


} // namespace mgb } // namespace mgb


+ 5
- 4
src/core/include/megbrain/utils/hash.h View File

@@ -14,6 +14,7 @@
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include "megbrain/utils/thin/function.h" #include "megbrain/utils/thin/function.h"
#include "megbrain_build_config.h"


namespace mgb { namespace mgb {


@@ -57,14 +58,14 @@ class XXHash {
long long m_state[11]; long long m_state[11];


public: public:
XXHash();
void reset();
MGE_WIN_DECLSPEC_FUC XXHash();
MGE_WIN_DECLSPEC_FUC void reset();


//! update internal state, and return *this //! update internal state, and return *this
XXHash& update(const void* data, size_t len);
MGE_WIN_DECLSPEC_FUC XXHash& update(const void* data, size_t len);


//! get hash value, guaranteed to be non-zero //! get hash value, guaranteed to be non-zero
uint64_t digest() const;
MGE_WIN_DECLSPEC_FUC uint64_t digest() const;
}; };


/*! /*!


+ 3
- 3
src/core/include/megbrain/utils/hashable.h View File

@@ -144,7 +144,7 @@ public:
*/ */
template <typename T> template <typename T>
class ScalarHash final : public HashableVD { class ScalarHash final : public HashableVD {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


union U { union U {
T t; T t;
@@ -181,7 +181,7 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(ScalarHash<T>);
*/ */
template <typename T> template <typename T>
class PODHash final : public HashableVD { class PODHash final : public HashableVD {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


static_assert(is_location_invariant<T>::value, "key must be location invariant"); static_assert(is_location_invariant<T>::value, "key must be location invariant");


@@ -219,7 +219,7 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(PODHash<T>);
* \brief wraps around a raw pointer to Hashable object * \brief wraps around a raw pointer to Hashable object
*/ */
class HashableObjPtrWrapper final : public HashableVD { class HashableObjPtrWrapper final : public HashableVD {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


const Hashable* m_ptr; const Hashable* m_ptr;




+ 10
- 7
src/core/include/megbrain/utils/infile_persistent_cache.h View File

@@ -60,19 +60,22 @@ class InFilePersistentCache final : public PersistentCache {
void read_cache(Input& inp); void read_cache(Input& inp);


public: public:
InFilePersistentCache() = default;
InFilePersistentCache(const char* path, bool always_open = false);
InFilePersistentCache(const uint8_t* bin, size_t size);
MGE_WIN_DECLSPEC_FUC InFilePersistentCache() = default;
MGE_WIN_DECLSPEC_FUC InFilePersistentCache(
const char* path, bool always_open = false);
MGE_WIN_DECLSPEC_FUC InFilePersistentCache(const uint8_t* bin, size_t size);


/** /**
* \warning You should invoke \c dump_cache mannually to save the cache * \warning You should invoke \c dump_cache mannually to save the cache
* file. * file.
*/ */
void dump_cache(const char* path);
void dump_cache(OutputFile* out_file);
MGE_WIN_DECLSPEC_FUC void dump_cache(const char* path);
MGE_WIN_DECLSPEC_FUC void dump_cache(OutputFile* out_file);


Maybe<Blob> get(const std::string& category, const Blob& key) override;
void put(const std::string& category, const Blob& key, const Blob& value) override;
MGE_WIN_DECLSPEC_FUC Maybe<Blob> get(
const std::string& category, const Blob& key) override;
MGE_WIN_DECLSPEC_FUC void put(
const std::string& category, const Blob& key, const Blob& value) override;
bool support_dump_cache() override { return true; } bool support_dump_cache() override { return true; }
}; };
} // namespace mgb } // namespace mgb


+ 18
- 16
src/core/include/megbrain/utils/json.h View File

@@ -28,19 +28,21 @@ class Value : public std::enable_shared_from_this<Value>, public DynTypeObj {
public: public:
virtual void writeto(std::string& fout, int indent = 0) const = 0; virtual void writeto(std::string& fout, int indent = 0) const = 0;


void writeto_fpath(const std::string& fout_path, int indent = 0) const {
MGE_WIN_DECLSPEC_FUC void writeto_fpath(
const std::string& fout_path, int indent = 0) const {
writeto_fpath(fout_path.c_str(), indent); writeto_fpath(fout_path.c_str(), indent);
} }


void writeto_fpath(const char* fout_path, int indent = 0) const;
MGE_WIN_DECLSPEC_FUC void writeto_fpath(
const char* fout_path, int indent = 0) const;


virtual std::string to_string(int indent = 0) const final;
MGE_WIN_DECLSPEC_FUC virtual std::string to_string(int indent = 0) const final;


virtual ~Value() = default; virtual ~Value() = default;
}; };


class Number final : public Value { class Number final : public Value {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


double m_val; double m_val;


@@ -59,7 +61,7 @@ public:
}; };


class NumberInt final : public Value { class NumberInt final : public Value {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


int64_t m_val; int64_t m_val;


@@ -70,7 +72,7 @@ public:
return std::make_shared<NumberInt>(v); return std::make_shared<NumberInt>(v);
} }


void writeto(std::string& fout, int indent = 0) const override;
MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override;


auto&& get_impl() { return m_val; } auto&& get_impl() { return m_val; }


@@ -78,7 +80,7 @@ public:
}; };


class Bool final : public Value { class Bool final : public Value {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


bool m_val; bool m_val;


@@ -87,7 +89,7 @@ public:


static std::shared_ptr<Bool> make(bool v); static std::shared_ptr<Bool> make(bool v);


void writeto(std::string& fout, int indent = 0) const override;
MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override;


auto&& get_impl() { return m_val; } auto&& get_impl() { return m_val; }


@@ -95,7 +97,7 @@ public:
}; };


class String final : public Value { class String final : public Value {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


std::string m_val; std::string m_val;


@@ -110,7 +112,7 @@ public:


bool operator==(const String& rhs) const { return m_val == rhs.m_val; } bool operator==(const String& rhs) const { return m_val == rhs.m_val; }


void writeto(std::string& fout, int indent = 0) const override;
MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override;


auto&& get_impl() { return m_val; } auto&& get_impl() { return m_val; }


@@ -118,7 +120,7 @@ public:
}; };


class Object final : public Value { class Object final : public Value {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


std::unordered_map<String, std::shared_ptr<Value>, StdHashAdaptor<String>> m_val; std::unordered_map<String, std::shared_ptr<Value>, StdHashAdaptor<String>> m_val;


@@ -140,7 +142,7 @@ public:


std::shared_ptr<Value>& operator[](const char* s) { return m_val[std::string(s)]; } std::shared_ptr<Value>& operator[](const char* s) { return m_val[std::string(s)]; }


void writeto(std::string& fout, int indent = 0) const override;
MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override;


auto&& get_impl() { return m_val; } auto&& get_impl() { return m_val; }


@@ -148,7 +150,7 @@ public:
}; };


class Array final : public Value { class Array final : public Value {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


std::vector<std::shared_ptr<Value>> m_val; std::vector<std::shared_ptr<Value>> m_val;


@@ -162,7 +164,7 @@ public:


std::shared_ptr<Value>& operator[](size_t idx) { return m_val.at(idx); } std::shared_ptr<Value>& operator[](size_t idx) { return m_val.at(idx); }


void writeto(std::string& fout, int indent = 0) const override;
MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override;


auto&& get_impl() { return m_val; } auto&& get_impl() { return m_val; }


@@ -170,7 +172,7 @@ public:
}; };


class Null final : public Value { class Null final : public Value {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


public: public:
static std::shared_ptr<Value> make() { static std::shared_ptr<Value> make() {
@@ -178,7 +180,7 @@ public:
return v; return v;
} }


void writeto(std::string& fout, int /*indent*/) const override;
MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int /*indent*/) const override;
}; };


class Serializable { class Serializable {


+ 10
- 9
src/core/include/megbrain/utils/mempool.h View File

@@ -15,6 +15,7 @@
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "megbrain_build_config.h"


namespace mgb { namespace mgb {


@@ -25,24 +26,24 @@ class MemPoolStorage {
std::vector<void*> m_free; std::vector<void*> m_free;


public: public:
MemPoolStorage() noexcept;
MemPoolStorage(MemPoolStorage&& rhs) noexcept;
~MemPoolStorage() noexcept;
MemPoolStorage& operator=(MemPoolStorage&& rhs) noexcept;
MGE_WIN_DECLSPEC_FUC MemPoolStorage() noexcept;
MGE_WIN_DECLSPEC_FUC MemPoolStorage(MemPoolStorage&& rhs) noexcept;
MGE_WIN_DECLSPEC_FUC ~MemPoolStorage() noexcept;
MGE_WIN_DECLSPEC_FUC MemPoolStorage& operator=(MemPoolStorage&& rhs) noexcept;


void swap(MemPoolStorage& other);
MGE_WIN_DECLSPEC_FUC void swap(MemPoolStorage& other);


/*! /*!
* \brief allocate sotrage for an object of specified size * \brief allocate sotrage for an object of specified size
* \param elem_size size of the object; it must remain unchanged * \param elem_size size of the object; it must remain unchanged
* during lifespan of this MemPoolStorage * during lifespan of this MemPoolStorage
*/ */
void* alloc(size_t elem_size);
void free(void* ptr);
void reorder_free();
MGE_WIN_DECLSPEC_FUC void* alloc(size_t elem_size);
MGE_WIN_DECLSPEC_FUC void free(void* ptr);
MGE_WIN_DECLSPEC_FUC void reorder_free();


//! clear all allocated storage //! clear all allocated storage
void clear();
MGE_WIN_DECLSPEC_FUC void clear();


void disable_freelist() { m_disable_freelist = true; } void disable_freelist() { m_disable_freelist = true; }
}; };


+ 17
- 5
src/core/include/megbrain/utils/metahelper.h View File

@@ -115,6 +115,13 @@ public: \
private: \ private: \
static ::mgb::Typeinfo sm_typeinfo static ::mgb::Typeinfo sm_typeinfo


#define MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT \
public: \
static inline ::mgb::Typeinfo* typeinfo() { return &sm_typeinfo; } \
\
private: \
static MGE_WIN_DECLSPEC_DATA ::mgb::Typeinfo sm_typeinfo

#if MGB_VERBOSE_TYPEINFO_NAME #if MGB_VERBOSE_TYPEINFO_NAME
//! get class name from class object //! get class name from class object
#define _MGB_TYPEINFO_CLASS_NAME(_cls) #_cls #define _MGB_TYPEINFO_CLASS_NAME(_cls) #_cls
@@ -133,6 +140,11 @@ public: \
::mgb::Typeinfo* dyn_typeinfo() const override final; \ ::mgb::Typeinfo* dyn_typeinfo() const override final; \
MGB_TYPEINFO_OBJ_DECL MGB_TYPEINFO_OBJ_DECL


#define MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT \
public: \
MGE_WIN_DECLSPEC_FUC ::mgb::Typeinfo* dyn_typeinfo() const override final; \
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT

//! put in the impl file of a final class inherited from DynTypeObj //! put in the impl file of a final class inherited from DynTypeObj
#define MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) \ #define MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) \
_MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL \ _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL \
@@ -364,7 +376,7 @@ public:
virtual ~UserData() = default; virtual ~UserData() = default;
}; };


~UserDataContainer() noexcept;
MGE_WIN_DECLSPEC_FUC ~UserDataContainer() noexcept;


/*! /*!
* \brief register new user data * \brief register new user data
@@ -430,10 +442,10 @@ public:
} }


private: private:
void do_add(Typeinfo* type, std::shared_ptr<UserData> ptr);
std::pair<void* const*, size_t> do_get(Typeinfo* type) const;
void* do_get_one(Typeinfo* type) const;
int do_pop(Typeinfo* type);
MGE_WIN_DECLSPEC_FUC void do_add(Typeinfo* type, std::shared_ptr<UserData> ptr);
MGE_WIN_DECLSPEC_FUC std::pair<void* const*, size_t> do_get(Typeinfo* type) const;
MGE_WIN_DECLSPEC_FUC void* do_get_one(Typeinfo* type) const;
MGE_WIN_DECLSPEC_FUC int do_pop(Typeinfo* type);


//! use a set to help erase //! use a set to help erase
std::unordered_set<std::shared_ptr<UserData>> m_refkeeper; std::unordered_set<std::shared_ptr<UserData>> m_refkeeper;


+ 2
- 2
src/core/include/megbrain/utils/metahelper_basic.h View File

@@ -22,7 +22,7 @@
namespace mgb { namespace mgb {


namespace metahelper_detail { namespace metahelper_detail {
[[noreturn]] void on_maybe_invalid_val_access();
[[noreturn]] MGE_WIN_DECLSPEC_FUC void on_maybe_invalid_val_access();


template <class T, class Tuple, size_t... I> template <class T, class Tuple, size_t... I>
constexpr T make_from_tuple_impl(Tuple&& t, std::index_sequence<I...>) { constexpr T make_from_tuple_impl(Tuple&& t, std::index_sequence<I...>) {
@@ -140,7 +140,7 @@ constexpr bool is_complete_v =


//! a None type to represent invalid Maybe //! a None type to represent invalid Maybe
class None {}; class None {};
extern class None None;
MGE_WIN_DECLSPEC_DATA extern class None None;


//! an optional storage for arbitrary object //! an optional storage for arbitrary object
template <typename T> template <typename T>


+ 2
- 2
src/core/include/megbrain/utils/persistent_cache.h View File

@@ -24,7 +24,7 @@ namespace mgb {
* The implementation must be thread safe. * The implementation must be thread safe.
*/ */
class PersistentCache { class PersistentCache {
static std::shared_ptr<PersistentCache> sm_impl;
static MGE_WIN_DECLSPEC_DATA std::shared_ptr<PersistentCache> sm_impl;


public: public:
virtual ~PersistentCache() = default; virtual ~PersistentCache() = default;
@@ -42,7 +42,7 @@ public:
virtual bool support_dump_cache() { return false; } virtual bool support_dump_cache() { return false; }


//! set an implementation; return the original implementation //! set an implementation; return the original implementation
static std::shared_ptr<PersistentCache> set_impl(
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<PersistentCache> set_impl(
std::shared_ptr<PersistentCache> impl); std::shared_ptr<PersistentCache> impl);


//! get the instance; the default implementation just caches in //! get the instance; the default implementation just caches in


+ 9
- 9
src/core/include/megbrain/utils/thread_impl_1.h View File

@@ -68,28 +68,28 @@ class SCQueueSynchronizer {
std::thread m_worker_thread; std::thread m_worker_thread;


public: public:
SCQueueSynchronizer(size_t max_spin);
MGE_WIN_DECLSPEC_FUC SCQueueSynchronizer(size_t max_spin);


~SCQueueSynchronizer() noexcept;
MGE_WIN_DECLSPEC_FUC ~SCQueueSynchronizer() noexcept;


bool worker_started() const { return m_worker_started; } bool worker_started() const { return m_worker_started; }


#ifdef WIN32 #ifdef WIN32
static bool is_into_atexit;
static MGE_WIN_DECLSPEC_DATA bool is_into_atexit;
void set_finish_called(bool status) { m_wait_finish_called = status; } void set_finish_called(bool status) { m_wait_finish_called = status; }
#endif #endif


//! get global default max spin from env //! get global default max spin from env
static size_t get_default_max_spin();
MGE_WIN_DECLSPEC_FUC static size_t get_default_max_spin();


void start_worker(std::thread thread);
MGE_WIN_DECLSPEC_FUC void start_worker(std::thread thread);


//! add a new task in producer thread; require worker to have //! add a new task in producer thread; require worker to have
//! started //! started
void producer_add();
MGE_WIN_DECLSPEC_FUC void producer_add();


//! wait for currently added tasks to finish //! wait for currently added tasks to finish
void producer_wait();
MGE_WIN_DECLSPEC_FUC void producer_wait();


bool check_finished() const { bool check_finished() const {
return m_finished_task.load(std::memory_order_acquire) == return m_finished_task.load(std::memory_order_acquire) ==
@@ -102,13 +102,13 @@ public:
* \param min minimal number of tasks to be fetched * \param min minimal number of tasks to be fetched
* \return number of tasks fetched; return 0 if worker should exit * \return number of tasks fetched; return 0 if worker should exit
*/ */
size_t consumer_fetch(size_t max, size_t min = 1);
MGE_WIN_DECLSPEC_FUC size_t consumer_fetch(size_t max, size_t min = 1);


/*! /*!
* \brief ack that tasks have been processed in consumer * \brief ack that tasks have been processed in consumer
* \param nr numnber of tasks to be committed * \param nr numnber of tasks to be committed
*/ */
void consumer_commit(size_t nr);
MGE_WIN_DECLSPEC_FUC void consumer_commit(size_t nr);
}; };


/*! /*!


+ 2
- 1
src/core/include/megbrain/utils/timer.h View File

@@ -12,6 +12,7 @@
#pragma once #pragma once


#include <string> #include <string>
#include "megbrain_build_config.h"


namespace mgb { namespace mgb {


@@ -34,7 +35,7 @@ class Timer {
TimeSpec m_start; TimeSpec m_start;


public: public:
static TimeSpec get_time();
MGE_WIN_DECLSPEC_FUC static TimeSpec get_time();


Timer() { reset(); } Timer() { reset(); }




+ 3
- 1
src/core/include/megbrain/version.h View File

@@ -11,6 +11,8 @@


#pragma once #pragma once


#include "megbrain_build_config.h"

#define MGB_MAJOR 8 #define MGB_MAJOR 8
#define MGB_MINOR 9999 #define MGB_MINOR 9999
#define MGB_PATCH 0 #define MGB_PATCH 0
@@ -24,7 +26,7 @@ struct Version {
int major, minor, patch, is_dev; int major, minor, patch, is_dev;
}; };


Version get_version();
MGE_WIN_DECLSPEC_FUC Version get_version();
} // namespace mgb } // namespace mgb


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 19
- 15
src/custom/include/megbrain/custom/manager.h View File

@@ -24,23 +24,27 @@ class CustomOpManager {


public: public:
PREVENT_COPY_AND_ASSIGN(CustomOpManager); PREVENT_COPY_AND_ASSIGN(CustomOpManager);
static CustomOpManager* inst(void);
~CustomOpManager();
MGE_WIN_DECLSPEC_FUC static CustomOpManager* inst(void);
MGE_WIN_DECLSPEC_FUC ~CustomOpManager();


std::shared_ptr<CustomOp> insert(const std::string& name, uint32_t version);
bool erase(const std::string& name);
bool erase(const RunTimeId& id);
MGE_WIN_DECLSPEC_FUC std::shared_ptr<CustomOp> insert(
const std::string& name, uint32_t version);
MGE_WIN_DECLSPEC_FUC bool erase(const std::string& name);
MGE_WIN_DECLSPEC_FUC bool erase(const RunTimeId& id);


std::shared_ptr<CustomOp> find_or_reg(const std::string& name, uint32_t version);
MGE_WIN_DECLSPEC_FUC std::shared_ptr<CustomOp> find_or_reg(
const std::string& name, uint32_t version);


RunTimeId to_id(const std::string& name) const;
std::string to_name(const RunTimeId& id) const;
MGE_WIN_DECLSPEC_FUC RunTimeId to_id(const std::string& name) const;
MGE_WIN_DECLSPEC_FUC std::string to_name(const RunTimeId& id) const;


std::shared_ptr<const CustomOp> find(const std::string& name) const;
std::shared_ptr<const CustomOp> find(const RunTimeId& id) const;
MGE_WIN_DECLSPEC_FUC std::shared_ptr<const CustomOp> find(
const std::string& name) const;
MGE_WIN_DECLSPEC_FUC std::shared_ptr<const CustomOp> find(
const RunTimeId& id) const;


std::vector<std::string> op_name_list(void);
std::vector<RunTimeId> op_id_list(void);
MGE_WIN_DECLSPEC_FUC std::vector<std::string> op_name_list(void);
MGE_WIN_DECLSPEC_FUC std::vector<RunTimeId> op_id_list(void);
}; };


class CustomLib { class CustomLib {
@@ -67,10 +71,10 @@ class LibManager {
public: public:
PREVENT_COPY_AND_ASSIGN(LibManager); PREVENT_COPY_AND_ASSIGN(LibManager);


static LibManager* inst(void);
const std::vector<std::string>& install(
MGE_WIN_DECLSPEC_FUC static LibManager* inst(void);
MGE_WIN_DECLSPEC_FUC const std::vector<std::string>& install(
const std::string& name, const std::string& path); const std::string& name, const std::string& path);
bool uninstall(const std::string& name);
MGE_WIN_DECLSPEC_FUC bool uninstall(const std::string& name);
friend class CustomOpManager; friend class CustomOpManager;
}; };




+ 51
- 45
src/custom/include/megbrain/custom/op.h View File

@@ -34,24 +34,25 @@ using RunTimeId = uint64_t;


class ArgInfo { class ArgInfo {
CUSTOM_PIMPL_CLS_DECL(ArgInfo); CUSTOM_PIMPL_CLS_DECL(ArgInfo);
ArgInfo(const std::string& name, const std::string& desc,
MGE_WIN_DECLSPEC_FUC ArgInfo(
const std::string& name, const std::string& desc,
const std::unordered_set<std::string>& dtypes, const int& ndim, const std::unordered_set<std::string>& dtypes, const int& ndim,
const std::string& mem_stgy); const std::string& mem_stgy);


const std::string& name(void) const;
const std::string& desc(void) const;
const std::unordered_set<std::string>& dtypes(void) const;
int ndim(void) const;
const std::string& mem_strategy(void) const;
MGE_WIN_DECLSPEC_FUC const std::string& name(void) const;
MGE_WIN_DECLSPEC_FUC const std::string& desc(void) const;
MGE_WIN_DECLSPEC_FUC const std::unordered_set<std::string>& dtypes(void) const;
MGE_WIN_DECLSPEC_FUC int ndim(void) const;
MGE_WIN_DECLSPEC_FUC const std::string& mem_strategy(void) const;


std::string str() const;
MGE_WIN_DECLSPEC_FUC std::string str() const;
}; };


class CustomOp { class CustomOp {
std::unique_ptr<void, void_deleter> m_impl; std::unique_ptr<void, void_deleter> m_impl;


public: public:
CustomOp(const std::string& op_type, uint32_t version);
MGE_WIN_DECLSPEC_FUC CustomOp(const std::string& op_type, uint32_t version);
PREVENT_COPY_AND_ASSIGN(CustomOp); PREVENT_COPY_AND_ASSIGN(CustomOp);


using DeviceInferFuncPtr = using DeviceInferFuncPtr =
@@ -70,65 +71,70 @@ public:
void (*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); void (*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&);


// write for forward // write for forward
CustomOp& set_device_infer(DeviceInferFuncPtr func);
CustomOp& set_shape_infer(ShapeInferFuncPtr func);
CustomOp& set_dtype_infer(DTypeInferFuncPtr func);
CustomOp& set_format_infer(FormatInferFuncPtr func);
CustomOp& set_preprocess(PreprocessFuncPtr func);
CustomOp& set_preprocess(const std::string& device, PreprocessFuncPtr func);
CustomOp& set_postprocess(PostprocessFuncPtr func);
CustomOp& set_postprocess(const std::string& device, PostprocessFuncPtr func);
CustomOp& set_compute(ComputeFuncPtr func);
CustomOp& set_compute(const std::string& device, ComputeFuncPtr func);

CustomOp& set_description(const std::string& op_desc);
CustomOp& add_input(
MGE_WIN_DECLSPEC_FUC CustomOp& set_device_infer(DeviceInferFuncPtr func);
MGE_WIN_DECLSPEC_FUC CustomOp& set_shape_infer(ShapeInferFuncPtr func);
MGE_WIN_DECLSPEC_FUC CustomOp& set_dtype_infer(DTypeInferFuncPtr func);
MGE_WIN_DECLSPEC_FUC CustomOp& set_format_infer(FormatInferFuncPtr func);
MGE_WIN_DECLSPEC_FUC CustomOp& set_preprocess(PreprocessFuncPtr func);
MGE_WIN_DECLSPEC_FUC CustomOp& set_preprocess(
const std::string& device, PreprocessFuncPtr func);
MGE_WIN_DECLSPEC_FUC CustomOp& set_postprocess(PostprocessFuncPtr func);
MGE_WIN_DECLSPEC_FUC CustomOp& set_postprocess(
const std::string& device, PostprocessFuncPtr func);
MGE_WIN_DECLSPEC_FUC CustomOp& set_compute(ComputeFuncPtr func);
MGE_WIN_DECLSPEC_FUC CustomOp& set_compute(
const std::string& device, ComputeFuncPtr func);

MGE_WIN_DECLSPEC_FUC CustomOp& set_description(const std::string& op_desc);
MGE_WIN_DECLSPEC_FUC CustomOp& add_input(
const std::string& name, const std::string& desc, const std::string& name, const std::string& desc,
const std::initializer_list<std::string>& legal_dtypes = {"float32"}, const std::initializer_list<std::string>& legal_dtypes = {"float32"},
int dims = -1, const std::string& mem_stgy = "default"); int dims = -1, const std::string& mem_stgy = "default");
CustomOp& add_output(
MGE_WIN_DECLSPEC_FUC CustomOp& add_output(
const std::string& name, const std::string& desc, const std::string& name, const std::string& desc,
const std::initializer_list<std::string>& legal_dtypes = {"float32"}, const std::initializer_list<std::string>& legal_dtypes = {"float32"},
int dims = -1, const std::string& mem_stgy = "default"); int dims = -1, const std::string& mem_stgy = "default");
CustomOp& add_input(
MGE_WIN_DECLSPEC_FUC CustomOp& add_input(
const std::string& name, const std::string& name,
const std::initializer_list<std::string>& legal_dtypes = {"float32"}, const std::initializer_list<std::string>& legal_dtypes = {"float32"},
int dims = -1, const std::string& mem_stgy = "default"); int dims = -1, const std::string& mem_stgy = "default");
CustomOp& add_output(
MGE_WIN_DECLSPEC_FUC CustomOp& add_output(
const std::string& name, const std::string& name,
const std::initializer_list<std::string>& legal_dtypes = {"float32"}, const std::initializer_list<std::string>& legal_dtypes = {"float32"},
int dims = -1, const std::string& mem_stgy = "default"); int dims = -1, const std::string& mem_stgy = "default");
CustomOp& add_inputs(const size_t& input_num);
CustomOp& add_outputs(const size_t& output_num);
CustomOp& add_param(const std::string& name, const ParamVal& default_val);
CustomOp& add_param(
MGE_WIN_DECLSPEC_FUC CustomOp& add_inputs(const size_t& input_num);
MGE_WIN_DECLSPEC_FUC CustomOp& add_outputs(const size_t& output_num);
MGE_WIN_DECLSPEC_FUC CustomOp& add_param(
const std::string& name, const ParamVal& default_val);
MGE_WIN_DECLSPEC_FUC CustomOp& add_param(
const std::string& name, const std::string& desc, const std::string& name, const std::string& desc,
const ParamVal& default_val); const ParamVal& default_val);


// read // read
std::string op_type(void) const;
std::string op_desc(void) const;
RunTimeId runtime_id(void) const;
size_t input_num(void) const;
size_t output_num(void) const;
std::string str(void) const;
const ParamInfo& param_info(void) const;
ArgInfo input_info(size_t idx) const;
ArgInfo output_info(size_t idx) const;
const std::vector<ArgInfo>& inputs_info(void) const;
const std::vector<ArgInfo>& outputs_info(void) const;
MGE_WIN_DECLSPEC_FUC std::string op_type(void) const;
MGE_WIN_DECLSPEC_FUC std::string op_desc(void) const;
MGE_WIN_DECLSPEC_FUC RunTimeId runtime_id(void) const;
MGE_WIN_DECLSPEC_FUC size_t input_num(void) const;
MGE_WIN_DECLSPEC_FUC size_t output_num(void) const;
MGE_WIN_DECLSPEC_FUC std::string str(void) const;
MGE_WIN_DECLSPEC_FUC const ParamInfo& param_info(void) const;
MGE_WIN_DECLSPEC_FUC ArgInfo input_info(size_t idx) const;
MGE_WIN_DECLSPEC_FUC ArgInfo output_info(size_t idx) const;
MGE_WIN_DECLSPEC_FUC const std::vector<ArgInfo>& inputs_info(void) const;
MGE_WIN_DECLSPEC_FUC const std::vector<ArgInfo>& outputs_info(void) const;


// use // use
std::vector<Device> infer_output_device(
MGE_WIN_DECLSPEC_FUC std::vector<Device> infer_output_device(
const std::vector<Device>&, const Param&) const; const std::vector<Device>&, const Param&) const;
std::vector<Shape> infer_output_shape(
MGE_WIN_DECLSPEC_FUC std::vector<Shape> infer_output_shape(
const std::vector<Shape>&, const Param&) const; const std::vector<Shape>&, const Param&) const;
std::vector<DType> infer_output_dtype(
MGE_WIN_DECLSPEC_FUC std::vector<DType> infer_output_dtype(
const std::vector<DType>&, const Param&) const; const std::vector<DType>&, const Param&) const;
std::vector<Format> infer_output_format(
MGE_WIN_DECLSPEC_FUC std::vector<Format> infer_output_format(
const std::vector<Format>&, const Param&) const; const std::vector<Format>&, const Param&) const;
void compute(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) const;
MGE_WIN_DECLSPEC_FUC void compute(
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) const;
}; };


} // namespace custom } // namespace custom

+ 8
- 8
src/custom/include/megbrain/custom/param.h View File

@@ -49,15 +49,15 @@ class ParamInfo {
class Param { class Param {
CUSTOM_PIMPL_CLS_DECL(Param); CUSTOM_PIMPL_CLS_DECL(Param);


Param(const ParamInfo&);
ParamVal& operator[](const std::string&);
const ParamVal& operator[](const std::string&) const;
const std::unordered_map<std::string, ParamVal>& raw() const;
bool exist(const std::string& name) const;
std::string to_bytes(void) const;
void from_bytes(const std::string&);
MGE_WIN_DECLSPEC_FUC Param(const ParamInfo&);
MGE_WIN_DECLSPEC_FUC ParamVal& operator[](const std::string&);
MGE_WIN_DECLSPEC_FUC const ParamVal& operator[](const std::string&) const;
MGE_WIN_DECLSPEC_FUC const std::unordered_map<std::string, ParamVal>& raw() const;
MGE_WIN_DECLSPEC_FUC bool exist(const std::string& name) const;
MGE_WIN_DECLSPEC_FUC std::string to_bytes(void) const;
MGE_WIN_DECLSPEC_FUC void from_bytes(const std::string&);
}; };


bool operator==(const Param&, const Param&);
MGE_WIN_DECLSPEC_FUC bool operator==(const Param&, const Param&);


} // namespace custom } // namespace custom

+ 17
- 16
src/custom/include/megbrain/custom/param_val.h View File

@@ -175,15 +175,15 @@ class ParamVal {


public: public:
template <typename T> template <typename T>
ParamVal(const T& val);
MGE_WIN_DECLSPEC_FUC ParamVal(const T& val);
template <typename T> template <typename T>
ParamVal(const std::initializer_list<T>& val);
MGE_WIN_DECLSPEC_FUC ParamVal(const std::initializer_list<T>& val);


ParamVal();
ParamVal(const char* str);
ParamVal(const std::initializer_list<const char*>& strs);
ParamVal(const std::vector<const char*>& strs);
ParamVal(const ParamVal& rhs);
MGE_WIN_DECLSPEC_FUC ParamVal();
MGE_WIN_DECLSPEC_FUC ParamVal(const char* str);
MGE_WIN_DECLSPEC_FUC ParamVal(const std::initializer_list<const char*>& strs);
MGE_WIN_DECLSPEC_FUC ParamVal(const std::vector<const char*>& strs);
MGE_WIN_DECLSPEC_FUC ParamVal(const ParamVal& rhs);


template <typename T> template <typename T>
ParamVal& operator=(const T& rhs); ParamVal& operator=(const T& rhs);
@@ -196,18 +196,19 @@ public:
ParamVal& operator=(const ParamVal& rhs); ParamVal& operator=(const ParamVal& rhs);


template <typename T> template <typename T>
const T& as(void) const;
MGE_WIN_DECLSPEC_FUC const T& as(void) const;
template <typename T> template <typename T>
T& as(void);
MGE_WIN_DECLSPEC_FUC T& as(void);


const void* raw_ptr(void) const;
void* raw_ptr(void);
ParamDynType type(void) const;
std::string str(void) const;
size_t size(void) const;
MGE_WIN_DECLSPEC_FUC const void* raw_ptr(void) const;
MGE_WIN_DECLSPEC_FUC void* raw_ptr(void);
MGE_WIN_DECLSPEC_FUC ParamDynType type(void) const;
MGE_WIN_DECLSPEC_FUC std::string str(void) const;
MGE_WIN_DECLSPEC_FUC size_t size(void) const;


static std::string to_bytes(const ParamVal& value);
static ParamVal from_bytes(const std::string& bytes, size_t& offset);
MGE_WIN_DECLSPEC_FUC static std::string to_bytes(const ParamVal& value);
MGE_WIN_DECLSPEC_FUC static ParamVal from_bytes(
const std::string& bytes, size_t& offset);


friend ParamVal operator+(const ParamVal& lhs, const ParamVal& rhs); friend ParamVal operator+(const ParamVal& lhs, const ParamVal& rhs);
friend ParamVal operator-(const ParamVal& lhs, const ParamVal& rhs); friend ParamVal operator-(const ParamVal& lhs, const ParamVal& rhs);


+ 56
- 55
src/custom/include/megbrain/custom/tensor.h View File

@@ -31,8 +31,8 @@ namespace custom {
custom_type, custom_type,


class Device { class Device {
const void* impl() const;
Device(const void* impl);
MGE_WIN_DECLSPEC_FUC const void* impl() const;
MGE_WIN_DECLSPEC_FUC Device(const void* impl);
CUSTOM_PIMPL_CLS_DECL(Device); CUSTOM_PIMPL_CLS_DECL(Device);


public: public:
@@ -40,16 +40,16 @@ public:
CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_DEVICE_TYPE_ENUM_DECL) CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_DEVICE_TYPE_ENUM_DECL)
}; };


Device(const std::string& device);
Device(const char* device);
Device(DeviceEnum device);
MGE_WIN_DECLSPEC_FUC Device(const std::string& device);
MGE_WIN_DECLSPEC_FUC Device(const char* device);
MGE_WIN_DECLSPEC_FUC Device(DeviceEnum device);


std::string str(void) const;
DeviceEnum enumv(void) const;
MGE_WIN_DECLSPEC_FUC std::string str(void) const;
MGE_WIN_DECLSPEC_FUC DeviceEnum enumv(void) const;


static bool is_legal(const std::string& device);
static bool is_legal(DeviceEnum device);
static std::vector<std::string> legal_devices(void);
MGE_WIN_DECLSPEC_FUC static bool is_legal(const std::string& device);
MGE_WIN_DECLSPEC_FUC static bool is_legal(DeviceEnum device);
MGE_WIN_DECLSPEC_FUC static std::vector<std::string> legal_devices(void);


friend class Tensor; friend class Tensor;
friend bool operator==(const Device& lhs, const Device& rhs); friend bool operator==(const Device& lhs, const Device& rhs);
@@ -61,19 +61,19 @@ using DeviceEnum = Device::DeviceEnum;
bool operator==(const Device& lhs, const Device& rhs); bool operator==(const Device& lhs, const Device& rhs);


class Shape { class Shape {
const void* impl() const;
Shape(const void* impl);
MGE_WIN_DECLSPEC_FUC const void* impl() const;
MGE_WIN_DECLSPEC_FUC Shape(const void* impl);
CUSTOM_PIMPL_CLS_DECL(Shape); CUSTOM_PIMPL_CLS_DECL(Shape);


public: public:
Shape(const std::vector<size_t>& rhs);
Shape(const std::initializer_list<size_t>& rhs);
MGE_WIN_DECLSPEC_FUC Shape(const std::vector<size_t>& rhs);
MGE_WIN_DECLSPEC_FUC Shape(const std::initializer_list<size_t>& rhs);


size_t& operator[](size_t idx); size_t& operator[](size_t idx);
size_t operator[](size_t idx) const; size_t operator[](size_t idx) const;


void ndim(size_t dim);
size_t ndim(void) const;
MGE_WIN_DECLSPEC_FUC void ndim(size_t dim);
MGE_WIN_DECLSPEC_FUC size_t ndim(void) const;


friend class Tensor; friend class Tensor;
friend bool operator==(const Shape& lhs, const Shape& rhs); friend bool operator==(const Shape& lhs, const Shape& rhs);
@@ -105,8 +105,8 @@ using bfloat16_t = uint16_t;
#define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type, #define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type,


class DType { class DType {
const void* impl() const;
DType(const void* impl);
MGE_WIN_DECLSPEC_FUC const void* impl() const;
MGE_WIN_DECLSPEC_FUC DType(const void* impl);
CUSTOM_PIMPL_CLS_DECL(DType); CUSTOM_PIMPL_CLS_DECL(DType);


public: public:
@@ -114,23 +114,24 @@ public:
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DTYPE_ENUM_DECL) CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DTYPE_ENUM_DECL)
}; };


DType(const std::string& dtype);
DType(const char* dtype);
DType(const std::string& dtype, float scale, uint8_t zero_point = 0);
DType(const char* dtype, float scale, uint8_t zero_point = 0);
DType(DTypeEnum dtype);
DType(DTypeEnum dtype, float scale, uint8_t zero_point = 0);

std::string str(void) const;
DTypeEnum enumv() const;
float scale(void) const;
uint8_t zero_point(void) const;
MGE_WIN_DECLSPEC_FUC DType(const std::string& dtype);
MGE_WIN_DECLSPEC_FUC DType(const char* dtype);
MGE_WIN_DECLSPEC_FUC DType(
const std::string& dtype, float scale, uint8_t zero_point = 0);
MGE_WIN_DECLSPEC_FUC DType(const char* dtype, float scale, uint8_t zero_point = 0);
MGE_WIN_DECLSPEC_FUC DType(DTypeEnum dtype);
MGE_WIN_DECLSPEC_FUC DType(DTypeEnum dtype, float scale, uint8_t zero_point = 0);

MGE_WIN_DECLSPEC_FUC std::string str(void) const;
MGE_WIN_DECLSPEC_FUC DTypeEnum enumv() const;
MGE_WIN_DECLSPEC_FUC float scale(void) const;
MGE_WIN_DECLSPEC_FUC uint8_t zero_point(void) const;
template <typename T> template <typename T>
bool is_compatible(void) const;
MGE_WIN_DECLSPEC_FUC bool is_compatible(void) const;


static bool is_legal(const std::string& dtype);
static bool is_legal(const DTypeEnum& dtype);
static std::vector<std::string> legal_dtypes(void);
MGE_WIN_DECLSPEC_FUC static bool is_legal(const std::string& dtype);
MGE_WIN_DECLSPEC_FUC static bool is_legal(const DTypeEnum& dtype);
MGE_WIN_DECLSPEC_FUC static std::vector<std::string> legal_dtypes(void);


friend class Tensor; friend class Tensor;
friend bool operator==(const DType& lhs, const DType& rhs); friend bool operator==(const DType& lhs, const DType& rhs);
@@ -180,16 +181,16 @@ bool operator==(const std::string& lhs, const DType& rhs);
bool operator==(const char* lhs, const DType& rhs); bool operator==(const char* lhs, const DType& rhs);


class Format { class Format {
const void* impl() const;
Format(const void* impl);
MGE_WIN_DECLSPEC_FUC const void* impl() const;
MGE_WIN_DECLSPEC_FUC Format(const void* impl);
CUSTOM_PIMPL_CLS_DECL(Format); CUSTOM_PIMPL_CLS_DECL(Format);


public: public:
Format(const std::string& format);
Format(const char* format);
MGE_WIN_DECLSPEC_FUC Format(const std::string& format);
MGE_WIN_DECLSPEC_FUC Format(const char* format);


std::string str(void) const;
bool is_default(void) const;
MGE_WIN_DECLSPEC_FUC std::string str(void) const;
MGE_WIN_DECLSPEC_FUC bool is_default(void) const;


friend class Tensor; friend class Tensor;
CUSTOM_DATA_ADAPTOR_FRIEND_DECL; CUSTOM_DATA_ADAPTOR_FRIEND_DECL;
@@ -198,26 +199,26 @@ public:
class Tensor { class Tensor {
void* m_tensor; void* m_tensor;


const void* impl(void) const;
Tensor(const void* impl);
MGE_WIN_DECLSPEC_FUC const void* impl(void) const;
MGE_WIN_DECLSPEC_FUC Tensor(const void* impl);


const size_t* shapes_raw(void) const;
const ptrdiff_t* strides_raw(void) const;
MGE_WIN_DECLSPEC_FUC const size_t* shapes_raw(void) const;
MGE_WIN_DECLSPEC_FUC const ptrdiff_t* strides_raw(void) const;


public: public:
Tensor() = delete; Tensor() = delete;
Tensor(const Tensor& rhs);
Tensor& operator=(const Tensor& rhs);
Shape shape(void) const;
DType dtype(void) const;
Format format(void) const;
Device device(void) const;
size_t size(void) const;
std::vector<ptrdiff_t> stride(void) const;
float scale(void) const;
uint8_t zero_point(void) const;
MGE_WIN_DECLSPEC_FUC Tensor(const Tensor& rhs);
MGE_WIN_DECLSPEC_FUC Tensor& operator=(const Tensor& rhs);
MGE_WIN_DECLSPEC_FUC Shape shape(void) const;
MGE_WIN_DECLSPEC_FUC DType dtype(void) const;
MGE_WIN_DECLSPEC_FUC Format format(void) const;
MGE_WIN_DECLSPEC_FUC Device device(void) const;
MGE_WIN_DECLSPEC_FUC size_t size(void) const;
MGE_WIN_DECLSPEC_FUC std::vector<ptrdiff_t> stride(void) const;
MGE_WIN_DECLSPEC_FUC float scale(void) const;
MGE_WIN_DECLSPEC_FUC uint8_t zero_point(void) const;


void* data(void); void* data(void);
const void* data(void) const; const void* data(void) const;


+ 3
- 3
src/custom/include/megbrain/custom/utils.h View File

@@ -54,9 +54,9 @@ void impl_deleter(void* ptr) {
std::unique_ptr<void, void_deleter> m_impl; \ std::unique_ptr<void, void_deleter> m_impl; \
\ \
public: \ public: \
Cls(); \
Cls(const Cls& rhs); \
Cls& operator=(const Cls& rhs)
MGE_WIN_DECLSPEC_FUC Cls(); \
MGE_WIN_DECLSPEC_FUC Cls(const Cls& rhs); \
MGE_WIN_DECLSPEC_FUC Cls& operator=(const Cls& rhs)


#define CUSTOM_PIMPL_CLS_DEFINE(Cls) \ #define CUSTOM_PIMPL_CLS_DEFINE(Cls) \
Cls::Cls() : m_impl(new Cls##Impl(), impl_deleter<Cls##Impl>) {} \ Cls::Cls() : m_impl(new Cls##Impl(), impl_deleter<Cls##Impl>) {} \


+ 9
- 9
src/gopt/include/megbrain/gopt/framework.h View File

@@ -375,7 +375,7 @@ public:
~GraphOptimizer() noexcept; ~GraphOptimizer() noexcept;


//! add an optimization pass //! add an optimization pass
GraphOptimizer& add_pass(std::unique_ptr<Pass> pass);
MGE_WIN_DECLSPEC_FUC GraphOptimizer& add_pass(std::unique_ptr<Pass> pass);


//! add a pass with given type //! add a pass with given type
template <class Pass, typename... Params> template <class Pass, typename... Params>
@@ -415,14 +415,14 @@ public:
const ComputingGraph::Options* comp_graph_opt = nullptr); const ComputingGraph::Options* comp_graph_opt = nullptr);


//! transform given graph into a new optimized graph //! transform given graph into a new optimized graph
SubGraph apply(const SubGraph& graph) const;
MGE_WIN_DECLSPEC_FUC SubGraph apply(const SubGraph& graph) const;


/*! /*!
* \brief optimize graph defined by given endpoints and modify them * \brief optimize graph defined by given endpoints and modify them
* inplace * inplace
* \return *this * \return *this
*/ */
const GraphOptimizer& apply_inplace(VarNodeArray& vars) const;
MGE_WIN_DECLSPEC_FUC const GraphOptimizer& apply_inplace(VarNodeArray& vars) const;


/*! /*!
* \brief get var replace map associated with a computing graph * \brief get var replace map associated with a computing graph
@@ -431,14 +431,14 @@ public:
* Note that the map would be cleared when GraphOptimizer is applied * Note that the map would be cleared when GraphOptimizer is applied
* on the graph. * on the graph.
*/ */
static const ThinHashMap<VarNode*, VarNode*>& var_replace_map(
MGE_WIN_DECLSPEC_FUC static const ThinHashMap<VarNode*, VarNode*>& var_replace_map(
ComputingGraph& graph); ComputingGraph& graph);


/*! /*!
* \brief get the final replaced var in * \brief get the final replaced var in
* var_replace_map(var->owner_graph()) corresponding to var * var_replace_map(var->owner_graph()) corresponding to var
*/ */
static VarNode* var_replace_lookup(VarNode* var);
MGE_WIN_DECLSPEC_FUC static VarNode* var_replace_lookup(VarNode* var);


/** /**
* \brief add pass indicated by optimize options. * \brief add pass indicated by optimize options.
@@ -446,10 +446,10 @@ public:
* \param options common options * \param options common options
* \param reset if set true, it will reset options when add passes. * \param reset if set true, it will reset options when add passes.
*/ */
const GraphOptimizer& add_passes_for_optimize_options(
MGE_WIN_DECLSPEC_FUC const GraphOptimizer& add_passes_for_optimize_options(
cg::GraphCommonOptimizeOptions& options, bool reset = false); cg::GraphCommonOptimizeOptions& options, bool reset = false);


const GraphOptimizer& add_passes_for_optimize_options(
MGE_WIN_DECLSPEC_FUC const GraphOptimizer& add_passes_for_optimize_options(
const cg::GraphCommonOptimizeOptions& options); const cg::GraphCommonOptimizeOptions& options);


/** /**
@@ -457,7 +457,7 @@ public:
* *
* \param options graph tuning options * \param options graph tuning options
*/ */
const GraphOptimizer& add_passes_for_graph_tuning_options(
MGE_WIN_DECLSPEC_FUC const GraphOptimizer& add_passes_for_graph_tuning_options(
const GraphTuningOptions& options); const GraphTuningOptions& options);
}; };


@@ -491,7 +491,7 @@ public:
bool all_const_inp; bool all_const_inp;
}; };


AddOprResult add_opr(OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC AddOprResult add_opr(OperatorNodeBase* opr);


const AddOprResult& opr_rst(OperatorNodeBase* opr) const { const AddOprResult& opr_rst(OperatorNodeBase* opr) const {
return m_oprinfo.at(opr).result; return m_oprinfo.at(opr).result;


+ 8
- 6
src/gopt/include/megbrain/gopt/inference.h View File

@@ -384,7 +384,7 @@ struct GraphTuningOptions {
* This function applies a set of predefined optimizer passes to optimize * This function applies a set of predefined optimizer passes to optimize
* for inference. It assumes all params are constant. * for inference. It assumes all params are constant.
*/ */
SymbolVarArray optimize_for_inference(
MGE_WIN_DECLSPEC_FUC SymbolVarArray optimize_for_inference(
const SymbolVarArray& dest_vars, const OptimizeForInferenceOptions& opt = {}); const SymbolVarArray& dest_vars, const OptimizeForInferenceOptions& opt = {});


/*! /*!
@@ -393,7 +393,7 @@ SymbolVarArray optimize_for_inference(
* The layout selection optimizers are target-dependent. And this function * The layout selection optimizers are target-dependent. And this function
* applies a set of predefined optimizer passes designed for specific * applies a set of predefined optimizer passes designed for specific
* device. */ * device. */
SymbolVarArray layout_transform(
MGE_WIN_DECLSPEC_FUC SymbolVarArray layout_transform(
const SymbolVarArray& dest_vars, const SymbolVarArray& dest_vars,
GraphTuningOptions::Target target = GraphTuningOptions::Target::UNSPEC); GraphTuningOptions::Target target = GraphTuningOptions::Target::UNSPEC);


@@ -404,7 +404,7 @@ SymbolVarArray layout_transform(
* This would modify the operators inplace. It can be used for implement * This would modify the operators inplace. It can be used for implement
* the fast-run mode. * the fast-run mode.
*/ */
void modify_opr_algo_strategy_inplace(
MGE_WIN_DECLSPEC_FUC void modify_opr_algo_strategy_inplace(
const VarNodeArrayView& dest_vars, const VarNodeArrayView& dest_vars,
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy); opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy);


@@ -418,7 +418,8 @@ void modify_opr_algo_strategy_inplace(
* You may want to implement TimedFuncInvoker::ForkExecImpl and/or * You may want to implement TimedFuncInvoker::ForkExecImpl and/or
* PersistentCache for better performance in an SDK. * PersistentCache for better performance in an SDK.
*/ */
void enable_opr_algo_profiling_inplace(const VarNodeArrayView& dest_vars);
MGE_WIN_DECLSPEC_FUC void enable_opr_algo_profiling_inplace(
const VarNodeArrayView& dest_vars);


/*! /*!
* \brief enable opr try profiling cache first, if failed, fallback to * \brief enable opr try profiling cache first, if failed, fallback to
@@ -430,7 +431,8 @@ void enable_opr_algo_profiling_inplace(const VarNodeArrayView& dest_vars);
* You may want to implement TimedFuncInvoker::ForkExecImpl and/or * You may want to implement TimedFuncInvoker::ForkExecImpl and/or
* PersistentCache for better performance in an SDK. * PersistentCache for better performance in an SDK.
*/ */
void enable_opr_use_profiling_cache_inplace(const VarNodeArrayView& dest_vars);
MGE_WIN_DECLSPEC_FUC void enable_opr_use_profiling_cache_inplace(
const VarNodeArrayView& dest_vars);


/*! /*!
* \brief set workspace_limit for execution strategy for oprs with multiple * \brief set workspace_limit for execution strategy for oprs with multiple
@@ -442,7 +444,7 @@ void enable_opr_use_profiling_cache_inplace(const VarNodeArrayView& dest_vars);
* \warning It will influence the default algo choosed, and maybe slower but * \warning It will influence the default algo choosed, and maybe slower but
* save memory. * save memory.
*/ */
void set_opr_algo_workspace_limit_inplace(
MGE_WIN_DECLSPEC_FUC void set_opr_algo_workspace_limit_inplace(
const VarNodeArrayView& dest_vars, size_t workspace_limit); const VarNodeArrayView& dest_vars, size_t workspace_limit);


/*! /*!


+ 39
- 19
src/megbrain_build_config.h.in View File

@@ -60,37 +60,37 @@


// whether atlas is available // whether atlas is available
#ifndef MGB_ATLAS #ifndef MGB_ATLAS
#define MGB_ATLAS 0
#define MGB_ATLAS 0
#endif #endif


// whether cuda is available // whether cuda is available
#ifndef MGB_CUDA #ifndef MGB_CUDA
#define MGB_CUDA 1
#define MGB_CUDA 1
#endif #endif


// whether to include file/line location for assert message // whether to include file/line location for assert message
#ifndef MGB_ASSERT_LOC #ifndef MGB_ASSERT_LOC
#define MGB_ASSERT_LOC 1
#define MGB_ASSERT_LOC 1
#endif #endif


// whether to enable utils/debug.h and other debug methods // whether to enable utils/debug.h and other debug methods
#ifndef MGB_ENABLE_DEBUG_UTIL #ifndef MGB_ENABLE_DEBUG_UTIL
#define MGB_ENABLE_DEBUG_UTIL 1
#define MGB_ENABLE_DEBUG_UTIL 1
#endif #endif


// whether to enable logging // whether to enable logging
#ifndef MGB_ENABLE_LOGGING #ifndef MGB_ENABLE_LOGGING
#define MGB_ENABLE_LOGGING 1
#define MGB_ENABLE_LOGGING 1
#endif #endif


// whether to enable registering opr grad functions // whether to enable registering opr grad functions
#ifndef MGB_ENABLE_GRAD #ifndef MGB_ENABLE_GRAD
#define MGB_ENABLE_GRAD 1
#define MGB_ENABLE_GRAD 1
#endif #endif


// whether to enable cpuinfo // whether to enable cpuinfo
#ifndef MGB_ENABLE_CPUINFO #ifndef MGB_ENABLE_CPUINFO
#define MGB_ENABLE_CPUINFO 1
#define MGB_ENABLE_CPUINFO 1
#endif #endif


//! use one MACRO indicate enable_arm_dotprod //! use one MACRO indicate enable_arm_dotprod
@@ -101,7 +101,6 @@
#define MGB_ENABLE_DOT 1 #define MGB_ENABLE_DOT 1
#endif #endif



//! ENABLE MGB DOT should enable CPUINFO //! ENABLE MGB DOT should enable CPUINFO
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#if !defined(MGB_ENABLE_CPUINFO) || !MGB_ENABLE_CPUINFO #if !defined(MGB_ENABLE_CPUINFO) || !MGB_ENABLE_CPUINFO
@@ -115,38 +114,38 @@
//! IOS disabled cpuinfo and dotprod, cpuinfo has some problem on ios //! IOS disabled cpuinfo and dotprod, cpuinfo has some problem on ios
#ifdef IOS #ifdef IOS
#undef MGB_ENABLE_CPUINFO #undef MGB_ENABLE_CPUINFO
#define MGB_ENABLE_CPUINFO 0
#define MGB_ENABLE_CPUINFO 0
#undef MGB_ENABLE_DOT #undef MGB_ENABLE_DOT
#endif #endif


// whether to include actual class name in mgb::Typeinfo object; if this is // whether to include actual class name in mgb::Typeinfo object; if this is
// disabled, mgb::serialization::OprRegistry::find_opr_by_name would not work. // disabled, mgb::serialization::OprRegistry::find_opr_by_name would not work.
#ifndef MGB_VERBOSE_TYPEINFO_NAME #ifndef MGB_VERBOSE_TYPEINFO_NAME
#define MGB_VERBOSE_TYPEINFO_NAME 1
#define MGB_VERBOSE_TYPEINFO_NAME 1
#endif #endif


// whether to enbale configuing megbrain internals through env vars // whether to enbale configuing megbrain internals through env vars
#ifndef MGB_ENABLE_GETENV #ifndef MGB_ENABLE_GETENV
#define MGB_ENABLE_GETENV MGB_ASSERT_LOC
#define MGB_ENABLE_GETENV MGB_ASSERT_LOC
#endif #endif


// whether to remove unnecessary features when used for serving // whether to remove unnecessary features when used for serving
#ifndef MGB_BUILD_SLIM_SERVING #ifndef MGB_BUILD_SLIM_SERVING
#define MGB_BUILD_SLIM_SERVING 0
#define MGB_BUILD_SLIM_SERVING 0
#endif #endif


// whether to enable exception // whether to enable exception
#ifndef MGB_ENABLE_EXCEPTION #ifndef MGB_ENABLE_EXCEPTION
#if __EXCEPTIONS #if __EXCEPTIONS
#define MGB_ENABLE_EXCEPTION 1
#define MGB_ENABLE_EXCEPTION 1
#else #else
#define MGB_ENABLE_EXCEPTION 0
#define MGB_ENABLE_EXCEPTION 0
#endif #endif
#endif #endif


// whether <thread> is available and usable // whether <thread> is available and usable
#ifndef MGB_HAVE_THREAD #ifndef MGB_HAVE_THREAD
#define MGB_HAVE_THREAD 1
#define MGB_HAVE_THREAD 1
#endif #endif


// whether to trade thread safety for memory usage // whether to trade thread safety for memory usage
@@ -156,7 +155,7 @@


// whether to enable JIT // whether to enable JIT
#ifndef MGB_JIT #ifndef MGB_JIT
#define MGB_JIT 1
#define MGB_JIT 1
#endif #endif
#ifndef MGB_JIT_HALIDE #ifndef MGB_JIT_HALIDE
#define MGB_JIT_HALIDE 0 #define MGB_JIT_HALIDE 0
@@ -174,10 +173,9 @@
#define MGB_CAMBRICON MEGDNN_WITH_CAMBRICON #define MGB_CAMBRICON MEGDNN_WITH_CAMBRICON
#endif #endif



// whether to enable TensorRT support // whether to enable TensorRT support
#ifndef MGB_ENABLE_TENSOR_RT #ifndef MGB_ENABLE_TENSOR_RT
#define MGB_ENABLE_TENSOR_RT MGB_CUDA
#define MGB_ENABLE_TENSOR_RT MGB_CUDA
#endif #endif


// whether to enable fastrun profile // whether to enable fastrun profile
@@ -252,4 +250,26 @@
#define MEGDNN_X86_WITH_MKL_DNN 0 #define MEGDNN_X86_WITH_MKL_DNN 0
#endif #endif


#endif // _HEADER_MGB_BUILD_CONFIG
#ifdef WIN32
#ifdef MGE_DLL_EXPORT
#define MGE_WIN_DECLSPEC_FUC __declspec(dllexport)
#else
#define MGE_WIN_DECLSPEC_FUC
#endif
#else
#define MGE_WIN_DECLSPEC_FUC
#endif

#ifdef WIN32
#if defined(MGE_DLL_EXPORT_DATA)
#define MGE_WIN_DECLSPEC_DATA __declspec(dllexport)
#elif defined(MGE_DLL_IMPORT_DATA)
#define MGE_WIN_DECLSPEC_DATA __declspec(dllimport)
#else
#define MGE_WIN_DECLSPEC_DATA
#endif
#else
#define MGE_WIN_DECLSPEC_DATA
#endif

#endif // _HEADER_MGB_BUILD_CONFIG

+ 20
- 17
src/opr/include/megbrain/opr/basic_arith.h View File

@@ -58,22 +58,22 @@ public:
* The operands are broadcasted automatically on dimensions of shape one to * The operands are broadcasted automatically on dimensions of shape one to
* match shapes of each other; it works like broadcasting in numpy. * match shapes of each other; it works like broadcasting in numpy.
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Elemwise, intl::ElemwiseBase, mixin::FwdIn2OutWritableHelper) // { Elemwise, intl::ElemwiseBase, mixin::FwdIn2OutWritableHelper) // {
using ModeTrait = megdnn::Elemwise::ModeTrait; using ModeTrait = megdnn::Elemwise::ModeTrait;


public: public:
using Mode = Param::Mode; using Mode = Param::Mode;


Elemwise(
MGE_WIN_DECLSPEC_FUC Elemwise(
const ModeTrait& mode_trait, const VarNodeArrayView& inputs, Param param, const ModeTrait& mode_trait, const VarNodeArrayView& inputs, Param param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
const VarNodeArrayView& inputs, Param param, const VarNodeArrayView& inputs, Param param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


static TensorShape get_output_var_shape(
MGE_WIN_DECLSPEC_FUC static TensorShape get_output_var_shape(
Mode mode, const TensorShapeArray& input_shapes); Mode mode, const TensorShapeArray& input_shapes);


/*! /*!
@@ -84,7 +84,7 @@ public:
* \param opr the megdnn operator to be used; a new operator would be * \param opr the megdnn operator to be used; a new operator would be
* created if it is null * created if it is null
*/ */
static void perform(
MGE_WIN_DECLSPEC_FUC static void perform(
Mode mode, DeviceTensorND& dest, const SmallVector<DeviceTensorND>& inputs, Mode mode, DeviceTensorND& dest, const SmallVector<DeviceTensorND>& inputs,
intl::UniqPtrWithCN<megdnn::Elemwise>& opr); intl::UniqPtrWithCN<megdnn::Elemwise>& opr);


@@ -98,10 +98,12 @@ public:
* \param layouts the layouts to be collectively collapsed * \param layouts the layouts to be collectively collapsed
* *
*/ */
static TensorLayoutArray collective_collapse(const TensorLayoutArray& layouts);
MGE_WIN_DECLSPEC_FUC static TensorLayoutArray collective_collapse(
const TensorLayoutArray& layouts);


//! like collective_collapse(), but modify the layouts inplace //! like collective_collapse(), but modify the layouts inplace
static void collective_collapse_inplace(const TensorLayoutPtrArray& layouts);
MGE_WIN_DECLSPEC_FUC static void collective_collapse_inplace(
const TensorLayoutPtrArray& layouts);


/*! /*!
* \brief wapper for broadcast and collective collapse * \brief wapper for broadcast and collective collapse
@@ -111,7 +113,7 @@ public:
* \param[in,out] target_layout broadcasted target layout; it would be * \param[in,out] target_layout broadcasted target layout; it would be
* collapsed together with inputs * collapsed together with inputs
*/ */
static void broadcast_collective_collapse(
MGE_WIN_DECLSPEC_FUC static void broadcast_collective_collapse(
const TensorLayoutPtrArray& inp_layouts, TensorLayout* target_layout); const TensorLayoutPtrArray& inp_layouts, TensorLayout* target_layout);


/*! /*!
@@ -128,7 +130,8 @@ public:
* \param[in,out] grads vars to be summed; it is also an output param, * \param[in,out] grads vars to be summed; it is also an output param,
* which would contain all the intermediate results for summing * which would contain all the intermediate results for summing
*/ */
static VarNode* sum_grad_list(VarNode* wrt, VarNodeArray& grads);
MGE_WIN_DECLSPEC_FUC static VarNode* sum_grad_list(
VarNode* wrt, VarNodeArray& grads);


//! whether input layouts mismatch ever happened for fused oprs; this //! whether input layouts mismatch ever happened for fused oprs; this
//! method is public for debug purpose //! method is public for debug purpose
@@ -163,11 +166,11 @@ using TypeCvtBase = cg::OutshapePureByInshapeOpr<
cg::mixin::IOSameShapeOperatorNode>; cg::mixin::IOSameShapeOperatorNode>;
} }


MGB_DEFINE_OPR_CLASS(TypeCvt, intl::TypeCvtBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(TypeCvt, intl::TypeCvtBase) // {
public: public:
TypeCvt(VarNode* inp, DType dest_type, const OperatorNodeConfig& config); TypeCvt(VarNode* inp, DType dest_type, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar input, DType dest_type, const OperatorNodeConfig& config = {}); SymbolVar input, DType dest_type, const OperatorNodeConfig& config = {});


static void perform( static void perform(
@@ -200,7 +203,7 @@ private:
* Attention: AddUpdate will not be executed if disable flag is set to 1, * Attention: AddUpdate will not be executed if disable flag is set to 1,
* this is used for dynamic param-updating. * this is used for dynamic param-updating.
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
AddUpdate, cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolder>) // { AddUpdate, cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolder>) // {
public: public:
using SharedScalar = std::shared_ptr<DTypeScalar>; using SharedScalar = std::shared_ptr<DTypeScalar>;
@@ -235,7 +238,7 @@ public:
VarNode* dest, VarNode* delta, const Param& param, VarNode* dest, VarNode* delta, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar dest, SymbolVar delta, const Param& param = {}, SymbolVar dest, SymbolVar delta, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -256,7 +259,7 @@ private:
* Mode specifies the actual arithmetic; and exactly one of *axis* and * Mode specifies the actual arithmetic; and exactly one of *axis* and
* *target_shape* must be provided, to specify output shape. * *target_shape* must be provided, to specify output shape.
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Reduce, intl::DynamicOutputIfInputDynamic< Reduce, intl::DynamicOutputIfInputDynamic<
intl::OutshapeBySymvarSCNOpr<mixin::MegDNNOprHolder>>) // { intl::OutshapeBySymvarSCNOpr<mixin::MegDNNOprHolder>>) // {
public: public:
@@ -269,7 +272,7 @@ public:


const Param& param() const { return m_param; } const Param& param() const { return m_param; }


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, Param param, SymbolVar target_shape = {}, SymbolVar src, Param param, SymbolVar target_shape = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -317,10 +320,10 @@ private:
* graph with only Elemwise::Mode::POW, and this opr should only be inserted by * graph with only Elemwise::Mode::POW, and this opr should only be inserted by
* the optimizer. * the optimizer.
*/ */
MGB_DEFINE_OPR_CLASS(PowC, intl::MegDNNOprWrapperFwd<megdnn::PowC>) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(PowC, intl::MegDNNOprWrapperFwd<megdnn::PowC>) // {
public: public:
PowC(VarNode* inp, const Param& param, const OperatorNodeConfig& config); PowC(VarNode* inp, const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar inp, const Param& param = {}, SymbolVar inp, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});




+ 10
- 8
src/opr/include/megbrain/opr/blas.h View File

@@ -31,11 +31,11 @@ MGB_DEFINE_OPR_CLASS(
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {
public: public:
using AlgorithmInfo = megdnn::detail::Algorithm::Info; using AlgorithmInfo = megdnn::detail::Algorithm::Info;
MatrixMul(
MGE_WIN_DECLSPEC_FUC MatrixMul(
VarNode* opr0, VarNode* opr1, const Param& param, VarNode* opr0, VarNode* opr1, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar opr0, SymbolVar opr1, const Param& param = {}, SymbolVar opr0, SymbolVar opr1, const Param& param = {},
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});


@@ -62,11 +62,11 @@ MGB_DEFINE_OPR_CLASS(
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {
public: public:
using AlgorithmInfo = megdnn::detail::Algorithm::Info; using AlgorithmInfo = megdnn::detail::Algorithm::Info;
BatchedMatrixMul(
MGE_WIN_DECLSPEC_FUC BatchedMatrixMul(
VarNode* opr0, VarNode* opr1, const Param& param, VarNode* opr0, VarNode* opr1, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar opr0, SymbolVar opr1, const Param& param = {}, SymbolVar opr0, SymbolVar opr1, const Param& param = {},
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});


@@ -91,9 +91,10 @@ private:
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
Dot, cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolderImpl<megdnn::Dot>>) // { Dot, cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolderImpl<megdnn::Dot>>) // {
public: public:
Dot(VarNode* opr0, VarNode* opr1, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC Dot(
VarNode* opr0, VarNode* opr1, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar opr0, SymbolVar opr1, const OperatorNodeConfig& config = {}); SymbolVar opr0, SymbolVar opr1, const OperatorNodeConfig& config = {});


// for serialization // for serialization
@@ -115,8 +116,9 @@ MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(MatrixInverse);


MGB_DEFINE_OPR_CLASS(SVD, intl::MegDNNOprWrapperFwd<megdnn::SVD>) // { MGB_DEFINE_OPR_CLASS(SVD, intl::MegDNNOprWrapperFwd<megdnn::SVD>) // {
public: public:
SVD(VarNode * src, const Param& param, const OperatorNodeConfig& config);
static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC SVD(
VarNode * src, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
const SymbolVar& src, const Param& param = {}, const SymbolVar& src, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };


+ 13
- 13
src/opr/include/megbrain/opr/custom_opnode.h View File

@@ -84,26 +84,26 @@ MGB_DEFINE_OPR_CLASS(CustomOpNode, cg::OperatorNodeBase) // {
bool update_priority() const override final; bool update_priority() const override final;


public: public:
CustomOpNode(
MGE_WIN_DECLSPEC_FUC CustomOpNode(
const std::shared_ptr<const custom::CustomOp>& op, VarNodeArray inputs, const std::shared_ptr<const custom::CustomOp>& op, VarNodeArray inputs,
const custom::Param& param, const OperatorNodeConfig& config); const custom::Param& param, const OperatorNodeConfig& config);
static VarNodeArray make(
MGE_WIN_DECLSPEC_FUC static VarNodeArray make(
const std::shared_ptr<const custom::CustomOp>& op, VarNodeArray inputs, const std::shared_ptr<const custom::CustomOp>& op, VarNodeArray inputs,
const custom::Param& param, const OperatorNodeConfig& config); const custom::Param& param, const OperatorNodeConfig& config);
static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
const std::shared_ptr<const custom::CustomOp>& op, SymbolVarArray inputs, const std::shared_ptr<const custom::CustomOp>& op, SymbolVarArray inputs,
const custom::Param& param, const OperatorNodeConfig& config); const custom::Param& param, const OperatorNodeConfig& config);


custom::RunTimeId runtime_id(void) const;
uint32_t param_tag(void) const;
custom::Param& param(void);
custom::Param param(void) const;
std::string op_type(void) const;
std::string op_desc(void) const;
size_t input_num(void) const;
size_t output_num(void) const;
custom::ArgInfo input_info(size_t idx) const;
custom::ArgInfo output_info(size_t idx) const;
MGE_WIN_DECLSPEC_FUC custom::RunTimeId runtime_id(void) const;
MGE_WIN_DECLSPEC_FUC uint32_t param_tag(void) const;
MGE_WIN_DECLSPEC_FUC custom::Param& param(void);
MGE_WIN_DECLSPEC_FUC custom::Param param(void) const;
MGE_WIN_DECLSPEC_FUC std::string op_type(void) const;
MGE_WIN_DECLSPEC_FUC std::string op_desc(void) const;
MGE_WIN_DECLSPEC_FUC size_t input_num(void) const;
MGE_WIN_DECLSPEC_FUC size_t output_num(void) const;
MGE_WIN_DECLSPEC_FUC custom::ArgInfo input_info(size_t idx) const;
MGE_WIN_DECLSPEC_FUC custom::ArgInfo output_info(size_t idx) const;
}; };


} // namespace opr } // namespace opr


+ 4
- 4
src/opr/include/megbrain/opr/dnn/adaptive_pooling.h View File

@@ -25,10 +25,10 @@ MGB_DEFINE_OPR_CLASS(
intl::WorkspaceSizeInfer<intl::OutshapeBySymvarSCNOpr< intl::WorkspaceSizeInfer<intl::OutshapeBySymvarSCNOpr<
mixin::MegDNNOprHolderImpl<megdnn::AdaptivePoolingForward>>>) // { mixin::MegDNNOprHolderImpl<megdnn::AdaptivePoolingForward>>>) // {
public: public:
AdaptivePoolingForward(
MGE_WIN_DECLSPEC_FUC AdaptivePoolingForward(
VarNode* src, VarNode* out_shape, const Param& param, VarNode* src, VarNode* out_shape, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar out_shape, const Param& param, SymbolVar src, SymbolVar out_shape, const Param& param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
static SymbolVar make( static SymbolVar make(
@@ -55,10 +55,10 @@ MGB_DEFINE_OPR_CLASS(
AdaptivePoolingBackward, AdaptivePoolingBackward,
intl::MegDNNOprWrapperBwd<megdnn::AdaptivePoolingBackward>) // { intl::MegDNNOprWrapperBwd<megdnn::AdaptivePoolingBackward>) // {
public: public:
AdaptivePoolingBackward(
MGE_WIN_DECLSPEC_FUC AdaptivePoolingBackward(
VarNode* src, VarNode* out_shape, VarNode* dst, VarNode* diff, VarNode* src, VarNode* out_shape, VarNode* dst, VarNode* diff,
const Param& param, const OperatorNodeConfig& config); const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar out_shape, SymbolVar dst, SymbolVar diff, SymbolVar src, SymbolVar out_shape, SymbolVar dst, SymbolVar diff,
const Param& param, const OperatorNodeConfig& config = {}); const Param& param, const OperatorNodeConfig& config = {});




+ 8
- 8
src/opr/include/megbrain/opr/dnn/batch_norm.h View File

@@ -39,26 +39,26 @@ namespace opr {
* Output reserve is used for cudnnBatchNormalizationForwardTrainingEx, and should * Output reserve is used for cudnnBatchNormalizationForwardTrainingEx, and should
* be preserved for backward. * be preserved for backward.
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
BatchNormForward, BatchNormForward,
cg::OutshapePureByInshapeOpr< cg::OutshapePureByInshapeOpr<
intl::WorkspaceSizeInfer<cg::SingleCNOperatorNodeBaseT< intl::WorkspaceSizeInfer<cg::SingleCNOperatorNodeBaseT<
mixin::MegDNNOprHolderImpl<megdnn::BN>>>>) // { mixin::MegDNNOprHolderImpl<megdnn::BN>>>>) // {
public: public:
BatchNormForward(
MGE_WIN_DECLSPEC_FUC BatchNormForward(
VarNode* x, VarNode* scale, VarNode* bias, VarNode* mean, VarNode* variance, VarNode* x, VarNode* scale, VarNode* bias, VarNode* mean, VarNode* variance,
const Param& param, const OperatorNodeConfig& config); const Param& param, const OperatorNodeConfig& config);


BatchNormForward(
MGE_WIN_DECLSPEC_FUC BatchNormForward(
VarNode* x, VarNode* scale, VarNode* bias, const Param& param, VarNode* x, VarNode* scale, VarNode* bias, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar x, SymbolVar scale, SymbolVar bias, SymbolVar mean, SymbolVar x, SymbolVar scale, SymbolVar bias, SymbolVar mean,
SymbolVar variance, const Param& param = {}, SymbolVar variance, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar x, SymbolVar scale, SymbolVar bias, const Param& param = {}, SymbolVar x, SymbolVar scale, SymbolVar bias, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -93,14 +93,14 @@ using BatchNorm = BatchNormForward;
* scale_grad, bias_grad, x_grad * scale_grad, bias_grad, x_grad
*/ */


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
BatchNormBackward, intl::MegDNNOprWrapperBwd<megdnn::BNBackward>) // { BatchNormBackward, intl::MegDNNOprWrapperBwd<megdnn::BNBackward>) // {
public: public:
BatchNormBackward(
MGE_WIN_DECLSPEC_FUC BatchNormBackward(
VarNode* x, VarNode* y_grad, VarNode* save_mean, VarNode* save_variance, VarNode* x, VarNode* y_grad, VarNode* save_mean, VarNode* save_variance,
VarNode* scale, VarNode* reserve, const Param& param, VarNode* scale, VarNode* reserve, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar x, SymbolVar y_grad, SymbolVar save_mean, SymbolVar save_variance, SymbolVar x, SymbolVar y_grad, SymbolVar save_mean, SymbolVar save_variance,
SymbolVar scale, SymbolVar reserve, const Param& param = {}, SymbolVar scale, SymbolVar reserve, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


+ 63
- 62
src/opr/include/megbrain/opr/dnn/convolution.h View File

@@ -93,7 +93,7 @@ class ConvolutionTestingPeer;


} // namespace testing } // namespace testing


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
ConvolutionForward, intl::ConvolutionForwardBase, ConvolutionForward, intl::ConvolutionForwardBase,
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {
void init_output_dtype() override; void init_output_dtype() override;
@@ -114,17 +114,17 @@ MGB_DEFINE_OPR_CLASS(
friend testing::ConvolutionTestingPeer; friend testing::ConvolutionTestingPeer;


public: public:
ConvolutionForward(
MGE_WIN_DECLSPEC_FUC ConvolutionForward(
VarNode* src, VarNode* filter, const Param& param, VarNode* src, VarNode* filter, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, const Param& param = {}, SymbolVar src, SymbolVar filter, const Param& param = {},
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});
}; };
using Convolution = ConvolutionForward; using Convolution = ConvolutionForward;


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
ConvBiasForward, intl::ConvBiasForwardBase, public mixin::AlgoChooserHelper) // { ConvBiasForward, intl::ConvBiasForwardBase, public mixin::AlgoChooserHelper) // {
void init_output_dtype() override; void init_output_dtype() override;
size_t get_workspace_size_bytes( size_t get_workspace_size_bytes(
@@ -147,37 +147,37 @@ MGB_DEFINE_OPR_CLASS(


public: public:
//! src * filter //! src * filter
ConvBiasForward(
MGE_WIN_DECLSPEC_FUC ConvBiasForward(
VarNode* src, VarNode* filter, const Param& param, VarNode* src, VarNode* filter, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, const Param& param = {}, SymbolVar src, SymbolVar filter, const Param& param = {},
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});


//! src * filter + bias //! src * filter + bias
ConvBiasForward(
MGE_WIN_DECLSPEC_FUC ConvBiasForward(
VarNode* src, VarNode* filter, VarNode* bias, const Param& param, VarNode* src, VarNode* filter, VarNode* bias, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, SymbolVar bias, const Param& param = {}, SymbolVar src, SymbolVar filter, SymbolVar bias, const Param& param = {},
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});


//! src * filter + bias + z //! src * filter + bias + z
ConvBiasForward(
MGE_WIN_DECLSPEC_FUC ConvBiasForward(
VarNode* src, VarNode* filter, VarNode* bias, VarNode* z, VarNode* src, VarNode* filter, VarNode* bias, VarNode* z,
const Param& param, const ExecutionPolicy& policy, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, SymbolVar bias, SymbolVar z, SymbolVar src, SymbolVar filter, SymbolVar bias, SymbolVar z,
const Param& param = {}, const ExecutionPolicy& policy = {}, const Param& param = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


static void check_winograd_param_valid(
MGE_WIN_DECLSPEC_FUC static void check_winograd_param_valid(
const megdnn::ConvBias::WinogradParam& param, const DType& dtype); const megdnn::ConvBias::WinogradParam& param, const DType& dtype);
static megdnn::param::MatrixMul::Format get_matmul_format(
MGE_WIN_DECLSPEC_FUC static megdnn::param::MatrixMul::Format get_matmul_format(
const megdnn::ConvBias::WinogradParam& param); const megdnn::ConvBias::WinogradParam& param);
}; };
using ConvBias = ConvBiasForward; using ConvBias = ConvBiasForward;
@@ -185,7 +185,7 @@ using ConvBias = ConvBiasForward;
/*! /*!
* \brief Can be used in two ways: compute gradient of conv, or deconv * \brief Can be used in two ways: compute gradient of conv, or deconv
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
ConvolutionBackwardData, ConvolutionBackwardData,
cg::SingleCNOperatorNodeBaseT< cg::SingleCNOperatorNodeBaseT<
mixin::MegDNNOprHolderImpl<megdnn::ConvolutionBackwardData>>, mixin::MegDNNOprHolderImpl<megdnn::ConvolutionBackwardData>>,
@@ -200,30 +200,30 @@ MGB_DEFINE_OPR_CLASS(
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;


public: public:
ConvolutionBackwardData(
MGE_WIN_DECLSPEC_FUC ConvolutionBackwardData(
VarNode* filter, VarNode* diff, VarNode* src_for_shp, const Param& param, VarNode* filter, VarNode* diff, VarNode* src_for_shp, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


//! grad mode; original data shape is required //! grad mode; original data shape is required
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar filter, SymbolVar diff, SymbolVar src_for_shp, SymbolVar filter, SymbolVar diff, SymbolVar src_for_shp,
const Param& param = {}, const ExecutionPolicy& policy = {}, const Param& param = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


//! sereg for deconvolution mode //! sereg for deconvolution mode
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar filter, SymbolVar data, const Param& param, SymbolVar filter, SymbolVar data, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


//! user interface for deconv //! user interface for deconv
static SymbolVar make_deconv(
MGE_WIN_DECLSPEC_FUC static SymbolVar make_deconv(
SymbolVar data, SymbolVar filter, const Param& param = {}, SymbolVar data, SymbolVar filter, const Param& param = {},
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}) { const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}) {
return make(filter, data, param, policy, config); return make(filter, data, param, policy, config);
} }
}; };


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
ConvolutionBackwardFilter, ConvolutionBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::ConvolutionBackwardFilter>, intl::MegDNNOprWrapperBwd<megdnn::ConvolutionBackwardFilter>,
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {
@@ -232,40 +232,41 @@ MGB_DEFINE_OPR_CLASS(
const TensorShapeArray& output_shapes) const override final; const TensorShapeArray& output_shapes) const override final;


public: public:
ConvolutionBackwardFilter(
MGE_WIN_DECLSPEC_FUC ConvolutionBackwardFilter(
VarNode* src, VarNode* diff, VarNode* filter, const Param& param, VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar diff, SymbolVar filter, const Param& param, SymbolVar src, SymbolVar diff, SymbolVar filter, const Param& param,
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});
}; };


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
MaskConvolution, intl::MegDNNOprWrapperFwd<megdnn::MaskConvolution>) // { MaskConvolution, intl::MegDNNOprWrapperFwd<megdnn::MaskConvolution>) // {
void init_output_dtype() override final; void init_output_dtype() override final;


public: public:
MaskConvolution(
MGE_WIN_DECLSPEC_FUC MaskConvolution(
VarNode* src, VarNode* filter, VarNode* mask, const Param& param, VarNode* src, VarNode* filter, VarNode* mask, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, SymbolVar mask, const Param& param, SymbolVar src, SymbolVar filter, SymbolVar mask, const Param& param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
MaskPropagate, intl::MegDNNOprWrapperFwd<megdnn::MaskPropagate>) // { MaskPropagate, intl::MegDNNOprWrapperFwd<megdnn::MaskPropagate>) // {
void init_output_dtype() override final; void init_output_dtype() override final;


public: public:
MaskPropagate(VarNode* src, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC MaskPropagate(
VarNode* src, const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); SymbolVar src, const Param& param, const OperatorNodeConfig& config = {});
}; };


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Convolution3DForward, intl::MegDNNOprWrapperFwd<megdnn::Convolution3DForward>, Convolution3DForward, intl::MegDNNOprWrapperFwd<megdnn::Convolution3DForward>,
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {
void init_output_dtype() override; void init_output_dtype() override;
@@ -274,11 +275,11 @@ MGB_DEFINE_OPR_CLASS(
const TensorShapeArray& output_shapes) const override final; const TensorShapeArray& output_shapes) const override final;


public: public:
Convolution3DForward(
MGE_WIN_DECLSPEC_FUC Convolution3DForward(
VarNode* src, VarNode* filter, const Param& param, VarNode* src, VarNode* filter, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, const Param& param = {}, SymbolVar src, SymbolVar filter, const Param& param = {},
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});
}; };
@@ -287,7 +288,7 @@ using Convolution3D = Convolution3DForward;
/*! /*!
* \brief Can be used in two ways: compute gradient of conv, or deconv * \brief Can be used in two ways: compute gradient of conv, or deconv
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Convolution3DBackwardData, Convolution3DBackwardData,
cg::SingleCNOperatorNodeBaseT< cg::SingleCNOperatorNodeBaseT<
mixin::MegDNNOprHolderImpl<megdnn::Convolution3DBackwardData>>, mixin::MegDNNOprHolderImpl<megdnn::Convolution3DBackwardData>>,
@@ -300,18 +301,18 @@ MGB_DEFINE_OPR_CLASS(
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;


public: public:
Convolution3DBackwardData(
MGE_WIN_DECLSPEC_FUC Convolution3DBackwardData(
VarNode* filter, VarNode* diff, VarNode* src_for_shp, const Param& param, VarNode* filter, VarNode* diff, VarNode* src_for_shp, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


//! grad mode; original data shape is required //! grad mode; original data shape is required
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar filter, SymbolVar diff, SymbolVar src_for_shp, SymbolVar filter, SymbolVar diff, SymbolVar src_for_shp,
const Param& param = {}, const ExecutionPolicy& policy = {}, const Param& param = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


//! sereg for deconvolution3D mode //! sereg for deconvolution3D mode
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar filter, SymbolVar data, const Param& param, SymbolVar filter, SymbolVar data, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


@@ -323,7 +324,7 @@ public:
} }
}; };


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Convolution3DBackwardFilter, Convolution3DBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::Convolution3DBackwardFilter>, intl::MegDNNOprWrapperBwd<megdnn::Convolution3DBackwardFilter>,
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {
@@ -332,15 +333,15 @@ MGB_DEFINE_OPR_CLASS(
const TensorShapeArray& output_shapes) const override final; const TensorShapeArray& output_shapes) const override final;


public: public:
Convolution3DBackwardFilter(
MGE_WIN_DECLSPEC_FUC Convolution3DBackwardFilter(
VarNode* src, VarNode* diff, VarNode* filter, const Param& param, VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar diff, SymbolVar filter, const Param& param, SymbolVar src, SymbolVar diff, SymbolVar filter, const Param& param,
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});
}; };


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
LocalShareForward, intl::MegDNNOprWrapperFwd<megdnn::LocalShareForward>, LocalShareForward, intl::MegDNNOprWrapperFwd<megdnn::LocalShareForward>,
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {
void init_output_dtype() override; void init_output_dtype() override;
@@ -351,16 +352,16 @@ MGB_DEFINE_OPR_CLASS(
const TensorShapeArray& output_shapes) const override final; const TensorShapeArray& output_shapes) const override final;


public: public:
LocalShareForward(
MGE_WIN_DECLSPEC_FUC LocalShareForward(
VarNode* src, VarNode* filter, const Param& param, VarNode* src, VarNode* filter, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, const Param& param = {}, SymbolVar src, SymbolVar filter, const Param& param = {},
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});
}; };
using LocalShare = LocalShareForward; using LocalShare = LocalShareForward;


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
LocalShareBackwardData, LocalShareBackwardData,
cg::SingleCNOperatorNodeBaseT< cg::SingleCNOperatorNodeBaseT<
mixin::MegDNNOprHolderImpl<megdnn::LocalShareBackwardData>>, mixin::MegDNNOprHolderImpl<megdnn::LocalShareBackwardData>>,
@@ -374,18 +375,18 @@ MGB_DEFINE_OPR_CLASS(
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;


public: public:
LocalShareBackwardData(
MGE_WIN_DECLSPEC_FUC LocalShareBackwardData(
VarNode* filter, VarNode* diff, VarNode* src_for_shp, const Param& param, VarNode* filter, VarNode* diff, VarNode* src_for_shp, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


//! grad mode; original data shape is required //! grad mode; original data shape is required
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar filter, SymbolVar diff, SymbolVar src_for_shp, SymbolVar filter, SymbolVar diff, SymbolVar src_for_shp,
const Param& param = {}, const ExecutionPolicy& policy = {}, const Param& param = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
LocalShareBackwardFilter, LocalShareBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::LocalShareBackwardFilter>, intl::MegDNNOprWrapperBwd<megdnn::LocalShareBackwardFilter>,
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {
@@ -394,24 +395,24 @@ MGB_DEFINE_OPR_CLASS(
const TensorShapeArray& output_shapes) const override final; const TensorShapeArray& output_shapes) const override final;


public: public:
LocalShareBackwardFilter(
MGE_WIN_DECLSPEC_FUC LocalShareBackwardFilter(
VarNode* src, VarNode* diff, VarNode* filter, const Param& param, VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar diff, SymbolVar filter, const Param& param, SymbolVar src, SymbolVar diff, SymbolVar filter, const Param& param,
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});
}; };


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
DeformableConvForward, intl::MegDNNOprWrapperFwd<megdnn::DeformableConvForward>, DeformableConvForward, intl::MegDNNOprWrapperFwd<megdnn::DeformableConvForward>,
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {
public: public:
DeformableConvForward(
MGE_WIN_DECLSPEC_FUC DeformableConvForward(
VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask, VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask,
const Param& param, const ExecutionPolicy& policy, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, SymbolVar offset, SymbolVar mask, SymbolVar src, SymbolVar filter, SymbolVar offset, SymbolVar mask,
const Param& param = {}, const ExecutionPolicy& policy = {}, const Param& param = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
@@ -425,21 +426,21 @@ private:
}; };
using DeformableConv = DeformableConvForward; using DeformableConv = DeformableConvForward;


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
DeformableConvBackwardData, intl::DeformableConvBackwardDataBase, DeformableConvBackwardData, intl::DeformableConvBackwardDataBase,
public mixin::AlgoChooserHelper, public mixin::ConvolutionBackwardDataMixin) // { public mixin::AlgoChooserHelper, public mixin::ConvolutionBackwardDataMixin) // {
public: public:
DeformableConvBackwardData(
MGE_WIN_DECLSPEC_FUC DeformableConvBackwardData(
VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask, VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask,
VarNode* diff, const Param& param, const ExecutionPolicy& policy, VarNode* diff, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVarArray make_all(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make_all(
SymbolVar src, SymbolVar filter, SymbolVar offset, SymbolVar mask, SymbolVar src, SymbolVar filter, SymbolVar offset, SymbolVar mask,
SymbolVar diff, const Param& param = {}, const ExecutionPolicy& policy = {}, SymbolVar diff, const Param& param = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, SymbolVar offset, SymbolVar mask, SymbolVar src, SymbolVar filter, SymbolVar offset, SymbolVar mask,
SymbolVar diff, const Param& param = {}, const ExecutionPolicy& policy = {}, SymbolVar diff, const Param& param = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
@@ -463,17 +464,17 @@ private:
} }
}; };


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
DeformableConvBackwardFilter, DeformableConvBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::DeformableConvBackwardFilter>, intl::MegDNNOprWrapperBwd<megdnn::DeformableConvBackwardFilter>,
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {
public: public:
DeformableConvBackwardFilter(
MGE_WIN_DECLSPEC_FUC DeformableConvBackwardFilter(
VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask, VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask,
VarNode* diff, const Param& param, const ExecutionPolicy& policy, VarNode* diff, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, SymbolVar offset, SymbolVar mask, SymbolVar src, SymbolVar filter, SymbolVar offset, SymbolVar mask,
SymbolVar diff, const Param& param = {}, const ExecutionPolicy& policy = {}, SymbolVar diff, const Param& param = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
@@ -486,7 +487,7 @@ private:
const TensorShapeArray& output_shapes) const override final; const TensorShapeArray& output_shapes) const override final;
}; };


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
BatchConvBiasForward, intl::BatchConvBiasForwardBase, BatchConvBiasForward, intl::BatchConvBiasForwardBase,
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {
void init_output_dtype() override; void init_output_dtype() override;
@@ -506,30 +507,30 @@ MGB_DEFINE_OPR_CLASS(


public: public:
//! src * filter //! src * filter
BatchConvBiasForward(
MGE_WIN_DECLSPEC_FUC BatchConvBiasForward(
VarNode* src, VarNode* filter, const Param& param, VarNode* src, VarNode* filter, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, const Param& param = {}, SymbolVar src, SymbolVar filter, const Param& param = {},
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});


//! src * filter + bias //! src * filter + bias
BatchConvBiasForward(
MGE_WIN_DECLSPEC_FUC BatchConvBiasForward(
VarNode* src, VarNode* filter, VarNode* bias, const Param& param, VarNode* src, VarNode* filter, VarNode* bias, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, SymbolVar bias, const Param& param = {}, SymbolVar src, SymbolVar filter, SymbolVar bias, const Param& param = {},
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});


//! src * filter + bias + z //! src * filter + bias + z
BatchConvBiasForward(
MGE_WIN_DECLSPEC_FUC BatchConvBiasForward(
VarNode* src, VarNode* filter, VarNode* bias, VarNode* z, VarNode* src, VarNode* filter, VarNode* bias, VarNode* z,
const Param& param, const ExecutionPolicy& policy, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, SymbolVar bias, SymbolVar z, SymbolVar src, SymbolVar filter, SymbolVar bias, SymbolVar z,
const Param& param = {}, const ExecutionPolicy& policy = {}, const Param& param = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


+ 6
- 6
src/opr/include/megbrain/opr/dnn/correlation.h View File

@@ -20,11 +20,11 @@ namespace opr {
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
CorrelationForward, intl::MegDNNOprWrapperFwd<megdnn::CorrelationForward>) // { CorrelationForward, intl::MegDNNOprWrapperFwd<megdnn::CorrelationForward>) // {
public: public:
CorrelationForward(
MGE_WIN_DECLSPEC_FUC CorrelationForward(
VarNode* data1, VarNode* data2, const Param& param, VarNode* data1, VarNode* data2, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar data1, SymbolVar data2, const Param& param = {}, SymbolVar data1, SymbolVar data2, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -34,11 +34,11 @@ MGB_DEFINE_OPR_CLASS(
CorrelationBackwardData1, CorrelationBackwardData1,
intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData1>) // { intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData1>) // {
public: public:
CorrelationBackwardData1(
MGE_WIN_DECLSPEC_FUC CorrelationBackwardData1(
VarNode* diff, VarNode* data1, VarNode* data2, const Param& param, VarNode* diff, VarNode* data1, VarNode* data2, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar diff, SymbolVar data1, SymbolVar data2, const Param& param = {}, SymbolVar diff, SymbolVar data1, SymbolVar data2, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -53,11 +53,11 @@ MGB_DEFINE_OPR_CLASS(
CorrelationBackwardData2, CorrelationBackwardData2,
intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData2>) // { intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData2>) // {
public: public:
CorrelationBackwardData2(
MGE_WIN_DECLSPEC_FUC CorrelationBackwardData2(
VarNode* diff, VarNode* data1, VarNode* data2, const Param& param, VarNode* diff, VarNode* data1, VarNode* data2, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar diff, SymbolVar data1, SymbolVar data2, const Param& param = {}, SymbolVar diff, SymbolVar data1, SymbolVar data2, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});




+ 5
- 5
src/opr/include/megbrain/opr/dnn/fake_quant.h View File

@@ -18,11 +18,11 @@ namespace opr {
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
FakeQuantForward, intl::MegDNNOprWrapperFwd<megdnn::FakeQuantForward>) // { FakeQuantForward, intl::MegDNNOprWrapperFwd<megdnn::FakeQuantForward>) // {
public: public:
FakeQuantForward(
MGE_WIN_DECLSPEC_FUC FakeQuantForward(
VarNode* src, VarNode* scale, VarNode* zero_point, const Param& param, VarNode* src, VarNode* scale, VarNode* zero_point, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar scale, SymbolVar zero_point, SymbolVar src, SymbolVar scale, SymbolVar zero_point,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});
}; // namespace opr }; // namespace opr
@@ -31,14 +31,14 @@ using FakeQuant = FakeQuantForward;
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
FakeQuantBackward, intl::MegDNNOprWrapperBwd<megdnn::FakeQuantBackward>) // { FakeQuantBackward, intl::MegDNNOprWrapperBwd<megdnn::FakeQuantBackward>) // {
public: public:
FakeQuantBackward(
MGE_WIN_DECLSPEC_FUC FakeQuantBackward(
VarNode* diff, VarNode* input, VarNode* scale, VarNode* zero_point, VarNode* diff, VarNode* input, VarNode* scale, VarNode* zero_point,
const Param& param, const OperatorNodeConfig& config); const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar diff, SymbolVar input, SymbolVar scale, SymbolVar zero_point, SymbolVar diff, SymbolVar input, SymbolVar scale, SymbolVar zero_point,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});
}; };


} // namespace opr } // namespace opr
} // namespace mgb
} // namespace mgb

+ 4
- 4
src/opr/include/megbrain/opr/dnn/images2neibs.h View File

@@ -19,10 +19,10 @@ namespace opr {
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
Images2NeibsForward, intl::MegDNNOprWrapperFwd<megdnn::Images2NeibsForward>) // { Images2NeibsForward, intl::MegDNNOprWrapperFwd<megdnn::Images2NeibsForward>) // {
public: public:
Images2NeibsForward(
MGE_WIN_DECLSPEC_FUC Images2NeibsForward(
VarNode* src, const Param& param, const OperatorNodeConfig& config); VarNode* src, const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param = {}, SymbolVar src, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -31,11 +31,11 @@ using Images2Neibs = Images2NeibsForward;
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
Images2NeibsBackward, intl::MegDNNOprWrapperBwd<megdnn::Images2NeibsBackward>) // { Images2NeibsBackward, intl::MegDNNOprWrapperBwd<megdnn::Images2NeibsBackward>) // {
public: public:
Images2NeibsBackward(
MGE_WIN_DECLSPEC_FUC Images2NeibsBackward(
VarNode* diff, VarNode* src_for_shape, const Param& param, VarNode* diff, VarNode* src_for_shape, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar diff, SymbolVar src_for_shape, const Param& param = {}, SymbolVar diff, SymbolVar src_for_shape, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };


+ 5
- 4
src/opr/include/megbrain/opr/dnn/lrn.h View File

@@ -19,8 +19,9 @@ namespace opr {


MGB_DEFINE_OPR_CLASS(LRNForward, intl::MegDNNOprWrapperFwd<megdnn::LRNForward>) // { MGB_DEFINE_OPR_CLASS(LRNForward, intl::MegDNNOprWrapperFwd<megdnn::LRNForward>) // {
public: public:
LRNForward(VarNode* src, const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC LRNForward(
VarNode* src, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); SymbolVar src, const Param& param, const OperatorNodeConfig& config = {});
}; };
using LRN = LRNForward; using LRN = LRNForward;
@@ -28,10 +29,10 @@ using LRN = LRNForward;
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
LRNBackward, intl::MegDNNOprWrapperBwd<megdnn::LRNBackward>) // { LRNBackward, intl::MegDNNOprWrapperBwd<megdnn::LRNBackward>) // {
public: public:
LRNBackward(
MGE_WIN_DECLSPEC_FUC LRNBackward(
VarNode* src, VarNode* dst, VarNode* diff, const Param& param, VarNode* src, VarNode* dst, VarNode* diff, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param, SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };


+ 4
- 4
src/opr/include/megbrain/opr/dnn/lsq.h View File

@@ -18,11 +18,11 @@ namespace opr {


MGB_DEFINE_OPR_CLASS(LSQForward, intl::MegDNNOprWrapperFwd<megdnn::LSQForward>) // { MGB_DEFINE_OPR_CLASS(LSQForward, intl::MegDNNOprWrapperFwd<megdnn::LSQForward>) // {
public: public:
LSQForward(
MGE_WIN_DECLSPEC_FUC LSQForward(
VarNode* src, VarNode* scale, VarNode* zero_point, VarNode* grad_scale, VarNode* src, VarNode* scale, VarNode* zero_point, VarNode* grad_scale,
const Param& param, const OperatorNodeConfig& config); const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar scale, SymbolVar zero_point, SymbolVar grad_scale, SymbolVar src, SymbolVar scale, SymbolVar zero_point, SymbolVar grad_scale,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});
}; };
@@ -31,11 +31,11 @@ using LSQ = LSQForward;
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
LSQBackward, intl::MegDNNOprWrapperBwd<megdnn::LSQBackward>) // { LSQBackward, intl::MegDNNOprWrapperBwd<megdnn::LSQBackward>) // {
public: public:
LSQBackward(
MGE_WIN_DECLSPEC_FUC LSQBackward(
VarNode* y_grad, VarNode* x, VarNode* scale, VarNode* zero_point, VarNode* y_grad, VarNode* x, VarNode* scale, VarNode* zero_point,
VarNode* grad_scale, const Param& param, const OperatorNodeConfig& config); VarNode* grad_scale, const Param& param, const OperatorNodeConfig& config);


static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar y_grad, SymbolVar x, SymbolVar scale, SymbolVar zero_point, SymbolVar y_grad, SymbolVar x, SymbolVar scale, SymbolVar zero_point,
SymbolVar grad_scale, const Param& param = {}, SymbolVar grad_scale, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


+ 5
- 5
src/opr/include/megbrain/opr/dnn/pooling.h View File

@@ -22,10 +22,10 @@ MGB_DEFINE_OPR_CLASS(
PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>, PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>,
public mixin::AlgoChooserHelper) //{ public mixin::AlgoChooserHelper) //{
public: public:
PoolingForward(
MGE_WIN_DECLSPEC_FUC PoolingForward(
VarNode* src, const Param& param, const ExecutionPolicy& policy, VarNode* src, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}, SymbolVar src, const Param& param, const OperatorNodeConfig& config = {},
const ExecutionPolicy& policy = {}); const ExecutionPolicy& policy = {});


@@ -41,15 +41,15 @@ MGB_DEFINE_OPR_CLASS(
PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>, PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>,
public mixin::AlgoChooserHelper) //{ public mixin::AlgoChooserHelper) //{
public: public:
PoolingBackward(
MGE_WIN_DECLSPEC_FUC PoolingBackward(
VarNode* src, VarNode* dst, VarNode* diff, const Param& param, VarNode* src, VarNode* dst, VarNode* diff, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config); const ExecutionPolicy& policy, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param, SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param,
const OperatorNodeConfig& config = {}, const ExecutionPolicy& policy = {}); const OperatorNodeConfig& config = {}, const ExecutionPolicy& policy = {});


size_t get_workspace_size_bytes(
MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes, const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override final; const TensorShapeArray& output_shapes) const override final;
}; };


+ 4
- 4
src/opr/include/megbrain/opr/dnn/roi_align.h View File

@@ -20,11 +20,11 @@ namespace opr {
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
ROIAlignForward, intl::MegDNNOprWrapperFwd<megdnn::ROIAlignForward>) // { ROIAlignForward, intl::MegDNNOprWrapperFwd<megdnn::ROIAlignForward>) // {
public: public:
ROIAlignForward(
MGE_WIN_DECLSPEC_FUC ROIAlignForward(
VarNode* src, VarNode* rois, const Param& param, VarNode* src, VarNode* rois, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar rois, const Param& param = {}, SymbolVar src, SymbolVar rois, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -33,11 +33,11 @@ using ROIAlign = ROIAlignForward;
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
ROIAlignBackward, intl::MegDNNOprWrapperBwd<megdnn::ROIAlignBackward>) // { ROIAlignBackward, intl::MegDNNOprWrapperBwd<megdnn::ROIAlignBackward>) // {
public: public:
ROIAlignBackward(
MGE_WIN_DECLSPEC_FUC ROIAlignBackward(
VarNode* diff, VarNode* src, VarNode* rois, VarNode* index, VarNode* diff, VarNode* src, VarNode* rois, VarNode* index,
const Param& param, const OperatorNodeConfig& config); const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar diff, SymbolVar src, SymbolVar rois, SymbolVar index, SymbolVar diff, SymbolVar src, SymbolVar rois, SymbolVar index,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});




+ 11
- 11
src/opr/include/megbrain/opr/dnn/roi_pooling.h View File

@@ -43,11 +43,11 @@ MGB_DEFINE_OPR_CLASS(
intl::WorkspaceSizeInfer<intl::OutshapeBySymvarSCNOpr< intl::WorkspaceSizeInfer<intl::OutshapeBySymvarSCNOpr<
mixin::MegDNNOprHolderImpl<megdnn::ROIPoolingForward>>>) // { mixin::MegDNNOprHolderImpl<megdnn::ROIPoolingForward>>>) // {
public: public:
ROIPoolingForward(
MGE_WIN_DECLSPEC_FUC ROIPoolingForward(
VarNode* src, VarNode* rois, VarNode* dst_shape, const Param& param, VarNode* src, VarNode* rois, VarNode* dst_shape, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar rois, SymbolVar dst_shape, const Param& param = {}, SymbolVar src, SymbolVar rois, SymbolVar dst_shape, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -76,11 +76,11 @@ using ROIPooling = ROIPoolingForward;
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
ROIPoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::ROIPoolingBackward>) // { ROIPoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::ROIPoolingBackward>) // {
public: public:
ROIPoolingBackward(
MGE_WIN_DECLSPEC_FUC ROIPoolingBackward(
VarNode* diff, VarNode* src, VarNode* rois, VarNode* index, VarNode* diff, VarNode* src, VarNode* rois, VarNode* index,
const Param& param, const OperatorNodeConfig& config); const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar diff, SymbolVar src, SymbolVar rois, SymbolVar index, SymbolVar diff, SymbolVar src, SymbolVar rois, SymbolVar index,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});
}; };
@@ -94,14 +94,14 @@ MGB_DEFINE_OPR_CLASS(
DeformablePSROIPoolingForward, DeformablePSROIPoolingForward,
intl::MegDNNOprWrapperFwd<megdnn::DeformablePSROIPoolingForward>) // { intl::MegDNNOprWrapperFwd<megdnn::DeformablePSROIPoolingForward>) // {
public: public:
DeformablePSROIPoolingForward(
MGE_WIN_DECLSPEC_FUC DeformablePSROIPoolingForward(
VarNode* src, VarNode* rois, VarNode* trans, const Param& param, VarNode* src, VarNode* rois, VarNode* trans, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVarArray make_all(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make_all(
SymbolVar src, SymbolVar rois, SymbolVar trans, const Param& param = {}, SymbolVar src, SymbolVar rois, SymbolVar trans, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar rois, SymbolVar trans, const Param& param = {}, SymbolVar src, SymbolVar rois, SymbolVar trans, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -110,18 +110,18 @@ using DeformablePSROIPooling = DeformablePSROIPoolingForward;
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
DeformablePSROIPoolingBackward, intl::DeformablePSROIPoolingBackwardT) // { DeformablePSROIPoolingBackward, intl::DeformablePSROIPoolingBackwardT) // {
public: public:
DeformablePSROIPoolingBackward(
MGE_WIN_DECLSPEC_FUC DeformablePSROIPoolingBackward(
VarNode* src, VarNode* rois, VarNode* trans, VarNode* grad, VarNode* count, VarNode* src, VarNode* rois, VarNode* trans, VarNode* grad, VarNode* count,
const Param& param, const OperatorNodeConfig& config); const Param& param, const OperatorNodeConfig& config);
static SymbolVarArray make_all(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make_all(
SymbolVar src, SymbolVar rois, SymbolVar trans, SymbolVar grad, SymbolVar src, SymbolVar rois, SymbolVar trans, SymbolVar grad,
SymbolVar count, const Param& param = {}, SymbolVar count, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar rois, SymbolVar trans, SymbolVar grad, SymbolVar src, SymbolVar rois, SymbolVar trans, SymbolVar grad,
SymbolVar count, const Param& param = {}, SymbolVar count, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
void scn_do_execute() override;
MGE_WIN_DECLSPEC_FUC void scn_do_execute() override;


private: private:
void get_output_var_shape( void get_output_var_shape(


+ 5
- 5
src/opr/include/megbrain/opr/dnn/sliding_window_transpose.h View File

@@ -20,10 +20,10 @@ MGB_DEFINE_OPR_CLASS(
SlidingWindowTransposeForward, SlidingWindowTransposeForward,
intl::MegDNNOprWrapperFwd<megdnn::SlidingWindowTransposeForward>) // { intl::MegDNNOprWrapperFwd<megdnn::SlidingWindowTransposeForward>) // {
public: public:
SlidingWindowTransposeForward(
MGE_WIN_DECLSPEC_FUC SlidingWindowTransposeForward(
VarNode* src, const Param& param, const OperatorNodeConfig& config); VarNode* src, const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param = {}, SymbolVar src, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -33,11 +33,11 @@ MGB_DEFINE_OPR_CLASS(
SlidingWindowTransposeBackward, SlidingWindowTransposeBackward,
intl::MegDNNOprWrapperBwd<megdnn::SlidingWindowTransposeBackward>) // { intl::MegDNNOprWrapperBwd<megdnn::SlidingWindowTransposeBackward>) // {
public: public:
SlidingWindowTransposeBackward(
MGE_WIN_DECLSPEC_FUC SlidingWindowTransposeBackward(
VarNode* diff, VarNode* src_for_shape, const Param& param, VarNode* diff, VarNode* src_for_shape, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar diff, SymbolVar src_for_shape, const Param& param = {}, SymbolVar diff, SymbolVar src_for_shape, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -45,4 +45,4 @@ public:
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 4
- 4
src/opr/include/megbrain/opr/dnn/tqt.h View File

@@ -18,11 +18,11 @@ namespace opr {


MGB_DEFINE_OPR_CLASS(TQTForward, intl::MegDNNOprWrapperFwd<megdnn::TQTForward>) // { MGB_DEFINE_OPR_CLASS(TQTForward, intl::MegDNNOprWrapperFwd<megdnn::TQTForward>) // {
public: public:
TQTForward(
MGE_WIN_DECLSPEC_FUC TQTForward(
VarNode* src, VarNode* scale, const Param& param, VarNode* src, VarNode* scale, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar scale, const Param& param = {}, SymbolVar src, SymbolVar scale, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -31,11 +31,11 @@ using TQT = TQTForward;
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
TQTBackward, intl::MegDNNOprWrapperBwd<megdnn::TQTBackward>) // { TQTBackward, intl::MegDNNOprWrapperBwd<megdnn::TQTBackward>) // {
public: public:
TQTBackward(
MGE_WIN_DECLSPEC_FUC TQTBackward(
VarNode* y_grad, VarNode* x, VarNode* scale, const Param& param, VarNode* y_grad, VarNode* x, VarNode* scale, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar y_grad, SymbolVar x, SymbolVar scale, const Param& param = {}, SymbolVar y_grad, SymbolVar x, SymbolVar scale, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});




+ 16
- 16
src/opr/include/megbrain/opr/imgproc.h View File

@@ -43,7 +43,7 @@ public:
VarNode* in_tensor, VarNode* mat, VarNode* mat_idx, VarNode* out_shape, VarNode* in_tensor, VarNode* mat, VarNode* mat_idx, VarNode* out_shape,
const Param& param, const OperatorNodeConfig& config); const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar in_tensor, SymbolVar mat, SymbolVar mat_idx, SymbolVar out_shape, SymbolVar in_tensor, SymbolVar mat, SymbolVar mat_idx, SymbolVar out_shape,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});


@@ -89,11 +89,11 @@ public:
VarNode* mat, VarNode* mat_idx, VarNode* out_diff, VarNode* in_for_shape, VarNode* mat, VarNode* mat_idx, VarNode* out_diff, VarNode* in_for_shape,
const Param& param, const OperatorNodeConfig& config); const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar mat, SymbolVar out_diff, SymbolVar in_for_shape, SymbolVar mat, SymbolVar out_diff, SymbolVar in_for_shape,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar mat, SymbolVar mat_idx, SymbolVar out_diff, SymbolVar mat, SymbolVar mat_idx, SymbolVar out_diff,
SymbolVar in_for_shape, const Param& param = {}, SymbolVar in_for_shape, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
@@ -115,7 +115,7 @@ public:
return make(src, mat, {}, out_diff, param, config); return make(src, mat, {}, out_diff, param, config);
} }


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar mat, SymbolVar mat_idx, SymbolVar out_diff, SymbolVar src, SymbolVar mat, SymbolVar mat_idx, SymbolVar out_diff,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});


@@ -141,7 +141,7 @@ public:
VarNode * in_tensor, VarNode * out_shape, const Param& param, VarNode * in_tensor, VarNode * out_shape, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar in_tensor, SymbolVar out_shape, const Param& param = {}, SymbolVar in_tensor, SymbolVar out_shape, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -175,7 +175,7 @@ public:
VarNode* out_diff, VarNode* in_for_shape, const Param& param, VarNode* out_diff, VarNode* in_for_shape, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar out_diff, SymbolVar in_for_shape, const Param& param = {}, SymbolVar out_diff, SymbolVar in_for_shape, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -187,7 +187,7 @@ public:
VarNode* in_tensor, VarNode* map, const Param& param, VarNode* in_tensor, VarNode* map, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar in_tensor, SymbolVar map, const Param& param = {}, SymbolVar in_tensor, SymbolVar map, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -203,7 +203,7 @@ public:
VarNode* map, VarNode* out_diff, VarNode* in_for_shape, const Param& param, VarNode* map, VarNode* out_diff, VarNode* in_for_shape, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar map, SymbolVar out_diff, SymbolVar in_for_shape, SymbolVar map, SymbolVar out_diff, SymbolVar in_for_shape,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});
}; };
@@ -215,7 +215,7 @@ public:
VarNode* src, VarNode* map, VarNode* out_diff, const Param& param, VarNode* src, VarNode* map, VarNode* out_diff, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar map, SymbolVar out_diff, const Param& param = {}, SymbolVar src, SymbolVar map, SymbolVar out_diff, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -240,7 +240,7 @@ public:
VarNode* in_tensor, VarNode* mat, VarNode* out_shape, const Param& param, VarNode* in_tensor, VarNode* mat, VarNode* out_shape, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar in_tensor, SymbolVar mat, SymbolVar out_shape, SymbolVar in_tensor, SymbolVar mat, SymbolVar out_shape,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});


@@ -278,25 +278,25 @@ public:
VarNode* src, VarNode* mask_offset, VarNode* mask_val, const Param& param, VarNode* src, VarNode* mask_offset, VarNode* mask_val, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar mask_offset, SymbolVar mask_val, SymbolVar src, SymbolVar mask_offset, SymbolVar mask_val,
const Param& param, const OperatorNodeConfig& config = {}); const Param& param, const OperatorNodeConfig& config = {});


DctChannelSelectForward(
MGE_WIN_DECLSPEC_FUC DctChannelSelectForward(
VarNode* src, const Param& param, const OperatorNodeConfig& config); VarNode* src, const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); SymbolVar src, const Param& param, const OperatorNodeConfig& config = {});
void get_output_var_shape(
MGE_WIN_DECLSPEC_FUC void get_output_var_shape(
const TensorShapeArray& inp_shape, const TensorShapeArray& inp_shape,
TensorShapeArray& out_shape) const override; TensorShapeArray& out_shape) const override;


size_t get_workspace_size_bytes(
MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes, const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override; const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override; void scn_do_execute() override;


void valid_mask(
MGE_WIN_DECLSPEC_FUC void valid_mask(
const int* mask_offset, int mask_len, const int* mask_val, int mask_val_len, const int* mask_offset, int mask_len, const int* mask_val, int mask_val_len,
const Param& param); const Param& param);
}; };


+ 8
- 8
src/opr/include/megbrain/opr/indexing.h View File

@@ -22,10 +22,10 @@ namespace opr {
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
IndexingOneHot, intl::MegDNNOprWrapperFwd<megdnn::IndexingOneHotForward>) // { IndexingOneHot, intl::MegDNNOprWrapperFwd<megdnn::IndexingOneHotForward>) // {
public: public:
IndexingOneHot(
MGE_WIN_DECLSPEC_FUC IndexingOneHot(
VarNode* src, VarNode* index, const Param& param, VarNode* src, VarNode* index, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar index, const Param& param, SymbolVar src, SymbolVar index, const Param& param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -38,10 +38,10 @@ MGB_DEFINE_OPR_CLASS(
intl::WorkspaceSizeInfer<cg::SingleCNOperatorNodeBaseT< intl::WorkspaceSizeInfer<cg::SingleCNOperatorNodeBaseT<
mixin::MegDNNOprHolderImpl<megdnn::IndexingSetOneHotForward>>>) // { mixin::MegDNNOprHolderImpl<megdnn::IndexingSetOneHotForward>>>) // {
public: public:
IndexingSetOneHot(
MGE_WIN_DECLSPEC_FUC IndexingSetOneHot(
VarNode* data, VarNode* index, VarNode* sub, const Param& param, VarNode* data, VarNode* index, VarNode* sub, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar data, SymbolVar index, SymbolVar sub, const Param& param, SymbolVar data, SymbolVar index, SymbolVar sub, const Param& param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -62,10 +62,10 @@ private:
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
IndexingRemap, intl::MegDNNOprWrapperFwd<megdnn::IndexingRemap>) // { IndexingRemap, intl::MegDNNOprWrapperFwd<megdnn::IndexingRemap>) // {
public: public:
IndexingRemap(
MGE_WIN_DECLSPEC_FUC IndexingRemap(
VarNode* src, VarNode* map, const Param& param, VarNode* src, VarNode* map, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar map, const Param& param, SymbolVar src, SymbolVar map, const Param& param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -77,10 +77,10 @@ MGB_DEFINE_OPR_CLASS(
IndexingRemapBackward, IndexingRemapBackward,
intl::MegDNNOprWrapperBwd<megdnn::IndexingRemapBackward>) // { intl::MegDNNOprWrapperBwd<megdnn::IndexingRemapBackward>) // {
public: public:
IndexingRemapBackward(
MGE_WIN_DECLSPEC_FUC IndexingRemapBackward(
VarNode* out_diff, VarNode* map, VarNode* src_for_shape, const Param& param, VarNode* out_diff, VarNode* map, VarNode* src_for_shape, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar out_diff, SymbolVar map, SymbolVar src_for_shape, SymbolVar out_diff, SymbolVar map, SymbolVar src_for_shape,
const Param& param, const OperatorNodeConfig& config = {}); const Param& param, const OperatorNodeConfig& config = {});
}; };


+ 4
- 4
src/opr/include/megbrain/opr/internal/indexing_helper.h View File

@@ -53,12 +53,12 @@ struct AxisIndexer {
SymbolVar idx; SymbolVar idx;


//! index an axis on an interval //! index an axis on an interval
static AxisIndexer make_interval(
MGE_WIN_DECLSPEC_FUC static AxisIndexer make_interval(
AxisNum axis, Maybe<SymbolVar> begin, Maybe<SymbolVar> end, AxisNum axis, Maybe<SymbolVar> begin, Maybe<SymbolVar> end,
Maybe<SymbolVar> step); Maybe<SymbolVar> step);


//! index an axis with scalar or vector indexer //! index an axis with scalar or vector indexer
static AxisIndexer make_index(AxisNum axis, SymbolVar idx);
MGE_WIN_DECLSPEC_FUC static AxisIndexer make_index(AxisNum axis, SymbolVar idx);


/*! /*!
* \brief return true if axis of *lhs* is larger than (i.e. with smaller * \brief return true if axis of *lhs* is larger than (i.e. with smaller
@@ -191,7 +191,7 @@ private:


#define MGB_DECL_FANCY_INDEXING_OPR_GET(_opr) \ #define MGB_DECL_FANCY_INDEXING_OPR_GET(_opr) \
_opr(VarNode* inp, const IndexDesc& desc, const OperatorNodeConfig& config); \ _opr(VarNode* inp, const IndexDesc& desc, const OperatorNodeConfig& config); \
static SymbolVar make( \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar inp, const IndexDesc& desc, \ SymbolVar inp, const IndexDesc& desc, \
const OperatorNodeConfig& config = {}) const OperatorNodeConfig& config = {})


@@ -212,7 +212,7 @@ private:
_opr(VarNode* inp, VarNode* value, const IndexDesc& desc, \ _opr(VarNode* inp, VarNode* value, const IndexDesc& desc, \
const OperatorNodeConfig& config, \ const OperatorNodeConfig& config, \
const InputTensorReplacer& input_tensor_replacer); \ const InputTensorReplacer& input_tensor_replacer); \
static SymbolVar make( \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar inp, SymbolVar value, const IndexDesc& desc, \ SymbolVar inp, SymbolVar value, const IndexDesc& desc, \
const OperatorNodeConfig& config = {}, \ const OperatorNodeConfig& config = {}, \
const InputTensorReplacer& input_tensor_replacer = {}) const InputTensorReplacer& input_tensor_replacer = {})


+ 24
- 17
src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h View File

@@ -21,8 +21,9 @@ namespace opr {


namespace intl { namespace intl {
//! get megdnn handle from comp node //! get megdnn handle from comp node
megdnn::Handle* get_megdnn_handle(CompNode comp_node);
std::shared_ptr<megdnn::Handle> get_megdnn_handle_shared(CompNode comp_node);
MGE_WIN_DECLSPEC_FUC megdnn::Handle* get_megdnn_handle(CompNode comp_node);
MGE_WIN_DECLSPEC_FUC std::shared_ptr<megdnn::Handle> get_megdnn_handle_shared(
CompNode comp_node);


/*! /*!
* \brief get global megdnn operator asscoated with a computing node * \brief get global megdnn operator asscoated with a computing node
@@ -32,7 +33,7 @@ std::shared_ptr<megdnn::Handle> get_megdnn_handle_shared(CompNode comp_node);
* * Checksum * * Checksum
*/ */
template <typename Opr> template <typename Opr>
Opr* get_megdnn_global_opr(CompNode comp_node);
MGE_WIN_DECLSPEC_FUC Opr* get_megdnn_global_opr(CompNode comp_node);


template <class Obj> template <class Obj>
class UniqPtrWithCN : public std::unique_ptr<Obj> { class UniqPtrWithCN : public std::unique_ptr<Obj> {
@@ -63,7 +64,8 @@ UniqPtrWithCN<Opr> create_megdnn_opr(CompNode comp_node) {
* temp storage differs from workspace because the temp storage might * temp storage differs from workspace because the temp storage might
* depends on runtime layout / pointer address * depends on runtime layout / pointer address
*/ */
DeviceTensorStorage& get_temp_storage(ComputingGraph& graph, CompNode comp_node);
MGE_WIN_DECLSPEC_FUC DeviceTensorStorage& get_temp_storage(
ComputingGraph& graph, CompNode comp_node);


/*! /*!
* \brief like get_temp_storage() but returns a DeviceTensorND instead * \brief like get_temp_storage() but returns a DeviceTensorND instead
@@ -79,10 +81,11 @@ namespace mixin {
namespace megdnn_utils { namespace megdnn_utils {


//! add input layout constraint to require all inputs to be contiguous //! add input layout constraint to require all inputs to be contiguous
void add_input_layout_constraint_contig(OperatorNodeBase& opr);
MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint_contig(OperatorNodeBase& opr);


//! called in constructor to add output vars //! called in constructor to add output vars
void add_output_vars(OperatorNodeBase& opr, size_t nr_output, bool add_workspace);
MGE_WIN_DECLSPEC_FUC void add_output_vars(
OperatorNodeBase& opr, size_t nr_output, bool add_workspace);
} }


/*! /*!
@@ -110,27 +113,29 @@ protected:
class MegDNNOprHolder : public cg::mixin::SingleCNOperatorNode { class MegDNNOprHolder : public cg::mixin::SingleCNOperatorNode {
public: public:
//! call create_opr() internally. //! call create_opr() internally.
void mixin_init_output_comp_node(OperatorNodeBase& self);
MGE_WIN_DECLSPEC_FUC void mixin_init_output_comp_node(OperatorNodeBase& self);


//! recreate operator when stream changes //! recreate operator when stream changes
void mixin_on_output_comp_node_stream_changed(OperatorNodeBase& self);
MGE_WIN_DECLSPEC_FUC void mixin_on_output_comp_node_stream_changed(
OperatorNodeBase& self);


static void record_megdnn_opr(
MGE_WIN_DECLSPEC_FUC static void record_megdnn_opr(
std::unique_ptr<megdnn::OperatorBase> opr, std::unique_ptr<megdnn::OperatorBase> opr,
cg::GraphExecutable::ExecDependencyArray& deps); cg::GraphExecutable::ExecDependencyArray& deps);


protected: protected:
~MegDNNOprHolder() noexcept;
MGE_WIN_DECLSPEC_FUC ~MegDNNOprHolder() noexcept;


//! create actual megdnnn operator //! create actual megdnnn operator
virtual void create_megdnn_opr() = 0; virtual void create_megdnn_opr() = 0;


megdnn::OperatorBase* megdnn_opr() const { return m_dnn_opr.get(); } megdnn::OperatorBase* megdnn_opr() const { return m_dnn_opr.get(); }


void set_megdnn_opr(std::unique_ptr<megdnn::OperatorBase> opr);
MGE_WIN_DECLSPEC_FUC void set_megdnn_opr(std::unique_ptr<megdnn::OperatorBase> opr);


//! record the megdnn opr owned by this opr to ExecDependencyArray //! record the megdnn opr owned by this opr to ExecDependencyArray
void record_megdnn_opr(cg::GraphExecutable::ExecDependencyArray& deps);
MGE_WIN_DECLSPEC_FUC void record_megdnn_opr(
cg::GraphExecutable::ExecDependencyArray& deps);


private: private:
std::unique_ptr<megdnn::OperatorBase> m_dnn_opr; std::unique_ptr<megdnn::OperatorBase> m_dnn_opr;
@@ -323,8 +328,10 @@ public:
using GetWorkspaceLimitImpl = thin_function<size_t(CompNode, size_t)>; using GetWorkspaceLimitImpl = thin_function<size_t(CompNode, size_t)>;
WorkspaceLimitHook() = default; WorkspaceLimitHook() = default;
~WorkspaceLimitHook() = default; ~WorkspaceLimitHook() = default;
static void set_impl(ComputingGraph* graph, GetWorkspaceLimitImpl impl);
static const GetWorkspaceLimitImpl& get_impl(ComputingGraph* graph);
MGE_WIN_DECLSPEC_FUC static void set_impl(
ComputingGraph* graph, GetWorkspaceLimitImpl impl);
MGE_WIN_DECLSPEC_FUC static const GetWorkspaceLimitImpl& get_impl(
ComputingGraph* graph);


private: private:
void set_impl(GetWorkspaceLimitImpl impl); void set_impl(GetWorkspaceLimitImpl impl);
@@ -341,7 +348,7 @@ private:
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \ MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \
public: \ public: \
_name(VarNode* p0, const Param& param, const OperatorNodeConfig& config); \ _name(VarNode* p0, const Param& param, const OperatorNodeConfig& config); \
static SymbolVar make( \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar p0, const Param& param = {}, \ SymbolVar p0, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \ const OperatorNodeConfig& config = {}); \
} }
@@ -352,7 +359,7 @@ public: \
public: \ public: \
_name(VarNode* p0, VarNode* p1, const Param& param, \ _name(VarNode* p0, VarNode* p1, const Param& param, \
const OperatorNodeConfig& config); \ const OperatorNodeConfig& config); \
static SymbolVar make( \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar p0, SymbolVar p1, const Param& param = {}, \ SymbolVar p0, SymbolVar p1, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \ const OperatorNodeConfig& config = {}); \
} }
@@ -362,7 +369,7 @@ public: \
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperBwd<megdnn::_name>) \ MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperBwd<megdnn::_name>) \
_extra public : _name(VarNode* p0, VarNode* p1, VarNode* p2, const Param& param, \ _extra public : _name(VarNode* p0, VarNode* p1, VarNode* p2, const Param& param, \
const OperatorNodeConfig& config); \ const OperatorNodeConfig& config); \
static SymbolVar make( \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar p0, SymbolVar p1, SymbolVar p2, const Param& param = {}, \ SymbolVar p0, SymbolVar p1, SymbolVar p2, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \ const OperatorNodeConfig& config = {}); \
} }


+ 24
- 20
src/opr/include/megbrain/opr/io.h View File

@@ -154,7 +154,7 @@ private:
* triggered. * triggered.
* 2. If host data is not contiguous, it would be relayouted on host. * 2. If host data is not contiguous, it would be relayouted on host.
*/ */
MGB_DEFINE_OPR_CLASS(Host2DeviceCopy, intl::HostIONodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Host2DeviceCopy, intl::HostIONodeBase) // {
class HostValueExecDep; class HostValueExecDep;


public: public:
@@ -203,7 +203,7 @@ public:
return make(graph, host_data, p, config); return make(graph, host_data, p, config);
} }


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
ComputingGraph& graph, const std::shared_ptr<HostTensorND>& host_data, ComputingGraph& graph, const std::shared_ptr<HostTensorND>& host_data,
const Param& param, const OperatorNodeConfig& config); const Param& param, const OperatorNodeConfig& config);


@@ -246,13 +246,14 @@ private:
* *
* \see intl::SharedDeviceTensorBase and VolatileSharedDeviceTensor * \see intl::SharedDeviceTensorBase and VolatileSharedDeviceTensor
*/ */
MGB_DEFINE_OPR_CLASS(SharedDeviceTensor, intl::SharedDeviceTensorBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
SharedDeviceTensor, intl::SharedDeviceTensorBase) // {
cg::static_infer::SourceType static_infer_src_type() const override; cg::static_infer::SourceType static_infer_src_type() const override;


public: public:
using Super::Super; using Super::Super;


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data,
bool const_value, const OperatorNodeConfig& config); bool const_value, const OperatorNodeConfig& config);


@@ -273,7 +274,7 @@ public:
* *
* See SharedDeviceTensorBase::SharedDeviceTensorBase for const_value. * See SharedDeviceTensorBase::SharedDeviceTensorBase for const_value.
*/ */
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
ComputingGraph& graph, const HostTensorND& value, bool const_value, ComputingGraph& graph, const HostTensorND& value, bool const_value,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


@@ -295,7 +296,8 @@ public:
* *
* This opr is usually used in serialized models. * This opr is usually used in serialized models.
*/ */
MGB_DEFINE_OPR_CLASS(SharedDeviceTensorWithFormat, intl::SharedDeviceTensorBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
SharedDeviceTensorWithFormat, intl::SharedDeviceTensorBase) // {
cg::static_infer::SourceType static_infer_src_type() const override; cg::static_infer::SourceType static_infer_src_type() const override;


public: public:
@@ -303,7 +305,7 @@ public:


void init_output_format() override; void init_output_format() override;


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data,
bool const_value, const OperatorNodeConfig& config); bool const_value, const OperatorNodeConfig& config);


@@ -328,13 +330,14 @@ public:
* *
* \see intl::SharedDeviceTensorBase and SharedDeviceTensor * \see intl::SharedDeviceTensorBase and SharedDeviceTensor
*/ */
MGB_DEFINE_OPR_CLASS(VolatileSharedDeviceTensor, intl::SharedDeviceTensorBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
VolatileSharedDeviceTensor, intl::SharedDeviceTensorBase) // {
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;


public: public:
using Super::Super; using Super::Super;


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -350,7 +353,7 @@ public:
/*! /*!
* \brief tensor with immutable value * \brief tensor with immutable value
*/ */
MGB_DEFINE_OPR_CLASS(ImmutableTensor, intl::DeviceTensorHolder) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(ImmutableTensor, intl::DeviceTensorHolder) // {
public: public:
class Value; class Value;
class DevValueCache; class DevValueCache;
@@ -360,19 +363,19 @@ public:
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
~ImmutableTensor() noexcept; ~ImmutableTensor() noexcept;


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
ComputingGraph& graph, const HostTensorND& val, ComputingGraph& graph, const HostTensorND& val,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


//! make from DTypeScalar; comp node must be provided in config //! make from DTypeScalar; comp node must be provided in config
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
ComputingGraph& graph, const DTypeScalar& val, ComputingGraph& graph, const DTypeScalar& val,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


//! get underlying value on device //! get underlying value on device
const DeviceTensorND& value() const;
MGE_WIN_DECLSPEC_FUC const DeviceTensorND& value() const;


const DeviceTensorND& host_value();
MGE_WIN_DECLSPEC_FUC const DeviceTensorND& host_value();


SymbolVar shallow_copy( SymbolVar shallow_copy(
ComputingGraph& graph, const OperatorNodeConfig& config) const { ComputingGraph& graph, const OperatorNodeConfig& config) const {
@@ -404,7 +407,7 @@ private:
* *
* Output var would be placed on copy stream by default. * Output var would be placed on copy stream by default.
*/ */
MGB_DEFINE_OPR_CLASS(Copy, cg::SingleCNIOSameShapeOperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Copy, cg::SingleCNIOSameShapeOperatorNodeBase) // {
bool m_mem_fwd_success = false; bool m_mem_fwd_success = false;


void scn_do_execute() override; void scn_do_execute() override;
@@ -418,7 +421,8 @@ MGB_DEFINE_OPR_CLASS(Copy, cg::SingleCNIOSameShapeOperatorNodeBase) // {


public: public:
Copy(VarNode* inp, const OperatorNodeConfig& config); Copy(VarNode* inp, const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar inp, const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar inp, const OperatorNodeConfig& config = {});


// for serialization // for serialization
using Param = megdnn::param::Empty; using Param = megdnn::param::Empty;
@@ -433,11 +437,11 @@ public:
* *
* \see intl::MultipleDeviceTensorHolderBase * \see intl::MultipleDeviceTensorHolderBase
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
MultipleDeviceTensorHolder, intl::MultipleDeviceTensorHolderBase) // { MultipleDeviceTensorHolder, intl::MultipleDeviceTensorHolderBase) // {
public: public:
using Super::Super; using Super::Super;
static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
ComputingGraph& graph, ValueArray values, ComputingGraph& graph, ValueArray values,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -447,11 +451,11 @@ public:
* *
* \see intl::MultipleDeviceTensorHolderBase * \see intl::MultipleDeviceTensorHolderBase
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
MultipleDeviceTensorWithFormatHolder, intl::MultipleDeviceTensorHolderBase) // { MultipleDeviceTensorWithFormatHolder, intl::MultipleDeviceTensorHolderBase) // {
public: public:
using Super::Super; using Super::Super;
static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
ComputingGraph& graph, ValueArray values, ComputingGraph& graph, ValueArray values,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});




+ 1
- 1
src/opr/include/megbrain/opr/loop.h View File

@@ -128,7 +128,7 @@ public:
* which must have no side-effect so a desc could be made for grad * which must have no side-effect so a desc could be made for grad
* opr * opr
*/ */
static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
DescMaker desc_maker, const Param& param = {}, DescMaker desc_maker, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});




+ 32
- 25
src/opr/include/megbrain/opr/misc.h View File

@@ -24,17 +24,21 @@
namespace mgb { namespace mgb {
namespace opr { namespace opr {


MGB_DEFINE_OPR_CLASS(Argmax, intl::MegDNNOprWrapperFwd<megdnn::Argmax>) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Argmax, intl::MegDNNOprWrapperFwd<megdnn::Argmax>) // {
public: public:
Argmax(VarNode* src, const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC Argmax(
VarNode* src, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); SymbolVar src, const Param& param, const OperatorNodeConfig& config = {});
}; };


MGB_DEFINE_OPR_CLASS(Argmin, intl::MegDNNOprWrapperFwd<megdnn::Argmin>) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Argmin, intl::MegDNNOprWrapperFwd<megdnn::Argmin>) // {
public: public:
Argmin(VarNode* src, const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC Argmin(
VarNode* src, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); SymbolVar src, const Param& param, const OperatorNodeConfig& config = {});
}; };


@@ -47,7 +51,7 @@ public:
* \param[out] out_tensor the first output: \f$(m, n)\f$ sorted output tensor * \param[out] out_tensor the first output: \f$(m, n)\f$ sorted output tensor
* \param[out] indices the second output: \f$(m, n)\f$ sorted indices * \param[out] indices the second output: \f$(m, n)\f$ sorted indices
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
ArgsortForward, intl::MegDNNOprWrapperFwd<megdnn::ArgsortForward>) // { ArgsortForward, intl::MegDNNOprWrapperFwd<megdnn::ArgsortForward>) // {
protected: protected:
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;
@@ -57,23 +61,23 @@ protected:
TensorShapeArray& out_shape) const override; TensorShapeArray& out_shape) const override;


public: public:
ArgsortForward(
MGE_WIN_DECLSPEC_FUC ArgsortForward(
VarNode* in_tensor, const Param& param, const OperatorNodeConfig& config); VarNode* in_tensor, const Param& param, const OperatorNodeConfig& config);


static std::array<SymbolVar, 2> make(
MGE_WIN_DECLSPEC_FUC static std::array<SymbolVar, 2> make(
SymbolVar in_tensor, const Param& param = {}, SymbolVar in_tensor, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
using Argsort = ArgsortForward; using Argsort = ArgsortForward;


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
ArgsortBackward, intl::MegDNNOprWrapperBwd<megdnn::ArgsortBackward>) // { ArgsortBackward, intl::MegDNNOprWrapperBwd<megdnn::ArgsortBackward>) // {
public: public:
ArgsortBackward(
MGE_WIN_DECLSPEC_FUC ArgsortBackward(
VarNode* out_diff, VarNode* indices, VarNode* result_shape, VarNode* out_diff, VarNode* indices, VarNode* result_shape,
const Param& param, const OperatorNodeConfig& config); const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar out_diff, SymbolVar indices, SymbolVar result_shape, SymbolVar out_diff, SymbolVar indices, SymbolVar result_shape,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});
static SymbolVar make( static SymbolVar make(
@@ -84,16 +88,17 @@ public:
}; };


//! cumulative sum along given axis //! cumulative sum along given axis
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Cumsum, Cumsum,
cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolderImpl<megdnn::Cumsum>>) // { cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolderImpl<megdnn::Cumsum>>) // {
void add_input_layout_constraint() override; void add_input_layout_constraint() override;


public: public:
Cumsum(VarNode* src, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC Cumsum(
VarNode* src, const Param& param, const OperatorNodeConfig& config);


// for serialization // for serialization
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar opr, const Param& param, const OperatorNodeConfig& config = {}); SymbolVar opr, const Param& param, const OperatorNodeConfig& config = {});


protected: protected:
@@ -102,13 +107,14 @@ protected:
}; };


#if MGB_CUDA #if MGB_CUDA
MGB_DEFINE_OPR_CLASS(NvOf, cg::SingleCNOperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(NvOf, cg::SingleCNOperatorNodeBase) // {
public: public:
using Param = megdnn::param::NvOf; using Param = megdnn::param::NvOf;
NvOf(VarNode* src, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC NvOf(
VarNode* src, const Param& param, const OperatorNodeConfig& config);


// for serialization // for serialization
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar opr, const Param& param, const OperatorNodeConfig& config = {}); SymbolVar opr, const Param& param, const OperatorNodeConfig& config = {});


static SymbolVar make(SymbolVar opr, const OperatorNodeConfig& config = {}) { static SymbolVar make(SymbolVar opr, const OperatorNodeConfig& config = {}) {
@@ -142,22 +148,22 @@ using TopKBase = cg::SingleCNOperatorNode<
* \brief take values conditionally * \brief take values conditionally
* outputs: values, indices * outputs: values, indices
*/ */
MGB_DEFINE_OPR_CLASS(CondTake, intl::CondTakeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(CondTake, intl::CondTakeBase) // {
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
void scn_do_execute() override; void scn_do_execute() override;
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;


public: public:
CondTake(
MGE_WIN_DECLSPEC_FUC CondTake(
VarNode* data, VarNode* mask, const Param& param, VarNode* data, VarNode* mask, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
static std::array<SymbolVar, 2> make(
MGE_WIN_DECLSPEC_FUC static std::array<SymbolVar, 2> make(
SymbolVar data, SymbolVar mask, const Param& param, SymbolVar data, SymbolVar mask, const Param& param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };


MGB_DEFINE_OPR_CLASS(TopK, intl::TopKBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(TopK, intl::TopKBase) // {
void init_output_dtype() override; void init_output_dtype() override;
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
@@ -165,11 +171,12 @@ MGB_DEFINE_OPR_CLASS(TopK, intl::TopKBase) // {
void record_execute_deps(ExecDependencyArray& deps) override; void record_execute_deps(ExecDependencyArray& deps) override;


public: public:
TopK(VarNode* data, VarNode* k, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC TopK(
VarNode* data, VarNode* k, const Param& param,
const OperatorNodeConfig& config);


//! note: for KTH_ONLY mode, the second output would be nullptr //! note: for KTH_ONLY mode, the second output would be nullptr
static std::array<SymbolVar, 2> make(
MGE_WIN_DECLSPEC_FUC static std::array<SymbolVar, 2> make(
SymbolVar data, SymbolVar k, const Param& param, SymbolVar data, SymbolVar k, const Param& param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };


+ 3
- 3
src/opr/include/megbrain/opr/nn_int.h View File

@@ -29,11 +29,11 @@ MGB_DEFINE_OPR_CLASS(ElemwiseMultiType, intl::ElemwiseMultiTypeBase) // {
public: public:
using Mode = Param::Mode; using Mode = Param::Mode;


ElemwiseMultiType(
MGE_WIN_DECLSPEC_FUC ElemwiseMultiType(
const VarNodeArrayView& inputs, Param param, const VarNodeArrayView& inputs, Param param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
const VarNodeArrayView& inputs, Param param, const VarNodeArrayView& inputs, Param param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -57,7 +57,7 @@ class AffineInt final : public DynTypeObj {


public: public:
using Param = megdnn::param::Empty; using Param = megdnn::param::Empty;
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar x, SymbolVar k, SymbolVar b, const Param& param = {}, SymbolVar x, SymbolVar k, SymbolVar b, const Param& param = {},
const OperatorNodeConfig& config = {}) { const OperatorNodeConfig& config = {}) {
return ElemwiseMultiType::make( return ElemwiseMultiType::make(


+ 6
- 6
src/opr/include/megbrain/opr/rand.h View File

@@ -41,12 +41,12 @@ protected:


/* ================= RNG with shape ================= */ /* ================= RNG with shape ================= */
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ #define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
\ \
public: \ public: \
RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \ RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \
static SymbolVar make( \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar shape, const Param& param = {}, \ SymbolVar shape, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \ const OperatorNodeConfig& config = {}); \
static SymbolVar make( \ static SymbolVar make( \
@@ -67,13 +67,13 @@ _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG)


/* ================= RNG with input ================= */ /* ================= RNG with input ================= */
#define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \ #define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \
void add_input_layout_constraint() override; \ void add_input_layout_constraint() override; \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
\ \
public: \ public: \
RNG(_INPUTS(VarNode*), const Param& param, const OperatorNodeConfig& config); \ RNG(_INPUTS(VarNode*), const Param& param, const OperatorNodeConfig& config); \
static _OUTPUTS make( \
MGE_WIN_DECLSPEC_FUC static _OUTPUTS make( \
_INPUTS(SymbolVar), const Param& param = {}, \ _INPUTS(SymbolVar), const Param& param = {}, \
const OperatorNodeConfig& config = {}); \ const OperatorNodeConfig& config = {}); \
void init_output_static_infer_desc() override; \ void init_output_static_infer_desc() override; \
@@ -110,7 +110,7 @@ using PoissonRNG = intl::PoissonRNG;
using BetaRNG = intl::BetaRNG; using BetaRNG = intl::BetaRNG;
using ShuffleRNG = intl::ShuffleRNGForward; using ShuffleRNG = intl::ShuffleRNGForward;


MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
ShuffleRNGBackward, ShuffleRNGBackward,
intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) //{ intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) //{
public: public:
@@ -118,7 +118,7 @@ ShuffleRNGBackward(
VarNode* out_diff, VarNode* indices, VarNode* result_shape, const Param& param, VarNode* out_diff, VarNode* indices, VarNode* result_shape, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar out_diff, SymbolVar indices, SymbolVar result_shape, SymbolVar out_diff, SymbolVar indices, SymbolVar result_shape,
const Param& param = {}, const OperatorNodeConfig& config = {}); const Param& param = {}, const OperatorNodeConfig& config = {});
}; };


+ 6
- 4
src/opr/include/megbrain/opr/standalone/nms_opr.h View File

@@ -10,7 +10,8 @@ namespace standalone {
* *
* See the docs in the python operator * See the docs in the python operator
*/ */
MGB_DEFINE_OPR_CLASS(NMSKeep, cg::SingleCNOutshapePureByInshapeOprBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
NMSKeep, cg::SingleCNOutshapePureByInshapeOprBase) // {
public: public:
struct Param { struct Param {
//! TAG is used by the serializer to check Param type; here we //! TAG is used by the serializer to check Param type; here we
@@ -22,11 +23,12 @@ public:
uint32_t max_output; //!< max number of output boxes per batch uint32_t max_output; //!< max number of output boxes per batch
}; };


NMSKeep(VarNode* boxes, const Param& param, const OperatorNodeConfig& config);
~NMSKeep() noexcept;
MGE_WIN_DECLSPEC_FUC NMSKeep(
VarNode* boxes, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC ~NMSKeep() noexcept;


//! factory method to insert the operator into a graph //! factory method to insert the operator into a graph
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar boxes, const Param& param, const OperatorNodeConfig& config = {}); SymbolVar boxes, const Param& param, const OperatorNodeConfig& config = {});


const Param& param() const { return m_param; } const Param& param() const { return m_param; }


+ 3
- 3
src/opr/include/megbrain/opr/tensor_gen.h View File

@@ -33,7 +33,7 @@ MGB_DEFINE_OPR_CLASS(Alloc, intl::OutshapeBySymvarSCNOprBase) // {
public: public:
Alloc(VarNode* shape, DType dtype, const OperatorNodeConfig& config); Alloc(VarNode* shape, DType dtype, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar shape, DType dtype, const OperatorNodeConfig& config = {}); SymbolVar shape, DType dtype, const OperatorNodeConfig& config = {});


static SymbolVar make( static SymbolVar make(
@@ -61,7 +61,7 @@ public:
VarNode* start, VarNode* stop, VarNode* num, const Param& param, VarNode* start, VarNode* stop, VarNode* num, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar start, SymbolVar stop, SymbolVar num, const Param& param, SymbolVar start, SymbolVar stop, SymbolVar num, const Param& param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -83,7 +83,7 @@ public:
using Param = megdnn::Eye::Param; using Param = megdnn::Eye::Param;
Eye(VarNode* shape, const Param& param, const OperatorNodeConfig& config); Eye(VarNode* shape, const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar shape, const Param& param, const OperatorNodeConfig& config = {}); SymbolVar shape, const Param& param, const OperatorNodeConfig& config = {});


const Param& param() const { return m_param; } const Param& param() const { return m_param; }


+ 105
- 97
src/opr/include/megbrain/opr/tensor_manip.h View File

@@ -31,7 +31,7 @@ namespace opr {
* *
* \param axis output shape of a single axis * \param axis output shape of a single axis
*/ */
MGB_DEFINE_OPR_CLASS(GetVarShape, cg::SingleCNOperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(GetVarShape, cg::SingleCNOperatorNodeBase) // {
class ShapeDevValueExecDep; class ShapeDevValueExecDep;


public: public:
@@ -46,7 +46,7 @@ public:
} }


//! get broadcasted shape //! get broadcasted shape
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
const VarNodeArrayView& inp, Param axis = {}, const VarNodeArrayView& inp, Param axis = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -62,15 +62,16 @@ private:
DeviceTensorND m_cached_shape_cpu_v{CompNode::default_cpu()}, m_cached_shape_dev_v; DeviceTensorND m_cached_shape_cpu_v{CompNode::default_cpu()}, m_cached_shape_dev_v;


//! update m_cached_shape from m_src_shapes //! update m_cached_shape from m_src_shapes
void update_cached_shape();
MGE_WIN_DECLSPEC_FUC void update_cached_shape();


//! update m_cached_shape for static infer //! update m_cached_shape for static infer
void update_for_static_infer(const cg::static_infer::InpVal& inp);
MGE_WIN_DECLSPEC_FUC void update_for_static_infer(
const cg::static_infer::InpVal& inp);


NodeProp* do_make_node_prop() const override;
void scn_do_execute() override;
void init_output_static_infer_desc() override;
void record_execute_deps(ExecDependencyArray& deps) override;
MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override;
MGE_WIN_DECLSPEC_FUC void scn_do_execute() override;
MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override;
MGE_WIN_DECLSPEC_FUC void record_execute_deps(ExecDependencyArray& deps) override;
}; };


namespace intl { namespace intl {
@@ -82,13 +83,13 @@ MGB_DEFINE_CLS_WITH_SUPER(
ReshapeBrdcastHelper, ReadonlyFwdHelper<OutshapeBySymvarSCNOprBase>) // { ReshapeBrdcastHelper, ReadonlyFwdHelper<OutshapeBySymvarSCNOprBase>) // {
bool m_incompatible_inp_layout = false; bool m_incompatible_inp_layout = false;


void mem_plan_fwd_in2out_readonly() override final;
void outshape_by_symvar_do_get_output_shape(
MGE_WIN_DECLSPEC_FUC void mem_plan_fwd_in2out_readonly() override final;
MGE_WIN_DECLSPEC_FUC void outshape_by_symvar_do_get_output_shape(
TensorShape& dest, const ShapeInferInfo& shpinfo) override final; TensorShape& dest, const ShapeInferInfo& shpinfo) override final;
void scn_do_execute() override final;
void add_input_layout_constraint() override final;
void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;
MGE_WIN_DECLSPEC_FUC void scn_do_execute() override final;
MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint() override final;
MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override;
MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override;


protected: protected:
using Super::Super; using Super::Super;
@@ -118,14 +119,15 @@ protected:
* \param unspec_axis the axis that shape is not specified in input, but should * \param unspec_axis the axis that shape is not specified in input, but should
* be calculated from total number of elements and other dims in dest shape * be calculated from total number of elements and other dims in dest shape
*/ */
MGB_DEFINE_OPR_CLASS(Reshape, intl::ReshapeBrdcastHelper) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Reshape, intl::ReshapeBrdcastHelper) // {
public: public:
using Param = megdnn::param::OptionalAxisV1; using Param = megdnn::param::OptionalAxisV1;


Reshape(VarNode* inp, VarNode* tshp, Param unspec_axis,
MGE_WIN_DECLSPEC_FUC Reshape(
VarNode* inp, VarNode* tshp, Param unspec_axis,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar inp, SymbolVar tshp, Param unspec_axis = {}, SymbolVar inp, SymbolVar tshp, Param unspec_axis = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -150,16 +152,17 @@ private:
/*! /*!
* \brief broadcast tensor value along axes whose shape is 1 * \brief broadcast tensor value along axes whose shape is 1
*/ */
MGB_DEFINE_OPR_CLASS(Broadcast, intl::ReshapeBrdcastHelper) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Broadcast, intl::ReshapeBrdcastHelper) // {
Maybe<TensorLayout> reshapebrdcast_get_dest_layout( Maybe<TensorLayout> reshapebrdcast_get_dest_layout(
const TensorLayout& src, const TensorShape& tshape) const override; const TensorLayout& src, const TensorShape& tshape) const override;


bool reshapebrdcast_output_shape_need_input_shape() const override; bool reshapebrdcast_output_shape_need_input_shape() const override;


public: public:
Broadcast(VarNode* inp, VarNode* tshp, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC Broadcast(
VarNode* inp, VarNode* tshp, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar inp, SymbolVar tshp, const OperatorNodeConfig& config = {}); SymbolVar inp, SymbolVar tshp, const OperatorNodeConfig& config = {});


static SymbolVar make( static SymbolVar make(
@@ -188,10 +191,10 @@ namespace intl {
*/ */
MGB_DEFINE_CLS_WITH_SUPER( MGB_DEFINE_CLS_WITH_SUPER(
AxisManipOprBase, ReadonlyFwdHelper<cg::SingleCNOperatorNodeBase>) // { AxisManipOprBase, ReadonlyFwdHelper<cg::SingleCNOperatorNodeBase>) // {
void mem_plan_fwd_in2out_readonly() override final;
void scn_do_execute() override final;
void init_output_static_infer_desc() override final;
NodeProp* do_make_node_prop() const override;
MGE_WIN_DECLSPEC_FUC void mem_plan_fwd_in2out_readonly() override final;
MGE_WIN_DECLSPEC_FUC void scn_do_execute() override final;
MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override final;
MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override;


protected: protected:
using Super::Super; using Super::Super;
@@ -211,7 +214,7 @@ protected:
* *
* Note that dimensions with shape-1 could be dropped * Note that dimensions with shape-1 could be dropped
*/ */
MGB_DEFINE_OPR_CLASS(Dimshuffle, intl::AxisManipOprBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Dimshuffle, intl::AxisManipOprBase) // {
std::vector<int> m_pattern; std::vector<int> m_pattern;
size_t m_inp_ndim; size_t m_inp_ndim;


@@ -219,15 +222,16 @@ MGB_DEFINE_OPR_CLASS(Dimshuffle, intl::AxisManipOprBase) // {
const TensorLayout& inp_layout) const override; const TensorLayout& inp_layout) const override;


public: public:
Dimshuffle(
MGE_WIN_DECLSPEC_FUC Dimshuffle(
VarNode* inp, const std::vector<int>& pattern, size_t ndim, VarNode* inp, const std::vector<int>& pattern, size_t ndim,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar inp, const std::vector<int>& pattern, size_t ndim = 0, SymbolVar inp, const std::vector<int>& pattern, size_t ndim = 0,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


VarNode* grad(size_t wrt_idx, const VarNodeArray& out_grad) const;
MGE_WIN_DECLSPEC_FUC VarNode* grad(
size_t wrt_idx, const VarNodeArray& out_grad) const;


// used for serialization // used for serialization
struct Param { struct Param {
@@ -256,7 +260,7 @@ public:
* *
* All the axis descs would be processed in order * All the axis descs would be processed in order
*/ */
MGB_DEFINE_OPR_CLASS(AxisAddRemove, intl::AxisManipOprBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(AxisAddRemove, intl::AxisManipOprBase) // {
public: public:
struct AxisDesc { struct AxisDesc {
enum class Method { enum class Method {
@@ -283,11 +287,11 @@ public:
} }
}; };


AxisAddRemove(
MGE_WIN_DECLSPEC_FUC AxisAddRemove(
VarNode* inp, const std::vector<AxisDesc>& desc, VarNode* inp, const std::vector<AxisDesc>& desc,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar inp, const std::vector<AxisDesc>& desc, SymbolVar inp, const std::vector<AxisDesc>& desc,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -313,15 +317,15 @@ public:
private: private:
std::vector<AxisDesc> m_desc; std::vector<AxisDesc> m_desc;


TensorLayout axis_manip_get_output_layout(
const TensorLayout& inp_layout) const override;
MGE_WIN_DECLSPEC_FUC TensorLayout
axis_manip_get_output_layout(const TensorLayout& inp_layout) const override;
}; };


namespace intl { namespace intl {


MGB_DEFINE_CLS_WITH_SUPER(ModifySubtensorImplHelper, FancyIndexingHelper) // { MGB_DEFINE_CLS_WITH_SUPER(ModifySubtensorImplHelper, FancyIndexingHelper) // {
void init_output_static_infer_desc() override final;
void scn_do_execute() override final;
MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override final;
MGE_WIN_DECLSPEC_FUC void scn_do_execute() override final;


/*! /*!
* \brief implement the actual modifycation * \brief implement the actual modifycation
@@ -341,18 +345,19 @@ protected:
/*! /*!
* \brief get subtensor in a python-like way * \brief get subtensor in a python-like way
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Subtensor, intl::ReadonlyFwdHelper<intl::FancyIndexingHelper>) // { Subtensor, intl::ReadonlyFwdHelper<intl::FancyIndexingHelper>) // {
void init_output_static_infer_desc() override;
void scn_do_execute() override;
void mem_plan_fwd_in2out_readonly() override;
void init_rt_force_dynamic_mem_alloc_imply_chain() override;
NodeProp* do_make_node_prop() const override;
MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override;
MGE_WIN_DECLSPEC_FUC void scn_do_execute() override;
MGE_WIN_DECLSPEC_FUC void mem_plan_fwd_in2out_readonly() override;
MGE_WIN_DECLSPEC_FUC void init_rt_force_dynamic_mem_alloc_imply_chain() override;
MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override;


public: public:
Subtensor(VarNode* inp, const IndexDesc& desc, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC Subtensor(
VarNode* inp, const IndexDesc& desc, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar inp, const IndexDesc& desc, SymbolVar inp, const IndexDesc& desc,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -360,7 +365,7 @@ public:
/*! /*!
* \brief replace the value of subtensor by another tensor * \brief replace the value of subtensor by another tensor
*/ */
MGB_DEFINE_OPR_CLASS(SetSubtensor, intl::ModifySubtensorImplHelper) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(SetSubtensor, intl::ModifySubtensorImplHelper) // {
void modify(DeviceTensorND& sub, const DeviceTensorND& val) override; void modify(DeviceTensorND& sub, const DeviceTensorND& val) override;
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;


@@ -371,7 +376,7 @@ public:
/*! /*!
* \brief increase the value of subtensor by another tensor * \brief increase the value of subtensor by another tensor
*/ */
MGB_DEFINE_OPR_CLASS(IncrSubtensor, intl::ModifySubtensorImplHelper) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(IncrSubtensor, intl::ModifySubtensorImplHelper) // {
void modify(DeviceTensorND& sub, const DeviceTensorND& val) override; void modify(DeviceTensorND& sub, const DeviceTensorND& val) override;


public: public:
@@ -384,7 +389,7 @@ public:
* \brief helper for Subtensor with only index * \brief helper for Subtensor with only index
* \param index list of pairs of (axis, index) * \param index list of pairs of (axis, index)
*/ */
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar inp, const std::vector<std::pair<size_t, SymbolVar>>& index, SymbolVar inp, const std::vector<std::pair<size_t, SymbolVar>>& index,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -399,7 +404,7 @@ public:
* on this comp_node * on this comp_node
* 3. Specify comp_node for each output in OperatorNodeConfig * 3. Specify comp_node for each output in OperatorNodeConfig
*/ */
MGB_DEFINE_OPR_CLASS(Split, intl::OutshapeBySymvarOprBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Split, intl::OutshapeBySymvarOprBase) // {
public: public:
struct Options { struct Options {
enum class Method { enum class Method {
@@ -428,7 +433,7 @@ public:


Split(VarNode* inp, const Options& opt, const OperatorNodeConfig& config); Split(VarNode* inp, const Options& opt, const OperatorNodeConfig& config);


static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar inp, Options opt, const OperatorNodeConfig& config = {}); SymbolVar inp, Options opt, const OperatorNodeConfig& config = {});


const Options& options() const { return m_opt; } const Options& options() const { return m_opt; }
@@ -444,30 +449,30 @@ private:
Options m_opt; Options m_opt;
size_t m_output_shape_version = 0; size_t m_output_shape_version = 0;


void init_output_comp_node() override;
MGE_WIN_DECLSPEC_FUC void init_output_comp_node() override;


NodeProp* do_make_node_prop() const override;
MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override;


void do_execute(ExecEnv& env) override;
MGE_WIN_DECLSPEC_FUC void do_execute(ExecEnv& env) override;


void init_output_static_infer_desc() override;
void outshape_by_symvar_do_get_output_shape(
MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override;
MGE_WIN_DECLSPEC_FUC void outshape_by_symvar_do_get_output_shape(
TensorShape& dest, const ShapeInferInfo& shpinfo) override; TensorShape& dest, const ShapeInferInfo& shpinfo) override;


void mem_plan_fwd_in2out_readonly() override;
MGE_WIN_DECLSPEC_FUC void mem_plan_fwd_in2out_readonly() override;


void add_input_layout_constraint() override;
MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint() override;


bool infer_shape(
MGE_WIN_DECLSPEC_FUC bool infer_shape(
size_t out_idx, TensorShape& dest, const cg::static_infer::InpVal& inp); size_t out_idx, TensorShape& dest, const cg::static_infer::InpVal& inp);


void on_mem_status_changed();
OprEventCallback get_opr_event_callback() override final;
MGE_WIN_DECLSPEC_FUC void on_mem_status_changed();
MGE_WIN_DECLSPEC_FUC OprEventCallback get_opr_event_callback() override final;


void init_subspec(bool memfwd);
MGE_WIN_DECLSPEC_FUC void init_subspec(bool memfwd);


void on_output_comp_node_stream_changed() override;
void init_rt_force_dynamic_mem_alloc_imply_chain() override;
MGE_WIN_DECLSPEC_FUC void on_output_comp_node_stream_changed() override;
MGE_WIN_DECLSPEC_FUC void init_rt_force_dynamic_mem_alloc_imply_chain() override;
}; };


/*! /*!
@@ -476,12 +481,13 @@ private:
* To concat to a different computing node, specify the destination in * To concat to a different computing node, specify the destination in
* OperatorNodeConfig * OperatorNodeConfig
*/ */
MGB_DEFINE_OPR_CLASS(Concat, cg::SingleCNOutshapePureByInshapeOprBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Concat, cg::SingleCNOutshapePureByInshapeOprBase) // {
public: public:
using Param = megdnn::param::Axis; using Param = megdnn::param::Axis;
Concat(const VarNodeArrayView& inp, int axis, const OperatorNodeConfig& config); Concat(const VarNodeArrayView& inp, int axis, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
const VarNodeArrayView& inp, int axis, const VarNodeArrayView& inp, int axis,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -500,15 +506,15 @@ public:
private: private:
int m_axis; int m_axis;


void scn_do_execute() override;
MGE_WIN_DECLSPEC_FUC void scn_do_execute() override;


NodeProp* do_make_node_prop() const override;
MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override;


void init_output_static_infer_desc() override;
void add_input_layout_constraint() override;
void init_output_comp_node() override;
MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override;
MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint() override;
MGE_WIN_DECLSPEC_FUC void init_output_comp_node() override;


void get_output_var_shape(
MGE_WIN_DECLSPEC_FUC void get_output_var_shape(
const TensorShapeArray& inp_shape, const TensorShapeArray& inp_shape,
TensorShapeArray& out_shape) const override; TensorShapeArray& out_shape) const override;
}; };
@@ -521,27 +527,27 @@ private:
* the begin and the end of inputs[i]'s offsets in output * the begin and the end of inputs[i]'s offsets in output
* \param offsets_val: offsets value on cpu * \param offsets_val: offsets value on cpu
*/ */
MGB_DEFINE_OPR_CLASS(ParamPackConcat, cg::SingleCNOperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(ParamPackConcat, cg::SingleCNOperatorNodeBase) // {
//! input pointer buffer //! input pointer buffer
SmallVector<void*> m_inp_ptr; SmallVector<void*> m_inp_ptr;
std::vector<dt_int32> m_offsets; std::vector<dt_int32> m_offsets;
intl::UniqPtrWithCN<megdnn::ParamPackConcat> m_opr; intl::UniqPtrWithCN<megdnn::ParamPackConcat> m_opr;


void add_input_layout_constraint() override;
void scn_do_execute() override;
void init_output_static_infer_desc() override;
void init_output_dtype() override;
void on_output_comp_node_stream_changed() override;
MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint() override;
MGE_WIN_DECLSPEC_FUC void scn_do_execute() override;
MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override;
MGE_WIN_DECLSPEC_FUC void init_output_dtype() override;
MGE_WIN_DECLSPEC_FUC void on_output_comp_node_stream_changed() override;


public: public:
using Param = megdnn::param::Empty; using Param = megdnn::param::Empty;


Param param() const { return {}; } Param param() const { return {}; }


ParamPackConcat(
MGE_WIN_DECLSPEC_FUC ParamPackConcat(
VarNodeArray& inp, VarNode* offsets, VarNodeArray& inp, VarNode* offsets,
const std::vector<dt_int32> offsets_val, const OperatorNodeConfig& config); const std::vector<dt_int32> offsets_val, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
const SmallVector<SymbolVar>& inp, const SymbolVar& offsets, const SmallVector<SymbolVar>& inp, const SymbolVar& offsets,
const std::vector<dt_int32> offsets_val, const std::vector<dt_int32> offsets_val,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
@@ -564,24 +570,24 @@ public:
* \param offsets_val: offsets value on cpu * \param offsets_val: offsets value on cpu
* \param shapes: shape of each output * \param shapes: shape of each output
*/ */
MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(ParamPackSplit, cg::SingleCNOperatorNodeBase) // {
TensorShapeArray m_shapes; TensorShapeArray m_shapes;
std::vector<dt_int32> m_offsets; std::vector<dt_int32> m_offsets;


void scn_do_execute() override;
void init_output_static_infer_desc() override;
bool infer_shape(
MGE_WIN_DECLSPEC_FUC void scn_do_execute() override;
MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override;
MGE_WIN_DECLSPEC_FUC bool infer_shape(
size_t index, TensorShape& dest, const cg::static_infer::InpVal& inp); size_t index, TensorShape& dest, const cg::static_infer::InpVal& inp);
void init_output_dtype() override;
void mem_plan_fwd_in2out_readonly() override;
void add_input_layout_constraint() override;
MGE_WIN_DECLSPEC_FUC void init_output_dtype() override;
MGE_WIN_DECLSPEC_FUC void mem_plan_fwd_in2out_readonly() override;
MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint() override;


public: public:
ParamPackSplit(
MGE_WIN_DECLSPEC_FUC ParamPackSplit(
VarNode* src, const std::vector<dt_int32> offsets, TensorShapeArray& shapes, VarNode* src, const std::vector<dt_int32> offsets, TensorShapeArray& shapes,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
const SymbolVar& src, const std::vector<dt_int32> offsets, const SymbolVar& src, const std::vector<dt_int32> offsets,
TensorShapeArray shapes, const OperatorNodeConfig& config = {}); TensorShapeArray shapes, const OperatorNodeConfig& config = {});


@@ -589,7 +595,7 @@ public:


const TensorShapeArray& get_output_shapes() const { return m_shapes; } const TensorShapeArray& get_output_shapes() const { return m_shapes; }


void init_rt_force_dynamic_mem_alloc_imply_chain() override;
MGE_WIN_DECLSPEC_FUC void init_rt_force_dynamic_mem_alloc_imply_chain() override;
}; };


/*! /*!
@@ -597,23 +603,25 @@ public:
* *
* See docs of megdnn params for more details * See docs of megdnn params for more details
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
RelayoutFormat, intl::MegDNNOprWrapperFwd<megdnn::RelayoutFormat>) // { RelayoutFormat, intl::MegDNNOprWrapperFwd<megdnn::RelayoutFormat>) // {
public: public:
RelayoutFormat(VarNode* src, const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC RelayoutFormat(
VarNode* src, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); SymbolVar src, const Param& param, const OperatorNodeConfig& config = {});
void init_output_format() override final;
MGE_WIN_DECLSPEC_FUC void init_output_format() override final;
}; };


/*! /*!
* \brief padding the src tensor to dst tensor * \brief padding the src tensor to dst tensor
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
PaddingForward, intl::MegDNNOprWrapperFwd<megdnn::PaddingForward>) // { PaddingForward, intl::MegDNNOprWrapperFwd<megdnn::PaddingForward>) // {
public: public:
PaddingForward(VarNode* src, const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC PaddingForward(
VarNode* src, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param = {}, SymbolVar src, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -622,13 +630,13 @@ using Padding = PaddingForward;
/*! /*!
* \brief padding backward * \brief padding backward
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
PaddingBackward, intl::MegDNNOprWrapperBwd<megdnn::PaddingBackward>) // { PaddingBackward, intl::MegDNNOprWrapperBwd<megdnn::PaddingBackward>) // {
public: public:
PaddingBackward(
MGE_WIN_DECLSPEC_FUC PaddingBackward(
VarNode* src, VarNode* in_for_shape, const Param& param, VarNode* src, VarNode* in_for_shape, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar in_for_shape, const Param& param = {}, SymbolVar src, SymbolVar in_for_shape, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };


+ 66
- 48
src/opr/include/megbrain/opr/utility.h View File

@@ -27,7 +27,7 @@ namespace opr {
/*! /*!
* \brief sleep for specific time on device * \brief sleep for specific time on device
*/ */
MGB_DEFINE_OPR_CLASS(Sleep, cg::SingleCNIOSameShapeOperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Sleep, cg::SingleCNIOSameShapeOperatorNodeBase) // {
public: public:
/*! /*!
* \brief directly sleep without constructing an opr * \brief directly sleep without constructing an opr
@@ -41,9 +41,10 @@ public:
Type(bool d = true, bool h = false) : device(d), host(h) {} Type(bool d = true, bool h = false) : device(d), host(h) {}
}; };


Sleep(VarNode* node, double seconds, Type type, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC Sleep(
VarNode* node, double seconds, Type type, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar node, double seconds, Type type = {}, SymbolVar node, double seconds, Type type = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -82,13 +83,13 @@ private:
* \param dest_off the offset on which \p dest should be modified; this helps * \param dest_off the offset on which \p dest should be modified; this helps
* multiple Timestamp operator instances * multiple Timestamp operator instances
*/ */
MGB_DEFINE_OPR_CLASS(Timestamp, intl::ForwardInputToOutput) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Timestamp, intl::ForwardInputToOutput) // {
public: public:
Timestamp(
MGE_WIN_DECLSPEC_FUC Timestamp(
VarNode* node, std::shared_ptr<HostTensorND> dest, size_t dest_off, VarNode* node, std::shared_ptr<HostTensorND> dest, size_t dest_off,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar node, std::shared_ptr<HostTensorND> dest, size_t dest_off, SymbolVar node, std::shared_ptr<HostTensorND> dest, size_t dest_off,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -111,14 +112,15 @@ private:
* \brief To make sure inputs' owner oprs finished when executing this operator, * \brief To make sure inputs' owner oprs finished when executing this operator,
* and forwarding input(0) to output. * and forwarding input(0) to output.
*/ */
MGB_DEFINE_OPR_CLASS(VirtualDep, intl::ForwardInputToOutput) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(VirtualDep, intl::ForwardInputToOutput) // {
public: public:
VirtualDep(const VarNodeArray& inputs, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC VirtualDep(
const VarNodeArray& inputs, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
const SymbolVarArray& inputs, const OperatorNodeConfig& config = {}); const SymbolVarArray& inputs, const OperatorNodeConfig& config = {});


NodeProp* do_make_node_prop() const override;
MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override;
// void add_input(std::initializer_list<VarNode*> list); // void add_input(std::initializer_list<VarNode*> list);
}; };


@@ -128,7 +130,7 @@ public:
* \brief do not provide any static infer on a var to mark it dynamic; used for * \brief do not provide any static infer on a var to mark it dynamic; used for
* debug purposes * debug purposes
*/ */
MGB_DEFINE_OPR_CLASS(MarkDynamicVar, cg::SingleCNOperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(MarkDynamicVar, cg::SingleCNOperatorNodeBase) // {
void scn_do_execute() override; void scn_do_execute() override;
void init_output_static_infer_desc() override {} void init_output_static_infer_desc() override {}
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;
@@ -136,9 +138,11 @@ MGB_DEFINE_OPR_CLASS(MarkDynamicVar, cg::SingleCNOperatorNodeBase) // {
public: public:
using Param = megdnn::param::Empty; using Param = megdnn::param::Empty;


MarkDynamicVar(VarNode* node, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC MarkDynamicVar(
VarNode* node, const OperatorNodeConfig& config);


static SymbolVar make(SymbolVar node, const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar node, const OperatorNodeConfig& config = {});


// for serialization // for serialization
Param param() const { return {}; } Param param() const { return {}; }
@@ -151,7 +155,7 @@ public:
/*! /*!
* \brief inject a callback to be called whenever this operator is executed * \brief inject a callback to be called whenever this operator is executed
*/ */
MGB_DEFINE_OPR_CLASS(CallbackInjector, intl::ForwardInputToOutput) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(CallbackInjector, intl::ForwardInputToOutput) // {
void scn_do_execute_finish(const DeviceTensorND& val) override; void scn_do_execute_finish(const DeviceTensorND& val) override;
cg::static_infer::ValueInferDesc mixin_get_static_infer_desc( cg::static_infer::ValueInferDesc mixin_get_static_infer_desc(
OperatorNodeBase& opr) override; OperatorNodeBase& opr) override;
@@ -197,10 +201,10 @@ public:
callback{std::move(cb)} {} callback{std::move(cb)} {}
}; };


CallbackInjector(
MGE_WIN_DECLSPEC_FUC CallbackInjector(
VarNode* inp, const Param& param, const OperatorNodeConfig& config); VarNode* inp, const Param& param, const OperatorNodeConfig& config);


CallbackInjector(
MGE_WIN_DECLSPEC_FUC CallbackInjector(
VarNodeArray& inp, const Param& param, const OperatorNodeConfig& config); VarNodeArray& inp, const Param& param, const OperatorNodeConfig& config);


//! create the operator disallowing auto dup //! create the operator disallowing auto dup
@@ -226,7 +230,7 @@ public:
return make(inp, Param{cb}, config); return make(inp, Param{cb}, config);
} }


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVarArray inp, const Param& param, SymbolVarArray inp, const Param& param,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -244,12 +248,15 @@ private:
* Useful for removing the reduce when computing grad, so graph optimizer can * Useful for removing the reduce when computing grad, so graph optimizer can
* work well. * work well.
*/ */
MGB_DEFINE_OPR_CLASS(MarkNoBroadcastElemwise, intl::ForwardInputToOutput) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
MarkNoBroadcastElemwise, intl::ForwardInputToOutput) // {
public: public:
using Param = megdnn::param::Empty; using Param = megdnn::param::Empty;
MarkNoBroadcastElemwise(VarNode* input, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC MarkNoBroadcastElemwise(
VarNode* input, const OperatorNodeConfig& config);


static SymbolVar make(SymbolVar input, const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar input, const OperatorNodeConfig& config = {});


// for serialization // for serialization
Param param() const { return {}; } Param param() const { return {}; }
@@ -265,14 +272,15 @@ public:
* Currently only used for preventing graph optimizer from removing some var so * Currently only used for preventing graph optimizer from removing some var so
* its gradient can be correctly computed. * its gradient can be correctly computed.
*/ */
MGB_DEFINE_OPR_CLASS(Identity, intl::ForwardInputToOutput) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Identity, intl::ForwardInputToOutput) // {
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;


public: public:
using Param = megdnn::param::Empty; using Param = megdnn::param::Empty;
Identity(VarNode* input, const OperatorNodeConfig& config); Identity(VarNode* input, const OperatorNodeConfig& config);


static SymbolVar make(SymbolVar input, const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar input, const OperatorNodeConfig& config = {});


// for serialization // for serialization
Param param() const { return {}; } Param param() const { return {}; }
@@ -288,7 +296,7 @@ public:
* *
* raise UnequalError during exec if tensor not equal * raise UnequalError during exec if tensor not equal
*/ */
MGB_DEFINE_OPR_CLASS(AssertEqual, intl::ForwardInputToOutput) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(AssertEqual, intl::ForwardInputToOutput) // {
bool m_throw_on_error = true; bool m_throw_on_error = true;
HostTensorND m_hv; HostTensorND m_hv;


@@ -298,11 +306,11 @@ public:
using Param = megdnn::param::AssertEqual; using Param = megdnn::param::AssertEqual;


//! \p expect and \p get are only used for error message //! \p expect and \p get are only used for error message
AssertEqual(
MGE_WIN_DECLSPEC_FUC AssertEqual(
VarNode* expect, VarNode* get, VarNode* err, const Param& param, VarNode* expect, VarNode* get, VarNode* err, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar expect, SymbolVar get, const Param& param = {}, SymbolVar expect, SymbolVar get, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -310,7 +318,7 @@ public:
void disable_throw_on_error() { m_throw_on_error = false; } void disable_throw_on_error() { m_throw_on_error = false; }


//! for serialization and shallow copy //! for serialization and shallow copy
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar expect, SymbolVar get, SymbolVar err, const Param& param, SymbolVar expect, SymbolVar get, SymbolVar err, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


@@ -331,14 +339,15 @@ private:
* \brief output equals to input, but grad(input) would be replaced by return * \brief output equals to input, but grad(input) would be replaced by return
* value of given callback at runtime * value of given callback at runtime
*/ */
MGB_DEFINE_OPR_CLASS(SetGrad, intl::ForwardInputToOutput) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(SetGrad, intl::ForwardInputToOutput) // {
public: public:
using GradGetter = thin_function<SymbolVar(const SetGrad&)>; using GradGetter = thin_function<SymbolVar(const SetGrad&)>;


SetGrad(VarNode* input, const GradGetter& grad_getter,
MGE_WIN_DECLSPEC_FUC SetGrad(
VarNode* input, const GradGetter& grad_getter,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar input, const GradGetter& grad_getter, SymbolVar input, const GradGetter& grad_getter,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -354,7 +363,8 @@ private:
/*! /*!
* \brief get a special marker for a grad being invalid * \brief get a special marker for a grad being invalid
*/ */
MGB_DEFINE_OPR_CLASS(InvalidGrad, cg::SingleCNIOSameShapeOperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
InvalidGrad, cg::SingleCNIOSameShapeOperatorNodeBase) // {
const OperatorNodeBase* m_grad_opr; const OperatorNodeBase* m_grad_opr;
size_t m_inp_idx; size_t m_inp_idx;


@@ -364,9 +374,11 @@ MGB_DEFINE_OPR_CLASS(InvalidGrad, cg::SingleCNIOSameShapeOperatorNodeBase) // {


public: public:
//! \p vinp should be grad_opr.input(inp_idx), unless in shallow copy //! \p vinp should be grad_opr.input(inp_idx), unless in shallow copy
InvalidGrad(VarNode* vinp, const OperatorNodeBase* grad_opr, size_t inp_idx);
MGE_WIN_DECLSPEC_FUC InvalidGrad(
VarNode* vinp, const OperatorNodeBase* grad_opr, size_t inp_idx);


static VarNode* make(const OperatorNodeBase& grad_opr, size_t inp_idx);
MGE_WIN_DECLSPEC_FUC static VarNode* make(
const OperatorNodeBase& grad_opr, size_t inp_idx);


size_t inp_idx() const { return m_inp_idx; } size_t inp_idx() const { return m_inp_idx; }


@@ -380,7 +392,7 @@ public:
* This operator exists so graph optimization can be performed without actual * This operator exists so graph optimization can be performed without actual
* grad oprs. This operator must be expanded before graph execution. * grad oprs. This operator must be expanded before graph execution.
*/ */
MGB_DEFINE_OPR_CLASS(VirtualGrad, cg::OperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(VirtualGrad, cg::OperatorNodeBase) // {
void do_execute(ExecEnv&) override; void do_execute(ExecEnv&) override;
void init_output_comp_node() override; void init_output_comp_node() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
@@ -390,10 +402,11 @@ MGB_DEFINE_OPR_CLASS(VirtualGrad, cg::OperatorNodeBase) // {
public: public:
using Param = megdnn::param::Empty; using Param = megdnn::param::Empty;


VirtualGrad(VarNode* target, VarNode* wrt, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC VirtualGrad(
VarNode* target, VarNode* wrt, const OperatorNodeConfig& config);


Param param() const { return {}; } Param param() const { return {}; }
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar target, SymbolVar wrt, Param param = {}, SymbolVar target, SymbolVar wrt, Param param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
@@ -403,7 +416,7 @@ public:
* *
* The gradient w.r.t. \p ys[i] would be \p y_grads[i] * The gradient w.r.t. \p ys[i] would be \p y_grads[i]
*/ */
MGB_DEFINE_OPR_CLASS(VirtualLoss, cg::OperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(VirtualLoss, cg::OperatorNodeBase) // {
void do_execute(ExecEnv&) override; void do_execute(ExecEnv&) override;
void init_output_comp_node() override; void init_output_comp_node() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
@@ -414,9 +427,10 @@ public:
using Param = megdnn::param::Empty; using Param = megdnn::param::Empty;


//! the first half of \p inputs contain ys, and the remaining are y_grads //! the first half of \p inputs contain ys, and the remaining are y_grads
VirtualLoss(const VarNodeArray& inputs, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC VirtualLoss(
const VarNodeArray& inputs, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
const SymbolVarArray& ys, const SymbolVarArray& y_grads, Param param = {}, const SymbolVarArray& ys, const SymbolVarArray& y_grads, Param param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -427,7 +441,8 @@ public:
class InvalidGrad { class InvalidGrad {
public: public:
using OperatorNodeBase = cg::OperatorNodeBase; using OperatorNodeBase = cg::OperatorNodeBase;
[[noreturn]] static VarNode* make(const OperatorNodeBase& grad_opr, size_t inp_idx);
[[noreturn]] MGE_WIN_DECLSPEC_FUC static VarNode* make(
const OperatorNodeBase& grad_opr, size_t inp_idx);
}; };
#endif // MGB_ENABLE_GRAD #endif // MGB_ENABLE_GRAD


@@ -447,15 +462,15 @@ public:
* \see VarNode::Flag::NO_MEM_RECLAIM for eliminating only dynamic memory * \see VarNode::Flag::NO_MEM_RECLAIM for eliminating only dynamic memory
* deallocation * deallocation
*/ */
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
PersistentOutputStorage, cg::SingleCNIOSameShapeOperatorNodeBase) // { PersistentOutputStorage, cg::SingleCNIOSameShapeOperatorNodeBase) // {
public: public:
using Param = megdnn::param::PersistentOutputStorage; using Param = megdnn::param::PersistentOutputStorage;


PersistentOutputStorage(
MGE_WIN_DECLSPEC_FUC PersistentOutputStorage(
VarNode* inp, const Param& param, const OperatorNodeConfig& config); VarNode* inp, const Param& param, const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar inp, const Param& param = {}, SymbolVar inp, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -473,16 +488,19 @@ private:
void record_execute_deps(ExecDependencyArray& deps) override; void record_execute_deps(ExecDependencyArray& deps) override;
}; };


MGB_DEFINE_OPR_CLASS(RequireInputDynamicStorage, intl::ForwardInputToOutput) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
RequireInputDynamicStorage, intl::ForwardInputToOutput) // {
public: public:
RequireInputDynamicStorage(VarNode* input, const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar input, const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC RequireInputDynamicStorage(
VarNode* input, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar input, const OperatorNodeConfig& config = {});
}; };


/* /*
* \brief a special op providing shape hint only used in graph compilation (gopt) * \brief a special op providing shape hint only used in graph compilation (gopt)
*/ */
MGB_DEFINE_OPR_CLASS(ShapeHint, cg::SingleCNOperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(ShapeHint, cg::SingleCNOperatorNodeBase) // {
TensorShape m_shape; TensorShape m_shape;
bool m_is_const; bool m_is_const;


@@ -490,11 +508,11 @@ MGB_DEFINE_OPR_CLASS(ShapeHint, cg::SingleCNOperatorNodeBase) // {
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;


public: public:
ShapeHint(
MGE_WIN_DECLSPEC_FUC ShapeHint(
VarNode* inp, const TensorShape shape, bool is_const, VarNode* inp, const TensorShape shape, bool is_const,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar inp, const TensorShape shape, bool is_const = false, SymbolVar inp, const TensorShape shape, bool is_const = false,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});




+ 2
- 2
src/plugin/include/megbrain/plugin/cpu_dispatch_checker.h View File

@@ -33,10 +33,10 @@ class CPUDispatchChecker final : public PluginBase {
void check(CompNode cn, cg::OperatorNodeBase* opr); void check(CompNode cn, cg::OperatorNodeBase* opr);


public: public:
CPUDispatchChecker(cg::ComputingGraph* graph);
MGE_WIN_DECLSPEC_FUC CPUDispatchChecker(cg::ComputingGraph* graph);


//! get oprs that did not call cpu dispatch //! get oprs that did not call cpu dispatch
auto&& failed_oprs() const { return *m_failed_oprs; }
MGE_WIN_DECLSPEC_FUC auto&& failed_oprs() const { return *m_failed_oprs; }
}; };
} // namespace mgb } // namespace mgb




+ 7
- 5
src/plugin/include/megbrain/plugin/infkern_finder.h View File

@@ -60,23 +60,25 @@ public:
using FullRecord = std::vector<std::pair<VarNode*, InputValueRecord>>; using FullRecord = std::vector<std::pair<VarNode*, InputValueRecord>>;
}; };


InfkernFinder(cg::ComputingGraph* graph, bool record_input_value);
~InfkernFinder() noexcept;
MGE_WIN_DECLSPEC_FUC InfkernFinder(
cg::ComputingGraph* graph, bool record_input_value);
MGE_WIN_DECLSPEC_FUC ~InfkernFinder() noexcept;


//! this constructor should not be called by user //! this constructor should not be called by user
InfkernFinder(cg::ComputingGraph* graph, GlobalState* global_state);
MGE_WIN_DECLSPEC_FUC InfkernFinder(
cg::ComputingGraph* graph, GlobalState* global_state);


/*! /*!
* \brief write execution status to file * \brief write execution status to file
* \return the first operator whose output is not finished; or * \return the first operator whose output is not finished; or
* nullptr if all finished * nullptr if all finished
*/ */
cg::OperatorNodeBase* write_to_file(const char* fpath);
MGE_WIN_DECLSPEC_FUC cg::OperatorNodeBase* write_to_file(const char* fpath);


/*! /*!
* \brief get previous input values for dumped operators * \brief get previous input values for dumped operators
*/ */
InputValueRecord::FullRecord get_input_values(size_t opr_id);
MGE_WIN_DECLSPEC_FUC InputValueRecord::FullRecord get_input_values(size_t opr_id);
}; };


} // namespace mgb } // namespace mgb


+ 1
- 1
src/plugin/include/megbrain/plugin/num_range_checker.h View File

@@ -54,7 +54,7 @@ class NumRangeChecker final : public PluginBase {


public: public:
using Error = NumRangeCheckerError; using Error = NumRangeCheckerError;
NumRangeChecker(cg::ComputingGraph* graph, float range);
MGE_WIN_DECLSPEC_FUC NumRangeChecker(cg::ComputingGraph* graph, float range);
}; };
} // namespace mgb } // namespace mgb




+ 6
- 5
src/plugin/include/megbrain/plugin/opr_footprint.h View File

@@ -36,7 +36,7 @@ class OprFootprint {
void add_single_param_json(); void add_single_param_json();


//! be invoked when OprFootprint initilizing. //! be invoked when OprFootprint initilizing.
void init_all_footprints();
MGE_WIN_DECLSPEC_FUC void init_all_footprints();


public: public:
struct Result { struct Result {
@@ -74,15 +74,16 @@ public:
OprFootprint() { init_all_footprints(); } OprFootprint() { init_all_footprints(); }


//! return footprint rst for associated opr. //! return footprint rst for associated opr.
Result calc_footprint(cg::OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC Result calc_footprint(cg::OperatorNodeBase* opr);
//! get computation of a given operator //! get computation of a given operator
uint64_t get_computation(cg::OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC uint64_t get_computation(cg::OperatorNodeBase* opr);
#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
std::shared_ptr<json::Value> get_param_json(cg::OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> get_param_json(
cg::OperatorNodeBase* opr);
//! get opr foot print and graph exec info //! get opr foot print and graph exec info
//! the function will recompile graph, AsyncExecutable compiled before will //! the function will recompile graph, AsyncExecutable compiled before will
//! be invalid //! be invalid
static std::shared_ptr<json::Value> get_opr_fp_graph_exec(
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<json::Value> get_opr_fp_graph_exec(
cg::ComputingGraph& graph, const SymbolVarArray& outputs); cg::ComputingGraph& graph, const SymbolVarArray& outputs);
#endif #endif
}; };


+ 5
- 4
src/plugin/include/megbrain/plugin/opr_io_dump.h View File

@@ -66,7 +66,7 @@ class TextOprIODump final : public OprIODumpBase {
void dump_var(VarNode* var, bool lazy_sync) override; void dump_var(VarNode* var, bool lazy_sync) override;


public: public:
TextOprIODump(
MGE_WIN_DECLSPEC_FUC TextOprIODump(
cg::ComputingGraph* graph, cg::ComputingGraph* graph,
const std::shared_ptr<FILE>& fout = const std::shared_ptr<FILE>& fout =
std::shared_ptr<FILE>(stderr, [](FILE*) {})); std::shared_ptr<FILE>(stderr, [](FILE*) {}));
@@ -74,7 +74,7 @@ public:
TextOprIODump(cg::ComputingGraph* graph, const char* fpath) TextOprIODump(cg::ComputingGraph* graph, const char* fpath)
: TextOprIODump(graph, std::shared_ptr<FILE>(fopen(fpath, "w"), fclose)) {} : TextOprIODump(graph, std::shared_ptr<FILE>(fopen(fpath, "w"), fclose)) {}


~TextOprIODump();
MGE_WIN_DECLSPEC_FUC ~TextOprIODump();


void flush_lazy() override; void flush_lazy() override;


@@ -109,8 +109,9 @@ class BinaryOprIODump final : public OprIODumpBase {
void dump_var(VarNode* var, bool lazy_sync) override; void dump_var(VarNode* var, bool lazy_sync) override;


public: public:
BinaryOprIODump(cg::ComputingGraph* graph, std::string output_dir);
~BinaryOprIODump();
MGE_WIN_DECLSPEC_FUC BinaryOprIODump(
cg::ComputingGraph* graph, std::string output_dir);
MGE_WIN_DECLSPEC_FUC ~BinaryOprIODump();
void flush_lazy() override; void flush_lazy() override;
}; };




+ 3
- 3
src/plugin/include/megbrain/plugin/profiler.h View File

@@ -83,13 +83,13 @@ class GraphProfiler final : public PluginBase {
void record_event(CompNodeEventPtr& dest, CompNode comp_node); void record_event(CompNodeEventPtr& dest, CompNode comp_node);


public: public:
GraphProfiler(cg::ComputingGraph* graph);
~GraphProfiler() noexcept;
MGE_WIN_DECLSPEC_FUC GraphProfiler(cg::ComputingGraph* graph);
MGE_WIN_DECLSPEC_FUC ~GraphProfiler() noexcept;


/*! /*!
* \brief convert only profiling result to json * \brief convert only profiling result to json
*/ */
std::shared_ptr<json::Object> to_json() const;
MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Object> to_json() const;


/*! /*!
* \brief dump to visualizer format * \brief dump to visualizer format


+ 1
- 1
src/plugin/include/megbrain/plugin/var_value_checker.h View File

@@ -60,7 +60,7 @@ class VarValueChecker final : public PluginBase {
public: public:
using Error = opr::AssertEqual::UnequalError; using Error = opr::AssertEqual::UnequalError;


VarValueChecker(
MGE_WIN_DECLSPEC_FUC VarValueChecker(
ComputingGraph* graph, size_t var_switch_interval = 1, ComputingGraph* graph, size_t var_switch_interval = 1,
size_t init_var_idx = 0); size_t init_var_idx = 0);
}; };


+ 4
- 0
src/serialization/include/megbrain/serialization/extern_c_opr.h View File

@@ -16,7 +16,11 @@
#include <stdint.h> #include <stdint.h>
#include <string.h> #include <string.h>


#ifdef MGE_DLL_EXPORT
#define MGB_PUBLIC __declspec(dllexport)
#else
#define MGB_PUBLIC __attribute__((visibility("default"))) #define MGB_PUBLIC __attribute__((visibility("default")))
#endif


#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {


+ 13
- 10
src/serialization/include/megbrain/serialization/extern_c_opr_io.h View File

@@ -19,7 +19,8 @@ namespace mgb {
namespace opr { namespace opr {


//! an operator to run extern C oprs //! an operator to run extern C oprs
MGB_DEFINE_OPR_CLASS(ExternCOprRunner, cg::SingleCNOutshapePureByInshapeOprBase) // {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
ExternCOprRunner, cg::SingleCNOutshapePureByInshapeOprBase) // {
std::shared_ptr<MGBOprDesc> m_desc; std::shared_ptr<MGBOprDesc> m_desc;
//! store ExternCOprRunner opr full dump name //! store ExternCOprRunner opr full dump name
std::string m_dump_name; std::string m_dump_name;
@@ -40,12 +41,12 @@ MGB_DEFINE_OPR_CLASS(ExternCOprRunner, cg::SingleCNOutshapePureByInshapeOprBase)
std::shared_ptr<MGBOprDesc> desc, const OperatorNodeConfig& config); std::shared_ptr<MGBOprDesc> desc, const OperatorNodeConfig& config);


public: public:
ExternCOprRunner(
MGE_WIN_DECLSPEC_FUC ExternCOprRunner(
std::string& name, const VarNodeArray& inputs, std::string& name, const VarNodeArray& inputs,
std::shared_ptr<MGBOprDesc> desc, const OperatorNodeConfig& config); std::shared_ptr<MGBOprDesc> desc, const OperatorNodeConfig& config);


//! create from MGBOprDesc and steal its reference //! create from MGBOprDesc and steal its reference
static cg::OperatorNodeBase* make_from_desc(
MGE_WIN_DECLSPEC_FUC static cg::OperatorNodeBase* make_from_desc(
std::string& name, const VarNodeArray& inputs, MGBOprDesc* desc, std::string& name, const VarNodeArray& inputs, MGBOprDesc* desc,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});


@@ -61,7 +62,7 @@ public:
* \param data_len length of \p data * \param data_len length of \p data
* \param output_dtypes predefined output dtypes * \param output_dtypes predefined output dtypes
*/ */
static cg::OperatorNodeBase* make_placeholder(
MGE_WIN_DECLSPEC_FUC static cg::OperatorNodeBase* make_placeholder(
const SymbolVarArray& inputs, const TensorShapeArray& output_shapes, const SymbolVarArray& inputs, const TensorShapeArray& output_shapes,
const char* name, const void* data, size_t data_len, const char* name, const void* data, size_t data_len,
const OperatorNodeConfig& config = {}, const OperatorNodeConfig& config = {},
@@ -71,28 +72,30 @@ public:
* \brief unregister a MGBOprLoader * \brief unregister a MGBOprLoader
* \return whether any loader is removed (i.e. whether the name exists) * \return whether any loader is removed (i.e. whether the name exists)
*/ */
static bool unregister_loader(const char* name);
MGE_WIN_DECLSPEC_FUC static bool unregister_loader(const char* name);


//! impl for serialization dump //! impl for serialization dump
static void dump(
MGE_WIN_DECLSPEC_FUC static void dump(
serialization::OprDumpContext& ctx, const cg::OperatorNodeBase& opr); serialization::OprDumpContext& ctx, const cg::OperatorNodeBase& opr);


//! impl for serialization load //! impl for serialization load
static cg::OperatorNodeBase* load(
MGE_WIN_DECLSPEC_FUC static cg::OperatorNodeBase* load(
serialization::OprLoadContext& ctx, const cg::VarNodeArray& inputs, serialization::OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


//! impl for serialization shallow copy //! impl for serialization shallow copy
static cg::OperatorNodeBase* shallow_copy(
MGE_WIN_DECLSPEC_FUC static cg::OperatorNodeBase* shallow_copy(
const serialization::OprShallowCopyContext& ctx, const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr, const VarNodeArray& inputs, const cg::OperatorNodeBase& opr, const VarNodeArray& inputs,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


//! helper for converting TensorShape to MGBTensorShape //! helper for converting TensorShape to MGBTensorShape
static ::MGBTensorShape tensor_shape_to_c(const TensorShape& shape);
MGE_WIN_DECLSPEC_FUC static ::MGBTensorShape tensor_shape_to_c(
const TensorShape& shape);


//! helper for converting MGBTensorShape to TensorShape //! helper for converting MGBTensorShape to TensorShape
static TensorShape tensor_shape_from_c(const MGBTensorShape& shape);
MGE_WIN_DECLSPEC_FUC static TensorShape tensor_shape_from_c(
const MGBTensorShape& shape);


const std::string& get_dump_name() { return m_dump_name; } const std::string& get_dump_name() { return m_dump_name; }




+ 9
- 6
src/serialization/include/megbrain/serialization/file.h View File

@@ -72,11 +72,12 @@ public:
virtual SharedBuffer read_shared(size_t size); virtual SharedBuffer read_shared(size_t size);


//! create an InputFile correspoding to a file on local file system //! create an InputFile correspoding to a file on local file system
static std::unique_ptr<InputFile> make_fs(const char* path);
MGE_WIN_DECLSPEC_FUC static std::unique_ptr<InputFile> make_fs(const char* path);


//! create an InputFile correspoding to a memory region; the memory //! create an InputFile correspoding to a memory region; the memory
//! region must be alive throughout lifespan of this InputFile //! region must be alive throughout lifespan of this InputFile
static std::unique_ptr<InputFile> make_mem_proxy(const void* ptr, size_t size);
MGE_WIN_DECLSPEC_FUC static std::unique_ptr<InputFile> make_mem_proxy(
const void* ptr, size_t size);


/*! /*!
* \brief create an InputFile that would directly reuse the memory * \brief create an InputFile that would directly reuse the memory
@@ -86,7 +87,7 @@ public:
* If this is set to true, tensor storage can be aggressively * If this is set to true, tensor storage can be aggressively
* shared by reusing the buffer for alignment. * shared by reusing the buffer for alignment.
*/ */
static std::unique_ptr<InputFile> make_mem_proxy(
MGE_WIN_DECLSPEC_FUC static std::unique_ptr<InputFile> make_mem_proxy(
std::shared_ptr<void> ptr, size_t size, bool writable = true); std::shared_ptr<void> ptr, size_t size, bool writable = true);
}; };


@@ -108,7 +109,8 @@ public:
virtual size_t tell() = 0; virtual size_t tell() = 0;


//! create an OutputFile correspoding to a file on local file system //! create an OutputFile correspoding to a file on local file system
static std::unique_ptr<OutputFile> make_fs(const char* path, char mode = 'w');
MGE_WIN_DECLSPEC_FUC static std::unique_ptr<OutputFile> make_fs(
const char* path, char mode = 'w');


/*! /*!
* \brief create an OutputFile to write to a std::vector * \brief create an OutputFile to write to a std::vector
@@ -116,8 +118,9 @@ public:
* Note that the vector must be alive throughout lifespan of this * Note that the vector must be alive throughout lifespan of this
* OutputFile. Current content in *buf* would not be cleared. * OutputFile. Current content in *buf* would not be cleared.
*/ */
static std::unique_ptr<OutputFile> make_vector_proxy(std::vector<uint8_t>* buf);
MGE_WIN_DECLSPEC_FUC static std::unique_ptr<OutputFile> make_vector_proxy(
std::vector<uint8_t>* buf);
}; };


} // namespace serialization } // namespace serialization
} // namespace mgb
} // namespace mgb

+ 1
- 1
src/serialization/include/megbrain/serialization/helper.h View File

@@ -19,7 +19,7 @@
namespace mgb { namespace mgb {
namespace serialization { namespace serialization {


void serialize_dtype(
MGE_WIN_DECLSPEC_FUC void serialize_dtype(
DType dtype, megdnn::thin_function<void(const void*, size_t)> write_fn); DType dtype, megdnn::thin_function<void(const void*, size_t)> write_fn);
DType deserialize_dtype(megdnn::thin_function<void(void*, size_t)> read_fn); DType deserialize_dtype(megdnn::thin_function<void(void*, size_t)> read_fn);




Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save