|
- from __future__ import absolute_import
- import numpy as np
- from .Node import Op
- from ..gpu_links import one_hot
-
-
- class OneHotOp(Op):
- def __init__(self, node_A, num_classes, ctx=None):
- super().__init__(OneHotOp, [node_A], ctx)
- self.num_classes = num_classes
-
- def compute(self, input_vals, output_val, stream_handle=None):
- if self.on_cpu:
- inputs = input_vals[0].asnumpy().astype(np.int)
- res = np.eye(self.num_classes)[inputs.reshape(-1)]
- output_val[:] = res.reshape(
- list(inputs.shape) + [self.num_classes]).astype(np.float32)
- else:
- one_hot(input_vals[0], output_val, stream_handle)
-
- def gradient(self, output_grad):
- return [None]
-
- def infer_shape(self, input_shapes):
- assert len(input_shapes) == 1
- return tuple(list(input_shapes[0]) + [self.num_classes])
-
-
- def one_hot_op(node, num_classes, ctx=None):
- """Creates a node that represents one hot.
-
- Parameters:
- ----
- node : Node
- The input Node.
- num_classes: int
- Number of classes.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return OneHotOp(node, num_classes, ctx=ctx)
|