You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Node.py 6.5 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .. import ndarray
  4. from .. import stream
  5. from ..context import get_current_context, DeviceGroup
  6. G_NODE_ID = 0
  7. class Op(object):
  8. """Node in a computation graph."""
  9. def __init__(self, op_type, inputs, ctx=None):
  10. """Constructor
  11. Instance variables
  12. ------------------
  13. self.inputs: the list of input nodes.
  14. self.const_attr: the add or multiply constant.
  15. e.g. self.const_attr=5 if this node is created by x+5.
  16. self.name: node name for debugging.
  17. """
  18. self.inputs = inputs
  19. self.raw_ctx = get_current_context() if ctx is None else DeviceGroup(ctx)
  20. self.ctx = ctx
  21. self.const_attr = None
  22. self.dtype = None
  23. self.inplace = False
  24. self.lazy_execution = False
  25. self.event = None
  26. self.op_type = op_type.__name__
  27. global G_NODE_ID
  28. self.id = G_NODE_ID
  29. G_NODE_ID = G_NODE_ID + 1
  30. self.name = self.op_type + str(self.id)
  31. self.desc = self.name + \
  32. '(' + ', '.join([inp.name for inp in inputs]) + ')'
  33. def __add__(self, other):
  34. """Adding two nodes return a new node."""
  35. from .AddElewise import add_op
  36. from .AddConst import addbyconst_op
  37. # here the operator does NOT specify context
  38. # please explicitly specify the context in gradients!!
  39. if isinstance(other, Op):
  40. new_node = add_op(self, other)
  41. else:
  42. # Add by a constant stores the constant in new node's const_attr
  43. # 'other' argument is a constant
  44. new_node = addbyconst_op(self, other)
  45. return new_node
  46. def __mul__(self, other):
  47. """Multiplying two nodes return a new node."""
  48. from .MultiplyElewise import mul_op
  49. from .MultiplyConst import mul_byconst_op
  50. if isinstance(other, Op):
  51. new_node = mul_op(self, other)
  52. else:
  53. # Mul by a constant stores the constant in new node's const_attr
  54. # 'other' argument is a constant
  55. new_node = mul_byconst_op(self, other)
  56. return new_node
  57. # Allow left-hand-side add and multiply.
  58. __radd__ = __add__
  59. __rmul__ = __mul__
  60. def __str__(self):
  61. """Allow print to display node name."""
  62. return self.name
  63. def compute(self, input_vals, output_val, stream_handle=None):
  64. """Given values of input nodes, compute the output value.
  65. Parameters
  66. ----------
  67. node: node that performs the compute.
  68. input_vals: values of input nodes.
  69. output_val: output value of the node, modified in-place.
  70. """
  71. raise NotImplementedError
  72. def gradient(self, output_grad):
  73. """Given output gradient, compute partial gradient to each input node.
  74. Parameters
  75. ----------
  76. node: node that performs the gradient.
  77. output_grad: output gradient summed from children nodes' contributions
  78. Returns
  79. -------
  80. A list of gradient contributions to each input node respectively.
  81. """
  82. raise NotImplementedError
  83. def infer_shape(self, input_shapes):
  84. """Given shapes of input nodes, compute shape of output node.
  85. Implementation note:
  86. It's simpler to treat shape of constants as (1,), so that constants can
  87. be stored as a numpy array too and you would need fewer special case
  88. handling.
  89. Parameters
  90. ----------
  91. node: node whose shape is being inferred.
  92. input_vals: shapes of input nodes.
  93. Returns
  94. -------
  95. A tuple representing the shape of output node.
  96. """
  97. raise NotImplementedError
  98. def add_transfer_op(self, src_node, dst_ctx, h2d_ops, d2h_ops):
  99. from .DataTransfer import datah2d_op, datad2h_op, datad2h_sparse_op
  100. def add_h2d(prev_node, cur_ctx):
  101. if prev_node not in h2d_ops:
  102. h2d_ops[prev_node] = datah2d_op(prev_node, cur_ctx)
  103. return h2d_ops[prev_node]
  104. def add_d2h(prev_node):
  105. from .EmbeddingLookUp import EmbeddingLookUp_Gradient
  106. if prev_node not in d2h_ops:
  107. if isinstance(prev_node, EmbeddingLookUp_Gradient):
  108. d2h_ops[prev_node] = datad2h_sparse_op(prev_node)
  109. else:
  110. d2h_ops[prev_node] = datad2h_op(prev_node)
  111. if prev_node.event is None:
  112. # here we should ensure the computation complete before d2h
  113. prev_node.event = stream.create_event_handle(prev_node.ctx)
  114. return d2h_ops[prev_node]
  115. src_ctx = src_node.ctx
  116. result = src_node
  117. if src_ctx != dst_ctx:
  118. if ndarray.is_gpu_ctx(dst_ctx):
  119. if ndarray.is_gpu_ctx(src_ctx):
  120. assert False, 'Please use NCCL to P2P communicate!'
  121. else:
  122. result = add_h2d(result, dst_ctx)
  123. else:
  124. result = add_d2h(result)
  125. return result
  126. def forward_hook(self, config):
  127. # disable inplace if not lazy execution
  128. # previously we use array reshape lazy callback to do this, which is deprecated (not efficient)
  129. if not self.lazy_execution:
  130. for node in self.inputs:
  131. node.inplace = False
  132. # insert data transfer op if needed
  133. input_ctxs = set([n.ctx for n in self.inputs])
  134. assert None not in input_ctxs, 'Inputs contexts should already be determined.'
  135. if self.ctx is None:
  136. self.ctx = config.context
  137. for i in range(len(self.inputs)):
  138. self.inputs[i] = self.add_transfer_op(
  139. self.inputs[i], self.ctx, config.h2d_ops, config.d2h_ops)
  140. self.on_gpu = ndarray.is_gpu_ctx(self.ctx)
  141. self.on_cpu = not self.on_gpu
  142. if self in config.eval_node_list and self.on_gpu and self.event is None:
  143. self.event = stream.create_event_handle(self.ctx)
  144. def backward_hook(self, config):
  145. pass
  146. def deduce_states(self, input_states, input_duplicates):
  147. assert len(input_states) == len(self.inputs)
  148. assert len(input_states) == len(input_duplicates)
  149. if len(input_states) == 1:
  150. return input_states[0], input_duplicates[0]
  151. else:
  152. assert all([x is None or x == (1, 1) for x in input_states])
  153. return None, 1