You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Relu.py 2.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from .._base import DNNL_LIB
  5. from ..cpu_links import relu as cpu_relu
  6. from ..cpu_links import relu_gradient as cpu_relu_gradient
  7. from ..gpu_links import relu
  8. from ..gpu_links import relu_gradient
  9. class ReluOp(Op):
  10. def __init__(self, node_A, ctx=None):
  11. super().__init__(ReluOp, [node_A], ctx)
  12. def compute(self, input_vals, output_val, stream_handle=None):
  13. if self.on_cpu:
  14. if DNNL_LIB['DnnlRelu']:
  15. cpu_relu(input_vals[0], output_val)
  16. else:
  17. output_val[:] = np.maximum(input_vals[0].asnumpy(), 0)
  18. else:
  19. relu(input_vals[0], output_val, stream_handle)
  20. def gradient(self, output_grad):
  21. return [relu_gradient_op(self.inputs[0], output_grad, ctx=self.raw_ctx)]
  22. def infer_shape(self, input_shapes):
  23. assert len(input_shapes) == 1
  24. return input_shapes[0]
  25. class ReluGradientOp(Op):
  26. def __init__(self, node_A, node_B, ctx=None):
  27. super().__init__(ReluGradientOp, [node_A, node_B], ctx)
  28. def compute(self, input_vals, output_val, stream_handle=None):
  29. if self.on_cpu:
  30. if DNNL_LIB['DnnlRelu_Gradient']:
  31. cpu_relu_gradient(input_vals[0], input_vals[1], output_val)
  32. # heaviside function, 0.5 at x=0
  33. else:
  34. output_val[:] = (np.sign(input_vals[0].asnumpy()) +
  35. 1) * 0.5 * input_vals[1].asnumpy()
  36. else:
  37. relu_gradient(input_vals[0], input_vals[1],
  38. output_val, stream_handle)
  39. def gradient(self, output_grad):
  40. raise NotImplementedError
  41. def infer_shape(self, input_shapes):
  42. assert len(input_shapes) == 2
  43. return input_shapes[0]
  44. def relu_op(node, ctx=None):
  45. """Rectified Linear Unit.
  46. Parameters:
  47. ----
  48. node : Node
  49. Input variable.
  50. Returns:
  51. ----
  52. A new Node instance created by Op.
  53. """
  54. return ReluOp(node, ctx=ctx)
  55. def relu_gradient_op(node_A, node_B, ctx=None):
  56. """Computes the gradient of the ReLU function.
  57. Parameters:
  58. ----
  59. node_A : Node
  60. Relu input.
  61. node_B : Node
  62. Previous gradient node.
  63. Returns:
  64. ----
  65. A new Node instance created by Op.
  66. """
  67. return ReluGradientOp(node_A, node_B, ctx=ctx)