|
- """ ctypes library of hetusys and helper functions """
- from __future__ import absolute_import
-
- import os
- import ctypes
-
-
- def _check_functions(lib, func_dict):
- for func in func_dict:
- if hasattr(lib, func):
- func_dict[func] = True
-
-
- # Defines a dictionary to indicate whether to use DNNL.True is use,False not use.
- DNNL_LIB = {
- 'DnnlMatrixMultiply': False,
- 'DnnlMatrixElementwiseMultiplyByConst': False,
- 'DnnlMatrixElementwiseMultiply': False,
- 'DnnlMatrixElementwiseAddByConst': False,
- 'DnnlMatrixElementwiseAdd': False,
- 'DnnlMatrixElementwiseDivideByConst': False,
- 'DnnlMatrixElementwiseDivide': False,
- 'cpu_BroadcastTo': False, # c++
- 'cpu_ReduceSumAxisZero': False,
- 'cpu_ArraySet': False,
- 'cpu_Reshape': False, # c++
- 'DnnlSoftmax': False,
- 'DnnlSoftmaxCrossEntropy': False, # c++
- 'DnnlSoftmaxCrossEntropy_Gradient': False, # c++
- 'DnnlSqrt': False,
- 'DnnlReciprocalSqrt': False,
- 'DnnlTanh': False,
- 'DnnlOpposite': False,
- 'DnnlSigmoid': False,
- 'DnnlConv2d': False,
- 'DnnlConv2d_Gradient_of_Filter': False,
- 'DnnlConv2d_Gradient_of_Data': False,
- 'DnnlAvgPool': False,
- 'DnnlAvgPool_Gradient': False,
- 'DnnlMaxPool': False,
- 'DnnlMaxPool_Gradient': False,
- 'DnnlRelu': False,
- 'DnnlRelu_Gradient': False,
- 'DnnlBatchNorm': False,
- 'DnnlBatchNorm_Gradient': False,
- 'DnnlBatchNorm_Inference': False,
- 'DnnlConcat': False,
- 'cpu_Concat_Gradient': False, # c++
- 'cpu_Dropout': False, # c++
- 'cpu_Dropout_Gradient': False, # c++
- 'cpu_Pad': False, # c++
- 'cpu_Pad_Gradient': False, # c++
- 'cpu_EmbeddingLookup': False, # c++
- 'cpu_Transpose': False, # c++
- 'cpu_SGDOptimizerUpdate': False, # c++
- 'cpu_SGDOptimizerSparseUpdate': False, # c++
- 'cpu_MomentumOptimizerUpdate': False, # c++
- 'cpu_AdaGradOptimizerUpdate': False, # c++
- 'cpu_AdamOptimizerUpdate': False, # c++
- 'cpu_UniformInit': False, # c++
- 'cpu_NormalInit': False, # c++
- 'cpu_TruncatedNormalInit': False, # c++
- }
-
-
- def _load_lib():
- """Load libary in build/lib."""
- curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
- lib_path = os.path.join(curr_path, '../../build/lib/')
- path_to_so_file = os.path.join(lib_path, "libc_runtime_api.so")
- lib = ctypes.CDLL(path_to_so_file, ctypes.RTLD_GLOBAL)
- _check_functions(lib, DNNL_LIB)
- return lib
-
-
- # global library instance
- _LIB = _load_lib()
-
-
- ##################
- # Helper Methods #
- ##################
-
- def check_call(ret):
- """Check the return value of C API call
-
- This function will crash when error occurs.
- Wrap every API call with this function
-
- Parameters
- ----------
- ret : int
- return value from API calls
- """
- assert(ret == 0)
-
-
- def c_array(ctype, values):
- """Create ctypes array from a python array
-
- Parameters
- ----------
- ctype : ctypes data type
- data type of the array we want to convert to
-
- values : tuple or list
- data content
-
- Returns
- -------
- out : ctypes array
- Created ctypes array
- """
- return (ctype * len(values))(*values)
|