You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BinaryCrossEntropy.py 2.5 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from ..gpu_links import binary_cross_entropy
  5. from ..gpu_links import binary_cross_entropy_gradient
  6. class BinaryCrossEntropyOp(Op):
  7. def __init__(self, prediction, label, ctx=None):
  8. super().__init__(BinaryCrossEntropyOp, [prediction, label], ctx)
  9. def compute(self, input_vals, output_val, stream_handle=None):
  10. if self.on_cpu:
  11. y = input_vals[0].asnumpy()
  12. y_ = input_vals[1].asnumpy()
  13. output_val[:] = -y_ * np.log(y) - (1 - y_) * np.log(1 - y)
  14. else:
  15. binary_cross_entropy(
  16. input_vals[0], input_vals[1], output_val, stream_handle)
  17. def gradient(self, output_grad):
  18. grad_A = binarycrossentropy_gradient_op(
  19. self.inputs[0], self.inputs[1], output_grad, ctx=self.raw_ctx)
  20. grad_B = None
  21. return [grad_A, grad_B]
  22. def infer_shape(self, input_shapes):
  23. assert len(input_shapes) == 2
  24. assert len(input_shapes[0]) >= 2
  25. return input_shapes[0]
  26. class BinaryCrossEntropyGradientOp(Op):
  27. def __init__(self, prediction, label, output_grad_node, ctx=None):
  28. super().__init__(BinaryCrossEntropyGradientOp, [
  29. prediction, label, output_grad_node], ctx)
  30. def compute(self, input_vals, output_val, stream_handle=None):
  31. if self.on_cpu:
  32. y = input_vals[0].asnumpy()
  33. y_ = input_vals[1].asnumpy()
  34. output_grad = input_vals[2].asnumpy()
  35. output_val[:] = (- y_/y + (1 - y_)/(1-y))*output_grad
  36. else:
  37. binary_cross_entropy_gradient(
  38. input_vals[0], input_vals[1], input_vals[2], output_val, stream_handle)
  39. def gradient(self, output_grad):
  40. raise NotImplementedError
  41. def infer_shape(self, input_shapes):
  42. assert len(input_shapes) == 3
  43. return input_shapes[0]
  44. def binarycrossentropy_op(node_A, node_B, ctx=None):
  45. """Computes cross entropy loss for pre-softmax activations.
  46. Parameters:
  47. ----
  48. node_A : Node
  49. Predicted probability.
  50. node_B : Node
  51. Labels.
  52. Returns:
  53. ----
  54. A new Node instance created by Op.
  55. """
  56. return BinaryCrossEntropyOp(node_A, node_B, ctx=ctx)
  57. def binarycrossentropy_gradient_op(node_A, node_B, node_C, ctx=None):
  58. return BinaryCrossEntropyGradientOp(node_A, node_B, node_C, ctx=ctx)