You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MatrixTransLink.py 785 B

4 years ago
12345678910111213141516171819202122
  1. from __future__ import absolute_import
  2. import ctypes
  3. from .._base import _LIB
  4. from .. import ndarray as _nd
  5. def matrix_transpose(in_mat, out_mat, perm, stream=None):
  6. assert isinstance(in_mat, _nd.NDArray)
  7. assert isinstance(out_mat, _nd.NDArray)
  8. pointer_func = ctypes.c_int * len(perm)
  9. pointer = pointer_func(*list(perm))
  10. _LIB.DLGpuTranspose(in_mat.handle, out_mat.handle,
  11. pointer, stream.handle if stream else None)
  12. def matrix_transpose_simple(in_mat, out_mat, gpu_buf, stream=None):
  13. assert isinstance(in_mat, _nd.NDArray)
  14. assert isinstance(out_mat, _nd.NDArray)
  15. assert isinstance(gpu_buf, _nd.NDArray)
  16. _LIB.DLGpuTransposeSimple(
  17. in_mat.handle, out_mat.handle, gpu_buf.handle, stream.handle if stream else None)