You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_base.py 3.2 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """ ctypes library of hetusys and helper functions """
  2. from __future__ import absolute_import
  3. import os
  4. import ctypes
  5. def _check_functions(lib, func_dict):
  6. for func in func_dict:
  7. if hasattr(lib, func):
  8. func_dict[func] = True
  9. # Defines a dictionary to indicate whether to use DNNL.True is use,False not use.
  10. DNNL_LIB = {
  11. 'DnnlMatrixMultiply': False,
  12. 'DnnlMatrixElementwiseMultiplyByConst': False,
  13. 'DnnlMatrixElementwiseMultiply': False,
  14. 'DnnlMatrixElementwiseAddByConst': False,
  15. 'DnnlMatrixElementwiseAdd': False,
  16. 'DnnlMatrixElementwiseDivideByConst': False,
  17. 'DnnlMatrixElementwiseDivide': False,
  18. 'cpu_BroadcastTo': False, # c++
  19. 'cpu_ReduceSumAxisZero': False,
  20. 'cpu_ArraySet': False,
  21. 'cpu_Reshape': False, # c++
  22. 'DnnlSoftmax': False,
  23. 'DnnlSoftmaxCrossEntropy': False, # c++
  24. 'DnnlSoftmaxCrossEntropy_Gradient': False, # c++
  25. 'DnnlSqrt': False,
  26. 'DnnlReciprocalSqrt': False,
  27. 'DnnlTanh': False,
  28. 'DnnlOpposite': False,
  29. 'DnnlSigmoid': False,
  30. 'DnnlConv2d': False,
  31. 'DnnlConv2d_Gradient_of_Filter': False,
  32. 'DnnlConv2d_Gradient_of_Data': False,
  33. 'DnnlAvgPool': False,
  34. 'DnnlAvgPool_Gradient': False,
  35. 'DnnlMaxPool': False,
  36. 'DnnlMaxPool_Gradient': False,
  37. 'DnnlRelu': False,
  38. 'DnnlRelu_Gradient': False,
  39. 'DnnlBatchNorm': False,
  40. 'DnnlBatchNorm_Gradient': False,
  41. 'DnnlBatchNorm_Inference': False,
  42. 'DnnlConcat': False,
  43. 'cpu_Concat_Gradient': False, # c++
  44. 'cpu_Dropout': False, # c++
  45. 'cpu_Dropout_Gradient': False, # c++
  46. 'cpu_Pad': False, # c++
  47. 'cpu_Pad_Gradient': False, # c++
  48. 'cpu_EmbeddingLookup': False, # c++
  49. 'cpu_Transpose': False, # c++
  50. 'cpu_SGDOptimizerUpdate': False, # c++
  51. 'cpu_SGDOptimizerSparseUpdate': False, # c++
  52. 'cpu_MomentumOptimizerUpdate': False, # c++
  53. 'cpu_AdaGradOptimizerUpdate': False, # c++
  54. 'cpu_AdamOptimizerUpdate': False, # c++
  55. 'cpu_UniformInit': False, # c++
  56. 'cpu_NormalInit': False, # c++
  57. 'cpu_TruncatedNormalInit': False, # c++
  58. }
  59. def _load_lib():
  60. """Load libary in build/lib."""
  61. curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
  62. lib_path = os.path.join(curr_path, '../../build/lib/')
  63. path_to_so_file = os.path.join(lib_path, "libc_runtime_api.so")
  64. lib = ctypes.CDLL(path_to_so_file, ctypes.RTLD_GLOBAL)
  65. _check_functions(lib, DNNL_LIB)
  66. return lib
  67. # global library instance
  68. _LIB = _load_lib()
  69. ##################
  70. # Helper Methods #
  71. ##################
  72. def check_call(ret):
  73. """Check the return value of C API call
  74. This function will crash when error occurs.
  75. Wrap every API call with this function
  76. Parameters
  77. ----------
  78. ret : int
  79. return value from API calls
  80. """
  81. assert(ret == 0)
  82. def c_array(ctype, values):
  83. """Create ctypes array from a python array
  84. Parameters
  85. ----------
  86. ctype : ctypes data type
  87. data type of the array we want to convert to
  88. values : tuple or list
  89. data content
  90. Returns
  91. -------
  92. out : ctypes array
  93. Created ctypes array
  94. """
  95. return (ctype * len(values))(*values)