You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

OneHot.py 1.2 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from ..gpu_links import one_hot
  5. class OneHotOp(Op):
  6. def __init__(self, node_A, num_classes, ctx=None):
  7. super().__init__(OneHotOp, [node_A], ctx)
  8. self.num_classes = num_classes
  9. def compute(self, input_vals, output_val, stream_handle=None):
  10. if self.on_cpu:
  11. inputs = input_vals[0].asnumpy().astype(np.int)
  12. res = np.eye(self.num_classes)[inputs.reshape(-1)]
  13. output_val[:] = res.reshape(
  14. list(inputs.shape) + [self.num_classes]).astype(np.float32)
  15. else:
  16. one_hot(input_vals[0], output_val, stream_handle)
  17. def gradient(self, output_grad):
  18. return [None]
  19. def infer_shape(self, input_shapes):
  20. assert len(input_shapes) == 1
  21. return tuple(list(input_shapes[0]) + [self.num_classes])
  22. def one_hot_op(node, num_classes, ctx=None):
  23. """Creates a node that represents one hot.
  24. Parameters:
  25. ----
  26. node : Node
  27. The input Node.
  28. num_classes: int
  29. Number of classes.
  30. Returns:
  31. ----
  32. A new Node instance created by Op.
  33. """
  34. return OneHotOp(node, num_classes, ctx=ctx)