from __future__ import absolute_import import ctypes from .._base import _LIB from .. import ndarray as _nd def normal_init(arr, mean, stddev, seed, stream=None): assert isinstance(arr, _nd.NDArray) _LIB.DLGpuNormalInit(arr.handle, ctypes.c_float(mean), ctypes.c_float( stddev), ctypes.c_ulonglong(seed), stream.handle if stream else None) def uniform_init(arr, lb, ub, seed, stream=None): assert isinstance(arr, _nd.NDArray) _LIB.DLGpuUniformInit(arr.handle, ctypes.c_float(lb), ctypes.c_float( ub), ctypes.c_ulonglong(seed), stream.handle if stream else None) def truncated_normal_init(arr, mean, stddev, seed, stream=None): # time consuming !! assert isinstance(arr, _nd.NDArray) _LIB.DLGpuTruncatedNormalInit(arr.handle, ctypes.c_float(mean), ctypes.c_float( stddev), ctypes.c_ulonglong(seed), stream.handle if stream else None)