|
- from __future__ import absolute_import
-
- import ctypes
- from .._base import _LIB
- from .. import ndarray as _nd
-
-
- def add_l2_regularization(param, grad, l2reg, stream=None):
- assert isinstance(param, _nd.NDArray)
- assert isinstance(grad, (_nd.NDArray, _nd.IndexedSlices))
- # not support indexed slices now
- if isinstance(grad, _nd.NDArray):
- _LIB.AddL2Regularization(param.handle, grad.handle, ctypes.c_float(
- l2reg), stream.handle if stream else None)
-
-
- def sgd_update(param, grad, lr, stream=None):
- assert isinstance(param, _nd.NDArray)
- assert isinstance(grad, (_nd.NDArray, _nd.IndexedSlices))
- if isinstance(grad, _nd.NDArray):
- _LIB.SGDOptimizerUpdate(param.handle, grad.handle, ctypes.c_float(
- lr), stream.handle if stream else None)
- else:
- assert isinstance(grad.indices, _nd.NDArray)
- assert isinstance(grad.values, _nd.NDArray)
- _LIB.SGDOptimizerSparseUpdate(param.handle, grad.indices.handle, grad.values.handle, ctypes.c_float(
- lr), stream.handle if stream else None)
-
-
- def momentum_update(param, grad, velocity, lr, momentum, nesterov, stream=None):
- assert isinstance(param, _nd.NDArray)
- assert isinstance(grad, (_nd.NDArray, _nd.IndexedSlices))
- assert isinstance(velocity, _nd.NDArray)
- if isinstance(grad, _nd.NDArray):
- _LIB.MomentumOptimizerUpdate(param.handle, grad.handle, velocity.handle, ctypes.c_float(
- lr), ctypes.c_float(momentum), ctypes.c_bool(nesterov), stream.handle if stream else None)
- else:
- assert isinstance(grad.indices, _nd.NDArray)
- assert isinstance(grad.values, _nd.NDArray)
- _LIB.MomentumOptimizerSparseUpdate(param.handle, grad.indices.handle, grad.values.handle, velocity.handle, ctypes.c_float(
- lr), ctypes.c_float(momentum), ctypes.c_bool(nesterov), stream.handle if stream else None)
-
-
- def adagrad_update(param, grad, accumulation, lr, eps, stream=None):
- assert isinstance(param, _nd.NDArray)
- assert isinstance(grad, (_nd.NDArray, _nd.IndexedSlices))
- assert isinstance(accumulation, _nd.NDArray)
- if isinstance(grad, _nd.NDArray):
- _LIB.AdaGradOptimizerUpdate(param.handle, grad.handle, accumulation.handle, ctypes.c_float(
- lr), ctypes.c_float(eps), stream.handle if stream else None)
- else:
- grad.deduplicate(stream)
- assert isinstance(grad.indices, _nd.NDArray)
- assert isinstance(grad.values, _nd.NDArray)
- _LIB.AdaGradOptimizerSparseUpdate(param.handle, grad.indices.handle, grad.values.handle, accumulation.handle, ctypes.c_float(
- lr), ctypes.c_float(eps), stream.handle if stream else None)
- grad.free_deduplicate()
-
-
- def adam_update(param, grad, expavg, expavgsq, lr, beta1, beta2, beta1t, beta2t, eps, stream=None):
- assert isinstance(param, _nd.NDArray)
- assert isinstance(grad, (_nd.NDArray, _nd.IndexedSlices))
- assert isinstance(expavg, _nd.NDArray)
- assert isinstance(expavgsq, _nd.NDArray)
- if isinstance(grad, _nd.NDArray):
- _LIB.AdamOptimizerUpdate(param.handle, grad.handle, expavg.handle, expavgsq.handle, ctypes.c_float(lr), ctypes.c_float(beta1), ctypes.c_float(beta2),
- ctypes.c_float(beta1t), ctypes.c_float(beta2t), ctypes.c_float(eps), stream.handle if stream else None)
- else:
- grad.deduplicate(stream)
- assert isinstance(grad.indices, _nd.NDArray)
- assert isinstance(grad.values, _nd.NDArray)
- _LIB.AdamOptimizerSparseUpdate(param.handle, grad.indices.handle, grad.values.handle, expavg.handle, expavgsq.handle, ctypes.c_float(lr), ctypes.c_float(beta1), ctypes.c_float(beta2),
- ctypes.c_float(beta1t), ctypes.c_float(beta2t), ctypes.c_float(eps), stream.handle if stream else None)
- grad.free_deduplicate()
|