You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MaxPoolLink.py 895 B

4 years ago
1234567891011121314151617181920
  1. from __future__ import absolute_import
  2. import ctypes
  3. from .._base import _LIB
  4. from .. import ndarray as _nd
  5. def max_pooling2d(in_arr, kernel_H, kernel_W, pooled_layer, padding=0, stride=1, stream=None):
  6. assert isinstance(in_arr, _nd.NDArray)
  7. assert isinstance(pooled_layer, _nd.NDArray)
  8. _LIB.DLGpuMax_Pooling2d(in_arr.handle, kernel_H,
  9. kernel_W, pooled_layer.handle, padding, stride, stream.handle if stream else None)
  10. def max_pooling2d_gradient(in_arr, in_grad_arr, kernel_H, kernel_W, out_grad_arr, padding=0, stride=1, stream=None):
  11. assert isinstance(in_arr, _nd.NDArray)
  12. assert isinstance(in_grad_arr, _nd.NDArray)
  13. assert isinstance(out_grad_arr, _nd.NDArray)
  14. _LIB.DLGpuMax_Pooling2d_gradient(
  15. in_arr.handle, in_grad_arr.handle, kernel_H, kernel_W, out_grad_arr.handle, padding, stride, stream.handle if stream else None)