You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Softmax.py 2.7 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from .._base import DNNL_LIB
  5. from ..cpu_links import softmax as cpu_softmax
  6. from ..gpu_links import CuDNN_softmax
  7. from ..gpu_links import CuDNN_softmax_gradient
  8. def softmax_func(y):
  9. """Numerically stable softmax."""
  10. b = y - np.max(y, axis=-1, keepdims=True)
  11. expb = np.exp(b)
  12. softmax = expb / np.sum(expb, axis=-1, keepdims=True)
  13. return softmax
  14. def softmax_gradient_func(y, dy):
  15. dx = y * (dy - (dy * y).sum(axis=-1, keepdims=True))
  16. return dx
  17. class SoftmaxOp(Op):
  18. def __init__(self, node_A, ctx=None):
  19. super().__init__(SoftmaxOp, [node_A], ctx)
  20. def compute(self, input_vals, output_val, stream_handle=None):
  21. if self.on_cpu:
  22. if DNNL_LIB['DnnlSoftmax']:
  23. cpu_softmax(input_vals[0], output_val)
  24. else:
  25. output_val[:] = softmax_func(input_vals[0].asnumpy())
  26. else:
  27. CuDNN_softmax(input_vals[0], output_val, stream_handle)
  28. def gradient(self, output_grad):
  29. # Do not directly use SoftmaxOp, use SoftmaxCrossEntropyOp instead.
  30. # Not allowing taking 2nd derivative of SoftmaxCrossEntropyOp.
  31. return [softmax_gradient_op(self, output_grad, ctx=self.raw_ctx)]
  32. def infer_shape(self, input_shapes):
  33. assert len(input_shapes) == 1
  34. return input_shapes[0]
  35. class SoftmaxGradientOp(Op):
  36. def __init__(self, node_y, grad, ctx=None):
  37. super().__init__(SoftmaxGradientOp, [node_y, grad], ctx)
  38. def compute(self, input_vals, output_val, stream_handle=None):
  39. if self.on_cpu:
  40. output_val[:] = softmax_gradient_func(
  41. input_vals[0].asnumpy(), input_vals[1].asnumpy())
  42. else:
  43. CuDNN_softmax_gradient(
  44. input_vals[0], input_vals[1], output_val, stream_handle)
  45. def gradient(self, output_grad):
  46. raise NotImplementedError
  47. def infer_shape(self, input_shapes):
  48. assert len(input_shapes) == 2
  49. return input_shapes[0]
  50. def softmax_op(node, ctx=None):
  51. """ This function computes its softmax along an axis.
  52. Parameters:
  53. ----
  54. node : Node
  55. Input variable.
  56. Returns:
  57. ----
  58. A new Node instance created by Op.
  59. """
  60. return SoftmaxOp(node, ctx=ctx)
  61. def softmax_gradient_op(node_y, grad, ctx=None):
  62. """ This function computes softmax gradient.
  63. Parameters:
  64. ----
  65. node_y: Node
  66. Output variable of forward softmax.
  67. grad: Node
  68. Gradient variable, dy.
  69. Returns:
  70. ----
  71. A new Node instance created by Op.
  72. """
  73. return SoftmaxGradientOp(node_y, grad, ctx=ctx)