|
- from __future__ import absolute_import
- import numpy as np
- from .Node import Op
- from ..gpu_links import where
-
-
- class WhereOp(Op):
- def __init__(self, cond, node_A, node_B, ctx=None):
- super().__init__(WhereOp, [cond, node_A, node_B], ctx)
-
- def compute(self, input_vals, output_val, stream_handle=None):
- if self.on_cpu:
- output_val[:] = np.where(input_vals[0].asnumpy(
- ), input_vals[1].asnumpy(), input_vals[2].asnumpy())
- else:
-
- where(input_vals[0], input_vals[1],
- input_vals[2], output_val, stream_handle)
-
- def gradient(self, output_grad):
- from .ZerosLike import zeroslike_op
- zeros = zeroslike_op(self.inputs[0], ctx=self.raw_ctx)
- grad_A = where_op(self.inputs[0], output_grad, zeros, ctx=self.raw_ctx)
- grad_B = where_op(self.inputs[0], zeros, output_grad, ctx=self.raw_ctx)
- return [None, grad_A, grad_B]
-
- def infer_shape(self, input_shapes):
- assert len(input_shapes) == 3
- assert tuple(input_shapes[0]) == tuple(
- input_shapes[1]) == tuple(input_shapes[2])
- return input_shapes[0]
-
-
- def where_op(cond, node_A, node_B, ctx=None):
- """Creates a node that represents np.where.
-
- Parameters:
- ----
- cond : Node of a condition array
- node_A : Node, output if cond
- node_B : Node, output if not cond
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return WhereOp(cond, node_A, node_B, ctx=ctx)
|