|
- from __future__ import absolute_import
- import numpy as np
- from .. import ndarray
- from .. import stream
- from ..context import get_current_context, DeviceGroup
- G_NODE_ID = 0
-
-
- class Op(object):
- """Node in a computation graph."""
-
- def __init__(self, op_type, inputs, ctx=None):
- """Constructor
- Instance variables
- ------------------
- self.inputs: the list of input nodes.
- self.const_attr: the add or multiply constant.
- e.g. self.const_attr=5 if this node is created by x+5.
- self.name: node name for debugging.
- """
- self.inputs = inputs
- self.raw_ctx = get_current_context() if ctx is None else DeviceGroup(ctx)
- self.ctx = ctx
- self.const_attr = None
- self.dtype = None
- self.inplace = False
- self.lazy_execution = False
- self.event = None
- self.op_type = op_type.__name__
- global G_NODE_ID
- self.id = G_NODE_ID
- G_NODE_ID = G_NODE_ID + 1
- self.name = self.op_type + str(self.id)
- self.desc = self.name + \
- '(' + ', '.join([inp.name for inp in inputs]) + ')'
-
- def __add__(self, other):
- """Adding two nodes return a new node."""
- from .AddElewise import add_op
- from .AddConst import addbyconst_op
-
- # here the operator does NOT specify context
- # please explicitly specify the context in gradients!!
- if isinstance(other, Op):
- new_node = add_op(self, other)
- else:
- # Add by a constant stores the constant in new node's const_attr
- # 'other' argument is a constant
- new_node = addbyconst_op(self, other)
- return new_node
-
- def __mul__(self, other):
- """Multiplying two nodes return a new node."""
- from .MultiplyElewise import mul_op
- from .MultiplyConst import mul_byconst_op
-
- if isinstance(other, Op):
- new_node = mul_op(self, other)
- else:
- # Mul by a constant stores the constant in new node's const_attr
- # 'other' argument is a constant
- new_node = mul_byconst_op(self, other)
- return new_node
-
- # Allow left-hand-side add and multiply.
- __radd__ = __add__
- __rmul__ = __mul__
-
- def __str__(self):
- """Allow print to display node name."""
- return self.name
-
- def compute(self, input_vals, output_val, stream_handle=None):
- """Given values of input nodes, compute the output value.
- Parameters
- ----------
- node: node that performs the compute.
- input_vals: values of input nodes.
- output_val: output value of the node, modified in-place.
- """
- raise NotImplementedError
-
- def gradient(self, output_grad):
- """Given output gradient, compute partial gradient to each input node.
- Parameters
- ----------
- node: node that performs the gradient.
- output_grad: output gradient summed from children nodes' contributions
- Returns
- -------
- A list of gradient contributions to each input node respectively.
- """
- raise NotImplementedError
-
- def infer_shape(self, input_shapes):
- """Given shapes of input nodes, compute shape of output node.
- Implementation note:
- It's simpler to treat shape of constants as (1,), so that constants can
- be stored as a numpy array too and you would need fewer special case
- handling.
- Parameters
- ----------
- node: node whose shape is being inferred.
- input_vals: shapes of input nodes.
- Returns
- -------
- A tuple representing the shape of output node.
- """
- raise NotImplementedError
-
- def add_transfer_op(self, src_node, dst_ctx, h2d_ops, d2h_ops):
- from .DataTransfer import datah2d_op, datad2h_op, datad2h_sparse_op
-
- def add_h2d(prev_node, cur_ctx):
- if prev_node not in h2d_ops:
- h2d_ops[prev_node] = datah2d_op(prev_node, cur_ctx)
- return h2d_ops[prev_node]
-
- def add_d2h(prev_node):
- from .EmbeddingLookUp import EmbeddingLookUp_Gradient
- if prev_node not in d2h_ops:
- if isinstance(prev_node, EmbeddingLookUp_Gradient):
- d2h_ops[prev_node] = datad2h_sparse_op(prev_node)
- else:
- d2h_ops[prev_node] = datad2h_op(prev_node)
- if prev_node.event is None:
- # here we should ensure the computation complete before d2h
- prev_node.event = stream.create_event_handle(prev_node.ctx)
- return d2h_ops[prev_node]
- src_ctx = src_node.ctx
- result = src_node
- if src_ctx != dst_ctx:
- if ndarray.is_gpu_ctx(dst_ctx):
- if ndarray.is_gpu_ctx(src_ctx):
- assert False, 'Please use NCCL to P2P communicate!'
- else:
- result = add_h2d(result, dst_ctx)
- else:
- result = add_d2h(result)
- return result
-
- def forward_hook(self, config):
- # disable inplace if not lazy execution
- # previously we use array reshape lazy callback to do this, which is deprecated (not efficient)
- if not self.lazy_execution:
- for node in self.inputs:
- node.inplace = False
-
- # insert data transfer op if needed
- input_ctxs = set([n.ctx for n in self.inputs])
- assert None not in input_ctxs, 'Inputs contexts should already be determined.'
- if self.ctx is None:
- self.ctx = config.context
- for i in range(len(self.inputs)):
- self.inputs[i] = self.add_transfer_op(
- self.inputs[i], self.ctx, config.h2d_ops, config.d2h_ops)
- self.on_gpu = ndarray.is_gpu_ctx(self.ctx)
- self.on_cpu = not self.on_gpu
- if self in config.eval_node_list and self.on_gpu and self.event is None:
- self.event = stream.create_event_handle(self.ctx)
-
- def backward_hook(self, config):
- pass
-
- def deduce_states(self, input_states, input_duplicates):
- assert len(input_states) == len(self.inputs)
- assert len(input_states) == len(input_duplicates)
- if len(input_states) == 1:
- return input_states[0], input_duplicates[0]
- else:
- assert all([x is None or x == (1, 1) for x in input_states])
- return None, 1
|