| @@ -21,3 +21,6 @@ ci/resource/prof/model_with_err_assert.mdl filter=lfs diff=lfs merge=lfs -text | |||||
| ci/resource/prof/test_mge.mge filter=lfs diff=lfs merge=lfs -text | ci/resource/prof/test_mge.mge filter=lfs diff=lfs merge=lfs -text | ||||
| lite/test/resource/lite/ax_models/64-58063ce2.axe filter=lfs diff=lfs merge=lfs -text | lite/test/resource/lite/ax_models/64-58063ce2.axe filter=lfs diff=lfs merge=lfs -text | ||||
| imperative/python/test/unit/module/MagicMindRuntimeOprTest.GraphShapeMutable.mlu filter=lfs diff=lfs merge=lfs -text | imperative/python/test/unit/module/MagicMindRuntimeOprTest.GraphShapeMutable.mlu filter=lfs diff=lfs merge=lfs -text | ||||
| lite/test/resource/lite/ax_data_input.npy filter=lfs diff=lfs merge=lfs -text | |||||
| lite/test/resource/lite/ax_data_output.npy filter=lfs diff=lfs merge=lfs -text | |||||
| lite/test/resource/lite/ax_model.mge filter=lfs diff=lfs merge=lfs -text | |||||
| @@ -29,7 +29,6 @@ jobs: | |||||
| uses: actions/checkout@v2 | uses: actions/checkout@v2 | ||||
| - name: Checkout submodules | - name: Checkout submodules | ||||
| run: | | run: | | ||||
| apt update&&apt install ninja-build | |||||
| ./third_party/prepare.sh | ./third_party/prepare.sh | ||||
| ./third_party/install-mkl.sh | ./third_party/install-mkl.sh | ||||
| - name: Build MegEngine | - name: Build MegEngine | ||||
| @@ -58,7 +57,6 @@ jobs: | |||||
| uses: actions/checkout@v2 | uses: actions/checkout@v2 | ||||
| - name: Checkout submodules | - name: Checkout submodules | ||||
| run: | | run: | | ||||
| apt update&&apt install ninja-build | |||||
| ./third_party/prepare.sh | ./third_party/prepare.sh | ||||
| ./third_party/install-mkl.sh | ./third_party/install-mkl.sh | ||||
| - name: Build MegEngine | - name: Build MegEngine | ||||
| @@ -12,7 +12,7 @@ MegEngine is a fast, scalable and easy-to-use deep learning framework, with auto | |||||
| ## Installation | ## Installation | ||||
| **NOTE:** MegEngine now supports Python installation on Linux-64bit/Windows-64bit/MacOS(CPU-Only)-10.14+/Android 7+(CPU-Only) platforms with Python from 3.5 to 3.8. On Windows 10 you can either install the Linux distribution through [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) or install the Windows distribution directly. Many other platforms are supported for inference. | |||||
| **NOTE:** MegEngine now supports Python installation on Linux-64bit/Windows-64bit/MacOS(CPU-Only)-10.14+ platforms with Python from 3.5 to 3.8. On Windows 10 you can either install the Linux distribution through [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) or install the Windows distribution directly. Many other platforms are supported for inference. | |||||
| ### Binaries | ### Binaries | ||||
| @@ -13,7 +13,7 @@ MegEngine 是一个快速、可拓展、易于使用且支持自动求导的深 | |||||
| ## 安装说明 | ## 安装说明 | ||||
| **注意:** MegEngine 现在支持在 Linux-64bit/Windows-64bit/macos-10.14/Android 7+ 及其以上 (MacOS/Android只支持cpu) 等平台上安装 Python 包,支持Python3.5 到 Python3.8。对于 Windows 10 用户,可以通过安装 [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) 进行体验,同时我们也原生支持Windows。MegEngine 也支持在很多其它平台上进行推理运算。 | |||||
| **注意:** MegEngine 现在支持在 Linux-64bit/Windows-64bit/macos-10.14及其以上 (MacOS只支持cpu) 等平台上安装 Python 包,支持Python3.5 到 Python3.8。对于 Windows 10 用户,可以通过安装 [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) 进行体验,同时我们也原生支持Windows。MegEngine 也支持在很多其它平台上进行推理运算。 | |||||
| ### 通过包管理器安装 | ### 通过包管理器安装 | ||||
| @@ -26,8 +26,8 @@ python3 -m pip install megengine -f https://megengine.org.cn/whl/mge.html | |||||
| ## 通过源码编译安装 | ## 通过源码编译安装 | ||||
| * CMake 编译细节请参考 [BUILD_README.md](scripts/cmake-build/BUILD_README.md) | |||||
| * Python 绑定编译细节请参考 [BUILD_PYTHON_WHL_README.md](scripts/whl/BUILD_PYTHON_WHL_README.md) | |||||
| * CMake编译细节请参考 [BUILD_README.md](scripts/cmake-build/BUILD_README.md) | |||||
| * Python绑定编译细节请参考 [BUILD_PYTHON_WHL_README.md](scripts/whl/BUILD_PYTHON_WHL_README.md) | |||||
| ## 如何参与贡献 | ## 如何参与贡献 | ||||
| @@ -27,8 +27,7 @@ function build() { | |||||
| -DMGE_WITH_DISTRIBUTED=${DMGE_WITH_DISTRIBUTED} \ | -DMGE_WITH_DISTRIBUTED=${DMGE_WITH_DISTRIBUTED} \ | ||||
| -DMGE_WITH_CUDA=${DMGE_WITH_CUDA} \ | -DMGE_WITH_CUDA=${DMGE_WITH_CUDA} \ | ||||
| -DMGE_WITH_TEST=ON \ | -DMGE_WITH_TEST=ON \ | ||||
| -DCMAKE_BUILD_TYPE=RelWithDebInfo \ | |||||
| -DMGE_WITH_CUSTOM_OP=ON | |||||
| -DCMAKE_BUILD_TYPE=RelWithDebInfo | |||||
| make -j$(($(nproc) * 2)) -I ${build_dir} | make -j$(($(nproc) * 2)) -I ${build_dir} | ||||
| make develop | make develop | ||||
| popd >/dev/null | popd >/dev/null | ||||
| @@ -363,6 +363,58 @@ static inline void trans_8x4_u16( | |||||
| vst1q_u16(dst_ptr + 3 * dst_step, row_3); | vst1q_u16(dst_ptr + 3 * dst_step, row_3); | ||||
| } | } | ||||
| static inline void trans_8x3_u16( | |||||
| const void* src, void* dst, const size_t src_step, const size_t dst_step) { | |||||
| uint16_t* src_ptr = (uint16_t*)src; | |||||
| uint16_t* dst_ptr = (uint16_t*)dst; | |||||
| uint16x4_t src0 = vld1_u16(src_ptr + 0 * src_step); // A0A1A2A3 | |||||
| uint16x4_t src1 = vld1_u16(src_ptr + 1 * src_step); // B0B1B2B3 | |||||
| uint16x4_t src2 = vld1_u16(src_ptr + 2 * src_step); // C0C1C2C3 | |||||
| uint16x4_t src3 = vld1_u16(src_ptr + 3 * src_step); // D0D1D2D3 | |||||
| uint16x4_t src4 = vld1_u16(src_ptr + 4 * src_step); // E0E1E2E3 | |||||
| uint16x4_t src5 = vld1_u16(src_ptr + 5 * src_step); // F0F1F2F3 | |||||
| uint16x4_t src6 = vld1_u16(src_ptr + 6 * src_step); // G0G1G2G3 | |||||
| // H0H1H2 | |||||
| uint16x4_t src7 = | |||||
| vreinterpret_u16_u32(vld1_dup_u32((uint32_t*)(src_ptr + 7 * src_step))); | |||||
| src7 = vld1_lane_u16(src_ptr + 7 * src_step + 2, src7, 2); | |||||
| uint16x4_t ab_low = vzip1_u16(src0, src1); // A0B0A1B1 | |||||
| uint16x4_t ab_high = vzip2_u16(src0, src1); // A2B2A3B3 | |||||
| uint16x4_t cd_low = vzip1_u16(src2, src3); // C0D0C1D1 | |||||
| uint16x4_t cd_high = vzip2_u16(src2, src3); // C2D2C3D3 | |||||
| uint16x4_t ef_low = vzip1_u16(src4, src5); // E0F0E1F1 | |||||
| uint16x4_t ef_high = vzip2_u16(src4, src5); // E2F2E3F3 | |||||
| uint16x4_t gh_low = vzip1_u16(src6, src7); // G0H0G1H1 | |||||
| uint16x4_t gh_high = vzip2_u16(src6, src7); // G2H2G3 | |||||
| uint16x4_t abcd_0 = vreinterpret_u16_u32(vzip1_u32( | |||||
| vreinterpret_u32_u16(ab_low), | |||||
| vreinterpret_u32_u16(cd_low))); // A0B0C0D0 | |||||
| uint16x4_t abcd_1 = vreinterpret_u16_u32(vzip2_u32( | |||||
| vreinterpret_u32_u16(ab_low), | |||||
| vreinterpret_u32_u16(cd_low))); // A1B1C1D1 | |||||
| uint16x4_t abcd_2 = vreinterpret_u16_u32(vzip1_u32( | |||||
| vreinterpret_u32_u16(ab_high), | |||||
| vreinterpret_u32_u16(cd_high))); // A2B2C2D2 | |||||
| uint16x4_t efgh_0 = vreinterpret_u16_u32(vzip1_u32( | |||||
| vreinterpret_u32_u16(ef_low), | |||||
| vreinterpret_u32_u16(gh_low))); // E0F0G0H0 | |||||
| uint16x4_t efgh_1 = vreinterpret_u16_u32(vzip2_u32( | |||||
| vreinterpret_u32_u16(ef_low), | |||||
| vreinterpret_u32_u16(gh_low))); // E1F1G1H1 | |||||
| uint16x4_t efgh_2 = vreinterpret_u16_u32(vzip1_u32( | |||||
| vreinterpret_u32_u16(ef_high), | |||||
| vreinterpret_u32_u16(gh_high))); // E2F2G2H2 | |||||
| uint16x8_t row_0 = vcombine_u16(abcd_0, efgh_0); | |||||
| uint16x8_t row_1 = vcombine_u16(abcd_1, efgh_1); | |||||
| uint16x8_t row_2 = vcombine_u16(abcd_2, efgh_2); | |||||
| vst1q_u16(dst_ptr + 0 * dst_step, row_0); | |||||
| vst1q_u16(dst_ptr + 1 * dst_step, row_1); | |||||
| vst1q_u16(dst_ptr + 2 * dst_step, row_2); | |||||
| } | |||||
| } // anonymous namespace | } // anonymous namespace | ||||
| namespace megdnn { | namespace megdnn { | ||||
| @@ -410,6 +462,8 @@ void transpose_block<Transpose2Byte>( | |||||
| const size_t dst_stride, size_t block_h, size_t block_w) { | const size_t dst_stride, size_t block_h, size_t block_w) { | ||||
| if (block_h == 8 && block_w == 4) { | if (block_h == 8 && block_w == 4) { | ||||
| trans_8x4_u16(src, dst, src_stride, dst_stride); | trans_8x4_u16(src, dst, src_stride, dst_stride); | ||||
| } else if (block_h == 8 && block_w == 3) { | |||||
| trans_8x3_u16(src, dst, src_stride, dst_stride); | |||||
| } else { | } else { | ||||
| transpose_block_fallback(src, dst, src_stride, dst_stride, block_h, block_w); | transpose_block_fallback(src, dst, src_stride, dst_stride, block_h, block_w); | ||||
| } | } | ||||
| @@ -40,6 +40,9 @@ TEST_F(AARCH64, Relayout) { | |||||
| TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype); | TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype); | ||||
| checker.execl({src, dst}); | checker.execl({src, dst}); | ||||
| } | } | ||||
| TensorLayout src_4_3({1, 3, 112, 256}, {3, 1, 1024, 4}, dtype::Uint16()); | |||||
| TensorLayout dst_4_3({1, 3, 112, 256}, {86016, 28672, 256, 1}, dtype::Uint16()); | |||||
| checker.execl({src_4_3, dst_4_3}); | |||||
| } | } | ||||
| TEST_F(AARCH64, RelayoutNonContig) { | TEST_F(AARCH64, RelayoutNonContig) { | ||||
| @@ -50,7 +50,9 @@ _sh = _stream_helper() | |||||
| def _valid_device(inp): | def _valid_device(inp): | ||||
| if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp): | |||||
| if isinstance(inp, str) and re.match( | |||||
| "^([cxg]pu|rocm|multithread)(\d+|\d+:\d+|x)$", inp | |||||
| ): | |||||
| return True | return True | ||||
| return False | return False | ||||
| @@ -1153,35 +1153,39 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||||
| def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: | def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: | ||||
| r"""Computes the singular value decompositions of input matrix. | |||||
| r"""Returns a singular value decomposition ``A = USVh`` of a matrix (or a stack of matrices) ``x`` , where ``U`` is a matrix (or a stack of matrices) with orthonormal columns, ``S`` is a vector of non-negative numbers (or stack of vectors), and ``Vh`` is a matrix (or a stack of matrices) with orthonormal rows. | |||||
| Args: | Args: | ||||
| inp: input matrix, must has shape `[..., M, N]`. | |||||
| x (Tensor): A input real tensor having the shape ``(..., M, N)`` with ``x.ndim >= 2`` . | |||||
| full_matrices (bool, optional): If ``False`` , ``U`` and ``Vh`` have the shapes ``(..., M, K)`` and ``(..., K, N)`` , respectively, where ``K = min(M, N)`` . If ``True`` , the shapes are ``(..., M, M)`` and ``(..., N, N)`` , respectively. Default: ``False`` . | |||||
| compute_uv (bool, optional): Whether or not to compute ``U`` and ``Vh`` in addition to ``S`` . Default: ``True`` . | |||||
| Note: | |||||
| * naive does not support ``full_matrices`` and ``compute_uv`` as ``True`` . | |||||
| Returns: | Returns: | ||||
| output matrices, `(U, sigma, V)`. | |||||
| Returns a tuple ( ``U`` , ``S`` , ``Vh`` ), which are SVD factors ``U`` , ``S``, ``Vh`` of input matrix ``x``. ( ``U`` , ``Vh`` only returned when ``compute_uv`` is True). | |||||
| ``U`` contains matrices orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True`` , the array must have shape ``(..., M, M)`` . If ``full_matrices`` is ``False`` , the array must have shape ``(..., M, K)`` , where ``K = min(M, N)`` . | |||||
| Examples: | Examples: | ||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3)) | |||||
| _, y, _ = F.svd(x) | |||||
| print(y.numpy().round(decimals=3)) | |||||
| >>> import numpy as np | |||||
| >>> x = Tensor(np.random.randn(9, 6)) | |||||
| >>> y = Tensor(np.random.randn(2, 7, 8, 3)) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| Reconstruction based on reduced SVD, 2D case: | |||||
| >>> U, S, Vh = F.svd(x, full_matrices=False) | |||||
| >>> print(U._tuple_shape, S._tuple_shape, Vh._tuple_shape) | |||||
| (9, 6) (6,) (6, 6) | |||||
| [7.348 1. ] | |||||
| Reconsturction based on reduced SVD, 4D case: | |||||
| >>> u, s, vh = F.svd(y, full_matrices=False) | |||||
| >>> print(u._tuple_shape, s._tuple_shape, vh._tuple_shape) | |||||
| (2, 7, 8, 3) (2, 7, 3) (2, 7, 3, 3) | |||||
| """ | """ | ||||
| op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) | op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) | ||||
| U, sigma, V = apply(op, inp) | |||||
| return U, sigma, V | |||||
| U, S, Vh = apply(op, inp) | |||||
| return U, S, Vh | |||||
| def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: | def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: | ||||
| @@ -74,7 +74,7 @@ def calculate_gain( | |||||
| ) -> float: | ) -> float: | ||||
| r"""Returns a recommended gain value (see the table below) for the given nonlinearity | r"""Returns a recommended gain value (see the table below) for the given nonlinearity | ||||
| function. | function. | ||||
| ================= ==================================================== | ================= ==================================================== | ||||
| nonlinearity gain | nonlinearity gain | ||||
| ================= ==================================================== | ================= ==================================================== | ||||
| @@ -126,6 +126,11 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: | |||||
| r"""Calculates fan_in / fan_out value for given weight tensor. This function assumes | r"""Calculates fan_in / fan_out value for given weight tensor. This function assumes | ||||
| input tensor is stored in ``NCHW`` format. | input tensor is stored in ``NCHW`` format. | ||||
| Note: | |||||
| The group conv2d kernel shape in MegEngine is ``(G, O/G, I/G, K, K)``. This | |||||
| function calculates ``fan_out = O/G * K * K`` as default, but PyTorch uses | |||||
| ``fan_out = O * K * K``. | |||||
| Args: | Args: | ||||
| tensor: weight tensor in ``NCHW`` format. | tensor: weight tensor in ``NCHW`` format. | ||||
| """ | """ | ||||
| @@ -141,6 +146,10 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: | |||||
| fan_in = shape[1] | fan_in = shape[1] | ||||
| fan_out = shape[0] | fan_out = shape[0] | ||||
| else: | else: | ||||
| if ndim >= 5: | |||||
| # ignore the groups dimension of group conv2d and group conv3d | |||||
| # FIXME: will be wrong for conv3d | |||||
| shape = shape[1:] | |||||
| num_input_fmaps = shape[1] | num_input_fmaps = shape[1] | ||||
| num_output_fmaps = shape[0] | num_output_fmaps = shape[0] | ||||
| receptive_field_size = 1 | receptive_field_size = 1 | ||||
| @@ -154,7 +163,7 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: | |||||
| def calculate_correct_fan(tensor: Tensor, mode: str) -> float: | def calculate_correct_fan(tensor: Tensor, mode: str) -> float: | ||||
| r"""Calculates fan_in / fan_out value for given weight tensor, depending on given | r"""Calculates fan_in / fan_out value for given weight tensor, depending on given | ||||
| ``mode``. | ``mode``. | ||||
| See :func:`calculate_fan_in_and_fan_out` for details. | See :func:`calculate_fan_in_and_fan_out` for details. | ||||
| Args: | Args: | ||||
| @@ -175,11 +184,11 @@ def calculate_correct_fan(tensor: Tensor, mode: str) -> float: | |||||
| def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: | def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: | ||||
| r"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)` | r"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)` | ||||
| where | where | ||||
| .. math:: | .. math:: | ||||
| a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} | a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} | ||||
| Also known as Glorot initialization. Detailed information can be retrieved from | Also known as Glorot initialization. Detailed information can be retrieved from | ||||
| `Understanding the difficulty of training deep feedforward neural networks` - | `Understanding the difficulty of training deep feedforward neural networks` - | ||||
| Glorot, X. & Bengio, Y. (2010). | Glorot, X. & Bengio, Y. (2010). | ||||
| @@ -197,11 +206,11 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: | |||||
| def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: | def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: | ||||
| r"""Fills tensor with random values sampled from | r"""Fills tensor with random values sampled from | ||||
| :math:`\mathcal{N}(0, \text{std}^2)` where | :math:`\mathcal{N}(0, \text{std}^2)` where | ||||
| .. math:: | .. math:: | ||||
| \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} | \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} | ||||
| Also known as Glorot initialization. Detailed information can be retrieved from | Also known as Glorot initialization. Detailed information can be retrieved from | ||||
| `Understanding the difficulty of training deep feedforward neural networks` - | `Understanding the difficulty of training deep feedforward neural networks` - | ||||
| Glorot, X. & Bengio, Y. (2010). | Glorot, X. & Bengio, Y. (2010). | ||||
| @@ -220,11 +229,11 @@ def msra_uniform_( | |||||
| ) -> None: | ) -> None: | ||||
| r"""Fills tensor wilth random values sampled from | r"""Fills tensor wilth random values sampled from | ||||
| :math:`\mathcal{U}(-\text{bound}, \text{bound})` where | :math:`\mathcal{U}(-\text{bound}, \text{bound})` where | ||||
| .. math:: | .. math:: | ||||
| \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}} | \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}} | ||||
| Detailed information can be retrieved from | Detailed information can be retrieved from | ||||
| `Delving deep into rectifiers: Surpassing human-level performance on ImageNet | `Delving deep into rectifiers: Surpassing human-level performance on ImageNet | ||||
| classification` | classification` | ||||
| @@ -251,11 +260,11 @@ def msra_normal_( | |||||
| ) -> None: | ) -> None: | ||||
| r"""Fills tensor wilth random values sampled from | r"""Fills tensor wilth random values sampled from | ||||
| :math:`\mathcal{N}(0, \text{std}^2)` where | :math:`\mathcal{N}(0, \text{std}^2)` where | ||||
| .. math:: | .. math:: | ||||
| \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}} | \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}} | ||||
| Detailed information can be retrieved from | Detailed information can be retrieved from | ||||
| `Delving deep into rectifiers: Surpassing human-level performance on ImageNet | `Delving deep into rectifiers: Surpassing human-level performance on ImageNet | ||||
| classification` | classification` | ||||
| @@ -10,7 +10,7 @@ import numpy as np | |||||
| import pytest | import pytest | ||||
| from megengine import tensor | from megengine import tensor | ||||
| from megengine.module import Conv2d, Linear | |||||
| from megengine.module import Conv1d, Conv2d, Conv3d, Linear | |||||
| from megengine.module.init import calculate_fan_in_and_fan_out, fill_ | from megengine.module.init import calculate_fan_in_and_fan_out, fill_ | ||||
| @@ -32,7 +32,34 @@ def test_calculate_fan_in_and_fan_out(): | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| calculate_fan_in_and_fan_out(l.bias) | calculate_fan_in_and_fan_out(l.bias) | ||||
| l = Conv1d(in_channels=2, out_channels=3, kernel_size=5) | |||||
| fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | |||||
| assert fanin == 2 * 5 | |||||
| assert fanout == 3 * 5 | |||||
| # FIXME: will be wrong for group conv1d | |||||
| # l = Conv1d(in_channels=2, out_channels=4, kernel_size=5, groups=2) | |||||
| # fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | |||||
| # assert fanin == 2 // 2 * 5 | |||||
| # assert fanout == 4 // 2 * 5 | |||||
| l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7)) | l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7)) | ||||
| fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | ||||
| assert fanin == 2 * 5 * 7 | assert fanin == 2 * 5 * 7 | ||||
| assert fanout == 3 * 5 * 7 | assert fanout == 3 * 5 * 7 | ||||
| l = Conv2d(in_channels=2, out_channels=4, kernel_size=(5, 7), groups=2) | |||||
| fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | |||||
| assert fanin == 2 // 2 * 5 * 7 | |||||
| assert fanout == 4 // 2 * 5 * 7 | |||||
| # FIXME: will be wrong for conv3d | |||||
| # l = Conv3d(in_channels=2, out_channels=3, kernel_size=(5, 7, 9)) | |||||
| # fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | |||||
| # assert fanin == 2 * 5 * 7 * 9 | |||||
| # assert fanout == 3 * 5 * 7 * 9 | |||||
| l = Conv3d(in_channels=2, out_channels=4, kernel_size=(5, 7, 9), groups=2) | |||||
| fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | |||||
| assert fanin == 2 // 2 * 5 * 7 * 9 | |||||
| assert fanout == 4 // 2 * 5 * 7 * 9 | |||||
| @@ -154,6 +154,21 @@ LITE_API void set_tensor_rt_cache(std::string tensorrt_cache_path); | |||||
| */ | */ | ||||
| LITE_API void dump_tensor_rt_cache(); | LITE_API void dump_tensor_rt_cache(); | ||||
| /** | |||||
| * register the physical and virtual address pair to the mge, some device | |||||
| * need the map from physical to virtual. | |||||
| */ | |||||
| LITE_API bool register_memory_pair( | |||||
| void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device, | |||||
| LiteBackend backend = LiteBackend::LITE_DEFAULT); | |||||
| /** | |||||
| * clear the physical and virtual address pair in mge. | |||||
| */ | |||||
| LITE_API bool clear_memory_pair( | |||||
| void* vir_ptr, void* phy_ptr, LiteDeviceType device, | |||||
| LiteBackend backend = LiteBackend::LITE_DEFAULT); | |||||
| } // namespace lite | } // namespace lite | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -160,9 +160,24 @@ LITE_API int LITE_dump_persistent_cache(const char* cache_path); | |||||
| * \brief dump the tensorrt policy cache to file | * \brief dump the tensorrt policy cache to file | ||||
| */ | */ | ||||
| LITE_API int LITE_dump_tensor_rt_cache(); | LITE_API int LITE_dump_tensor_rt_cache(); | ||||
| #endif | |||||
| /** | |||||
| * register the physical and virtual address pair to the mge, some device | |||||
| * need the map from physical to virtual. | |||||
| */ | |||||
| LITE_API int LITE_register_memory_pair( | |||||
| void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device, | |||||
| LiteBackend backend); | |||||
| /** | |||||
| * clear the physical and virtual address pair in mge. | |||||
| */ | |||||
| LITE_API int LITE_clear_memory_pair( | |||||
| void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -189,4 +189,19 @@ int LITE_dump_tensor_rt_cache() { | |||||
| LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
| } | } | ||||
| int LITE_register_memory_pair( | |||||
| void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device, | |||||
| LiteBackend backend) { | |||||
| LITE_CAPI_BEGIN(); | |||||
| lite::register_memory_pair(vir_ptr, phy_ptr, length, device, backend); | |||||
| LITE_CAPI_END(); | |||||
| } | |||||
| int LITE_clear_memory_pair( | |||||
| void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend) { | |||||
| LITE_CAPI_BEGIN(); | |||||
| lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend); | |||||
| LITE_CAPI_END(); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -42,6 +42,8 @@ class _GlobalAPI(_LiteCObjBase): | |||||
| # ('LITE_set_tensor_rt_cache', [c_char_p]), | # ('LITE_set_tensor_rt_cache', [c_char_p]), | ||||
| ("LITE_dump_persistent_cache", [c_char_p]), | ("LITE_dump_persistent_cache", [c_char_p]), | ||||
| ("LITE_dump_tensor_rt_cache", [c_char_p]), | ("LITE_dump_tensor_rt_cache", [c_char_p]), | ||||
| ("LITE_register_memory_pair", [c_void_p, c_void_p, c_size_t, c_int, c_int]), | |||||
| ("LITE_clear_memory_pair", [c_void_p, c_void_p, c_int, c_int]), | |||||
| ] | ] | ||||
| @@ -121,3 +123,21 @@ class LiteGlobal(object): | |||||
| @staticmethod | @staticmethod | ||||
| def try_coalesce_all_free_memory(): | def try_coalesce_all_free_memory(): | ||||
| LiteGlobal._api.LITE_try_coalesce_all_free_memory() | LiteGlobal._api.LITE_try_coalesce_all_free_memory() | ||||
| @staticmethod | |||||
| def register_memory_pair( | |||||
| vir_ptr, phy_ptr, length, device, backend=LiteBackend.LITE_DEFAULT | |||||
| ): | |||||
| assert isinstance(vir_ptr, c_void_p) and isinstance( | |||||
| phy_ptr, c_void_p | |||||
| ), "clear memory pair only accept c_void_p type." | |||||
| LiteGlobal._api.LITE_register_memory_pair( | |||||
| vir_ptr, phy_ptr, length, device, backend | |||||
| ) | |||||
| @staticmethod | |||||
| def clear_memory_pair(vir_ptr, phy_ptr, device, backend=LiteBackend.LITE_DEFAULT): | |||||
| assert isinstance(vir_ptr, c_void_p) and isinstance( | |||||
| phy_ptr, c_void_p | |||||
| ), "clear memory pair only accept c_void_p type." | |||||
| LiteGlobal._api.LITE_clear_memory_pair(vir_ptr, phy_ptr, device, backend) | |||||
| @@ -212,6 +212,26 @@ void lite::dump_tensor_rt_cache() { | |||||
| #endif | #endif | ||||
| } | } | ||||
| bool lite::register_memory_pair( | |||||
| void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device, | |||||
| LiteBackend backend) { | |||||
| LITE_MARK_USED_VAR(vir_ptr); | |||||
| LITE_MARK_USED_VAR(phy_ptr); | |||||
| LITE_MARK_USED_VAR(length); | |||||
| LITE_MARK_USED_VAR(device); | |||||
| LITE_MARK_USED_VAR(backend); | |||||
| LITE_THROW("register_memory_pair is not implement yet!"); | |||||
| } | |||||
| bool lite::clear_memory_pair( | |||||
| void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend) { | |||||
| LITE_MARK_USED_VAR(vir_ptr); | |||||
| LITE_MARK_USED_VAR(phy_ptr); | |||||
| LITE_MARK_USED_VAR(device); | |||||
| LITE_MARK_USED_VAR(backend); | |||||
| LITE_THROW("clear_memory_pair is not implement yet!"); | |||||
| } | |||||
| #else // LITE_BUILD_WITH_MGE | #else // LITE_BUILD_WITH_MGE | ||||
| void lite::try_coalesce_all_free_memory() {} | void lite::try_coalesce_all_free_memory() {} | ||||
| @@ -235,6 +255,17 @@ void lite::set_tensor_rt_cache(std::string) { | |||||
| void lite::dump_tensor_rt_cache() { | void lite::dump_tensor_rt_cache() { | ||||
| LITE_THROW("mge is disbale at build time, please build with mge"); | LITE_THROW("mge is disbale at build time, please build with mge"); | ||||
| } | } | ||||
| bool lite::register_memory_pair( | |||||
| void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device, | |||||
| LiteBackend beckend) { | |||||
| LITE_THROW("register_memory_pair is not implement yet!"); | |||||
| } | |||||
| bool lite::clear_memory_pair( | |||||
| void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend beckend) { | |||||
| LITE_THROW("clear_memory_pair is not implement yet!"); | |||||
| } | |||||
| #endif | #endif | ||||
| namespace lite { | namespace lite { | ||||
| REGIST_DECRYPTION_METHOD( | REGIST_DECRYPTION_METHOD( | ||||
| @@ -1357,5 +1357,6 @@ TEST(TestNetWork, CambriconDeviceID) { | |||||
| load_device_id(LiteDeviceType::LITE_CAMBRICON, 0, "./model_magicmind.mgb"); | load_device_id(LiteDeviceType::LITE_CAMBRICON, 0, "./model_magicmind.mgb"); | ||||
| } | } | ||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||