You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Dropout2d.py 2.7 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from __future__ import absolute_import
  2. from .Node import Op
  3. import ctypes
  4. import numpy as np
  5. from .._base import DNNL_LIB
  6. #from ..cpu_links import dropout as cpu_dropout
  7. #from ..cpu_links import dropout_gradient as cpu_dropout_gradient
  8. from ..gpu_links import dropout2d_gradient
  9. from ..gpu_links import dropout2d
  10. class Dropout2dOp(Op):
  11. def __init__(self, node_in, keep_prob, ctx=None):
  12. super().__init__(Dropout2dOp, [node_in], ctx)
  13. self.seed = ctypes.c_ulonglong(0)
  14. self.mask = None
  15. self.keep_prob = keep_prob
  16. def compute(self, input_vals, output_val, stream_handle=None, inference=False):
  17. if inference:
  18. if self.on_cpu:
  19. output_val[:] = input_vals[0].asnumpy()
  20. else:
  21. input_vals[0].copyto(output_val)
  22. else:
  23. if self.on_cpu:
  24. raise NotImplementedError
  25. else:
  26. dropout2d(input_vals[0], 1 - self.keep_prob,
  27. output_val, self.seed, stream_handle)
  28. def gradient(self, output_grad):
  29. return [dropout2d_gradient_op(output_grad, self.keep_prob, self, ctx=self.raw_ctx)]
  30. def infer_shape(self, input_shapes):
  31. return input_shapes[0]
  32. class Dropout2d_GradientOp(Op):
  33. def __init__(self, node_in, keep_prob, forward_node, ctx=None):
  34. super().__init__(Dropout2d_GradientOp, [node_in], ctx)
  35. self.forward_node = forward_node
  36. self.keep_prob = keep_prob
  37. def compute(self, input_vals, output_val, stream_handle=None):
  38. if self.on_cpu:
  39. raise NotImplementedError
  40. else:
  41. dropout2d_gradient(
  42. input_vals[0], 1 - self.keep_prob, output_val, self.forward_node.seed, stream_handle)
  43. def gradient(self, output_grad):
  44. raise NotImplementedError
  45. def infer_shape(self, input_shapes):
  46. return input_shapes[0]
  47. def dropout2d_op(node_in, keep_prob, ctx=None):
  48. """Drops elements of input variable randomly.
  49. Parameters:
  50. ----
  51. node_in : Node
  52. Input variable.
  53. keep_prob : float
  54. Probability of the results to be kept.
  55. Returns:
  56. ----
  57. A new Node instance created by Op.
  58. """
  59. return Dropout2dOp(node_in, keep_prob, ctx=ctx)
  60. def dropout2d_gradient_op(node_in, keep_prob, forward_node, ctx=None):
  61. """Gradient node of dropout2d operation.
  62. Parameters:
  63. ----
  64. node_in : Node
  65. Input variable.
  66. keep_prob : float
  67. Probability of the results to be kept.
  68. Returns:
  69. ----
  70. A new Node instance created by Op.
  71. """
  72. return Dropout2d_GradientOp(node_in, keep_prob, forward_node, ctx=ctx)