from __future__ import absolute_import from .Node import Op class DispatchOp(Op): def __init__(self, node, parts, duplicate=1): super().__init__(DispatchOp, [node], None) self.parts = parts self.duplicate = duplicate def compute(self, input_vals, output_val, stream_handle=None): assert False, "This Op should be replaced in preprocessing phase." def gradient(self, output_grad): return [dispatch_gradient(output_grad, self.inputs[0])] def infer_shape(self, input_shapes): assert False, "This Op should be replaced in preprocessing phase." class DispatchGradientOp(Op): def __init__(self, node, forward_input): super().__init__(DispatchGradientOp, [node, forward_input], None) def compute(self, input_vals, output_val, stream_handle=None): assert False, "This Op should be replaced in preprocessing phase." def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): assert False, "This Op should be replaced in preprocessing phase." def dispatch(node, parts, duplicate=1): """Dispatch a node into several parts, so the nodes following up can use model parallel. Parameters: ---- node : Node The input Node. parts: tuple Indicates number of partitions in each dimension. Returns: ---- A new Node instance created by Op. """ return DispatchOp(node, parts, duplicate) def dispatch_gradient(node, forward_input): """Gradient node for Dispatch. Parameters: ---- node : Node The input Node. forward_input: Node The original input node in forward phase. Returns: ---- A new Node instance created by Op. """ return DispatchGradientOp(node, forward_input)