You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Dispatch.py 1.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from __future__ import absolute_import
  2. from .Node import Op
  3. class DispatchOp(Op):
  4. def __init__(self, node, parts, duplicate=1):
  5. super().__init__(DispatchOp, [node], None)
  6. self.parts = parts
  7. self.duplicate = duplicate
  8. def compute(self, input_vals, output_val, stream_handle=None):
  9. assert False, "This Op should be replaced in preprocessing phase."
  10. def gradient(self, output_grad):
  11. return [dispatch_gradient(output_grad, self.inputs[0])]
  12. def infer_shape(self, input_shapes):
  13. assert False, "This Op should be replaced in preprocessing phase."
  14. class DispatchGradientOp(Op):
  15. def __init__(self, node, forward_input):
  16. super().__init__(DispatchGradientOp, [node, forward_input], None)
  17. def compute(self, input_vals, output_val, stream_handle=None):
  18. assert False, "This Op should be replaced in preprocessing phase."
  19. def gradient(self, output_grad):
  20. raise NotImplementedError
  21. def infer_shape(self, input_shapes):
  22. assert False, "This Op should be replaced in preprocessing phase."
  23. def dispatch(node, parts, duplicate=1):
  24. """Dispatch a node into several parts, so the nodes following up can use model parallel.
  25. Parameters:
  26. ----
  27. node : Node
  28. The input Node.
  29. parts: tuple
  30. Indicates number of partitions in each dimension.
  31. Returns:
  32. ----
  33. A new Node instance created by Op.
  34. """
  35. return DispatchOp(node, parts, duplicate)
  36. def dispatch_gradient(node, forward_input):
  37. """Gradient node for Dispatch.
  38. Parameters:
  39. ----
  40. node : Node
  41. The input Node.
  42. forward_input: Node
  43. The original input node in forward phase.
  44. Returns:
  45. ----
  46. A new Node instance created by Op.
  47. """
  48. return DispatchGradientOp(node, forward_input)