|
- from .ndarray import cpu, gpu, rcpu, rgpu, DLContext, is_gpu_ctx
- import contextlib
- import re
-
-
- class DeviceGroup(object):
- def __init__(self, ctxs):
- self._contexts = self.parse_contexts(ctxs)
- self.get_servers_n_workers()
-
- @classmethod
- def parse_contexts(cls, ctxs):
- if isinstance(ctxs, DeviceGroup):
- return ctxs
- if isinstance(ctxs, str):
- ctxs = re.split(';|,| +', ctxs.lower())
- if not isinstance(ctxs, list):
- ctxs = [ctxs]
- new_ctxs = []
- for c in ctxs:
- if isinstance(c, tuple):
- c = tuple([ccc for ccc in [cls.str2ctx(cc)
- for cc in c] if ccc is not None])
- else:
- c = cls.str2ctx(c)
- if c is not None:
- new_ctxs.append(c)
- return new_ctxs
-
- @classmethod
- def str2ctx(cls, c):
- if isinstance(c, str):
- c = c.lower().split(':')
- assert c[-2] in ('cpu', 'gpu'), 'Context invalid: %s' % c
- hostname = 'localhost' if len(c) == 2 else c[0]
- idx = int(c[-1])
- c = rcpu(hostname, idx) if c[-2] == 'cpu' else rgpu(hostname, idx)
- assert isinstance(c, DLContext), 'Context invalid: %s' % c
- return c
-
- def index(self, ctx):
- return self._contexts.index(ctx)
-
- def __getitem__(self, key):
- return self._contexts[key]
-
- def __iter__(self):
- return iter(self._contexts)
-
- def __len__(self):
- return len(self._contexts)
-
- @property
- def worker_num(self):
- return len(self._workers)
-
- @property
- def server_num(self):
- return len(self._servers)
-
- @property
- def workers(self):
- return self._workers
-
- @property
- def servers(self):
- return self._servers
-
- def get_servers_n_workers(self):
- self._workers = []
- self._servers = []
- for ctx in self._contexts:
- if isinstance(ctx, tuple) or is_gpu_ctx(ctx):
- self._workers.append(ctx)
- else:
- self._servers.append(ctx)
-
- def __repr__(self):
- result = 'DeviceGroup('
- for c in self._contexts:
- result += ('(' + ', '.join([str(cc) for cc in c]) +
- '), ') if isinstance(c, tuple) else '%s, ' % c
- result += ')'
- return result
-
- def __hash__(self):
- if not hasattr(self, 'hash'):
- self.hash = hash(
- tuple(sorted(self._contexts, key=lambda x: x.device_id)))
- return self.hash
-
- def __eq__(self, other):
- return hash(self) == hash(other)
-
-
- class ContextStack(object):
- def __init__(self):
- self._stack = []
-
- def peek(self):
- return self._stack[-1] if self._stack else None
-
- def push(self, ctx):
- return self._stack.append(ctx)
-
- def pop(self):
- self._stack.pop()
-
-
- _default_ctx_stack = ContextStack()
-
-
- def get_current_context():
- return _default_ctx_stack.peek()
-
-
- @contextlib.contextmanager
- def context(ctx):
- try:
- ctx = DeviceGroup(ctx)
- _default_ctx_stack.push(ctx)
- yield ctx
- finally:
- _default_ctx_stack.pop()
-
-
- def check_worker(ctx):
- # if the context is GPU or is a tuple (which means model parallel),
- # we regard it as a worker
- return isinstance(ctx, tuple) or is_gpu_ctx(ctx)
-
-
- def get_launch_config_by_traverse_nodes(node_list, default_ctx):
- node_strategy = dict()
- devices = set()
- for ctx in default_ctx:
- if isinstance(ctx, tuple):
- devices.update(ctx)
- else:
- devices.add(ctx)
- launchPS = default_ctx.server_num > 0
- launchMPI = (not launchPS) and default_ctx.worker_num > 1
- nrank = default_ctx.worker_num
- for node in node_list:
- traverse_dfs(node, node_strategy, devices, nrank)
- launchPS = launchPS or any([x == 'PS' for x in node_strategy.values()])
- launchMPI = launchMPI or any(
- [x == 'AllReduce' for x in node_strategy.values()])
- return launchMPI, launchPS, node_strategy, devices
-
-
- def traverse_dfs(node, node_strategy, devices, nrank):
- if node in node_strategy:
- return
- strategy = None
- if node.raw_ctx is not None and node.raw_ctx.server_num > 0 and node.raw_ctx.worker_num > 0:
- strategy = 'PS'
- elif node.raw_ctx is not None and node.raw_ctx.worker_num > 1:
- strategy = 'AllReduce'
- node_strategy[node] = strategy
- for ctx in node.raw_ctx:
- if isinstance(ctx, tuple):
- devices.update(ctx)
- else:
- devices.add(ctx)
- local_nrank = nrank if node.raw_ctx is None else node.raw_ctx.worker_num
- assert local_nrank in (
- 0, nrank), 'Number of workers not consist: (%d, %d).' % (local_nrank, nrank)
- for n in node.inputs:
- traverse_dfs(n, node_strategy, devices, nrank)
-
-
- def assign_context_by_traverse_nodes(node_list, ctx, mpi_comm, p2p_stream):
- from .dataloader import DataloaderOp
- from .optimizer import OptimizerOp
- from .gpu_ops.PipelineSend import pipeline_send_op
- from .gpu_ops.PipelineReceive import pipeline_receive_op
- from .gpu_ops.Variable import PlaceholderOp
- from .gpu_ops.Dispatch import DispatchOp, DispatchGradientOp
- from .gpu_ops.Concat import concat_op
- from .gpu_ops.Split import split_op
- from .gpu_ops.AddElewise import add_op
-
- def receive_model_parallel(prev_input, node):
- # assert dp_index_map[prev_input] < 0 and dp_index_map[node] >= 0
- dev_pos = dp_index_map[node]
- if isinstance(node.raw_ctx.workers[dev_pos], tuple):
- # here we receive from a node on one device dispatching to many
- # in this case current node MUST have mp_index, and the split will be handled in sending
- assert mp_index_map[node] >= 0, 'Now only support 1 to N.'
- hostname = prev_input.raw_ctx.workers[dev_pos].hostname
- target_id = prev_input.raw_ctx.workers[dev_pos].device_id
- if prev_input not in recv_src:
- recv_src[prev_input] = pipeline_receive_op(mpi_comm.getRankFromDevice(
- hostname, target_id), mpi_comm, stream=p2p_stream, ctx=ctx)
- return recv_src[prev_input]
- else:
- # here we receive from a node on multiple devices
- # in this case current node MUST NOT have mp_index, and handle the combination
- target = node_tar_states_map[prev_input]
- assert mp_index_map[node] < 0 and (target is None or all(
- [ts == 1 for ts in target])), 'Now only support N to 1.'
- if prev_input not in recv_src:
- device_index = -1
-
- def make_comb(devices, cur_states, depth):
- if depth == len(cur_states):
- nonlocal device_index
- device_index += 1
- return pipeline_receive_op(mpi_comm.getRankFromDevice(devices[device_index].hostname, devices[device_index].device_id), mpi_comm, stream=p2p_stream, ctx=ctx)
- else:
- result = make_comb(devices, cur_states, depth + 1)
- for _ in range(1, cur_states[depth]):
- result = concat_op(result, make_comb(
- devices, cur_states, depth + 1), axis=depth, ctx=ctx)
- return result
- res = make_comb(
- prev_input.raw_ctx.workers[dev_pos], node_cur_states_map[prev_input], 0)
- for _ in range(1, node_cur_duplicate_map.get(prev_input, 1)):
- res = add_op(res, make_comb(
- prev_input.raw_ctx.workers[dev_pos], node_cur_states_map[prev_input], 0), ctx=ctx)
- assert device_index + \
- 1 == len(prev_input.raw_ctx.workers[dev_pos])
- recv_src[prev_input] = res
- return recv_src[prev_input]
-
- def send_model_parallel(prev_input, node):
- # assert dp_index_map[prev_input] >= 0 and dp_index_map[node] < 0
- dev_pos = dp_index_map[prev_input]
- if not isinstance(prev_input.raw_ctx.workers[dev_pos], tuple):
- # here we send from a node on one device dispatching to many nodes
- # in this case current node MUST have mp_index, and the split will be handled in sending
- assert mp_index_map[prev_input] < 0, 'Now only support 1 to N.'
- device_index = 0
-
- def make_split(devices, target_states, cur_states, depth):
- if len(target_states) == depth:
- nonlocal device_index
- hostname = devices[device_index].hostname
- target_id = devices[device_index].device_id
- device_index += 1
- key = (prev_input, target_id)
- if key not in send_dst:
- cur_node = prev_input if all([x == 1 for x in target_states]) else split_op(
- prev_input, list(range(len(target_states))), list(cur_states), list(target_states), ctx=ctx)
- target_rank = mpi_comm.getRankFromDevice(
- hostname, target_id)
- send_dst[key] = pipeline_send_op(
- cur_node, target_rank, mpi_comm, stream=p2p_stream, ctx=ctx)
- my_eval_nodes.append(send_dst[key])
- else:
- for ts in range(target_states[depth]):
- cur_states[depth] = ts
- make_split(devices, target_states,
- cur_states, depth + 1)
- for _ in range(node_tar_duplicate_map.get(prev_input, 1)):
- cur_states = [0 for _ in range(
- len(node_tar_states_map[prev_input]))]
- make_split(
- node.raw_ctx.workers[dev_pos], node_tar_states_map[prev_input], cur_states, 0)
- assert device_index == len(node.raw_ctx.workers[dev_pos])
- else:
- # here we send from a node on multiple devices to one node
- # in this case current node MUST NOT have mp_index, and the combination will be handled in receiving
- target = node_tar_states_map[prev_input]
- assert mp_index_map[prev_input] >= 0 and (target is None or all(
- [ts == 1 for ts in target])), 'Now only support N to 1.'
- hostname = node.raw_ctx.workers[dev_pos].hostname
- target_id = node.raw_ctx.workers[dev_pos].device_id
- key = (prev_input, target_id)
- if key not in send_dst:
- send_dst[key] = pipeline_send_op(prev_input, mpi_comm.getRankFromDevice(
- hostname, target_id), mpi_comm, stream=p2p_stream, ctx=ctx)
- my_eval_nodes.append(send_dst[key])
-
- def assign_ctx(node):
- if node in dp_index_map:
- return
- mp_index_map[node] = -1
- dp_index_map[node] = -1
- if isinstance(node, DataloaderOp):
- return
- elif isinstance(node, OptimizerOp):
- nonlocal opt
- assert opt is None, 'Multiple optimizer is invalid.'
- opt = node
- for n in node.inputs:
- assign_ctx(n)
- grads = []
- original_params = node.optimizer.params
- for ind, param in enumerate(original_params):
- ori_grad = node.inputs[ind]
- if param in trainable_params:
- new_grad = receive_model_parallel(ori_grad.inputs[0], param) if isinstance(
- ori_grad, (DispatchOp, DispatchGradientOp)) else ori_grad
- grads.append(new_grad)
- elif isinstance(ori_grad, (DispatchOp, DispatchGradientOp)):
- real_input = ori_grad.inputs[0]
- my_pos = dp_index_map[real_input]
- if my_pos >= 0:
- send_model_parallel(ori_grad.inputs[0], param)
- if trainable_params:
- # indices = [original_params.index(param) for param in trainable_params]
- node.optimizer.params = trainable_params
- # grads = [node.inputs[index] for index in indices]
- node.inputs = grads
- node.ctx = ctx
- my_eval_nodes.append(node)
- elif isinstance(node, DispatchOp):
- real_node = node.inputs[0]
- assign_ctx(real_node)
- node_tar_states_map[real_node] = node.parts
- node_tar_duplicate_map[real_node] = node.duplicate
- elif isinstance(node, DispatchGradientOp):
- real_node = node.inputs[0]
- assign_ctx(real_node)
- assign_ctx(node.inputs[1])
- node_tar_states_map[real_node] = node_cur_states_map.get(
- node.inputs[1], None)
- node_tar_duplicate_map[real_node] = node_cur_duplicate_map.get(
- node.inputs[1], 1)
- else:
- # now we only support SAME model parallel in data parallel
- # and 1 context can only appear once
- mp_index = -1
- dp_index = -1
- for i, c in enumerate(node.raw_ctx.workers):
- if isinstance(c, tuple) and ctx in c:
- mp_index = c.index(ctx)
- dp_index = i
- elif ctx == c:
- dp_index = i
- mp_index_map[node] = mp_index
- dp_index_map[node] = dp_index
- need_states_deduction = False
- for i, n in enumerate(node.inputs):
- if isinstance(n, DataloaderOp):
- if dp_index >= 0 and n in node_list and n not in my_eval_nodes:
- my_eval_nodes.append(n)
- continue
- assign_ctx(n)
-
- # we assume that in model parallel + data parallel mode,
- # devices number of each stage is equal
- # the device in correspondent place will communicate with each other
- # TODO: not support following case: context(1,5) -> context(5,1); context(1,5) -> context(3,1)
- # solution: modify following is_my_node logic to support
- # TODO: not support the case that each process has different group init numbers, since there is an AllGather in mpi_nccl_comm's init
- # solution: modify mpi_nccl_comm class, so that the MPI part only process once while nccl has several groups
- assert node.raw_ctx.worker_num == n.raw_ctx.worker_num, \
- 'In pipeline + data parallel, devices number of each stage should be equal!'
-
- if isinstance(n, (DispatchOp, DispatchGradientOp)):
- need_states_deduction = True
- # here we only allow pipeline + model parallel, which means the devices are all different
- # TODO: release the constraint above
- # here in every context each device appear only once
- # TODO: consider whether or not release the constraint above?
- # here we only allow one2n/n2one/n2n, can not change from x to y where x != 1 and y != 1 and x != y in dimension-granularity
- # TODO: consider whether or not release the constraint above? too complex and not realistic!
- real_input = n.inputs[0]
- if dp_index >= 0 and dp_index_map[real_input] < 0:
- node.inputs[i] = receive_model_parallel(
- real_input, node)
- elif dp_index < 0 and dp_index_map[real_input] >= 0:
- send_model_parallel(real_input, node)
- else:
- assert mp_index < 0 and mp_index_map[n] < 0
- # handle receiving
- if dp_index >= 0 and dp_index != dp_index_map[n]:
- my_pos = dp_index
- hostname = n.raw_ctx.workers[my_pos].hostname
- target_id = n.raw_ctx.workers[my_pos].device_id
- if n not in recv_src:
- recv_src[n] = pipeline_receive_op(mpi_comm.getRankFromDevice(
- hostname, target_id), mpi_comm, stream=p2p_stream, ctx=ctx)
- node.inputs[i] = recv_src[n]
- # handle sending
- if dp_index_map[n] >= 0 and dp_index != dp_index_map[n]:
- my_pos = dp_index_map[n]
- hostname = node.raw_ctx.workers[my_pos].hostname
- target_id = node.raw_ctx.workers[my_pos].device_id
- key = (n, target_id)
- if key not in send_dst:
- send_dst[key] = pipeline_send_op(n, mpi_comm.getRankFromDevice(
- hostname, target_id), mpi_comm, stream=p2p_stream, ctx=ctx)
- my_eval_nodes.append(send_dst[key])
-
- if dp_index >= 0:
- node.ctx = ctx
- if node in node_list:
- my_eval_nodes.append(node)
- if isinstance(node, PlaceholderOp) and node.trainable:
- trainable_params.append(node)
- if need_states_deduction:
- input_states = []
- input_duplicates = []
- for n in node.inputs:
- if isinstance(n, (DispatchOp, DispatchGradientOp)):
- input_states.append(node_tar_states_map[n.inputs[0]])
- input_duplicates.append(
- node_tar_duplicate_map[n.inputs[0]])
- else:
- input_states.append(node_cur_states_map.get(n, None))
- input_duplicates.append(
- node_cur_duplicate_map.get(n, 1))
- node_cur_states_map[node], node_cur_duplicate_map[node] = node.deduce_states(
- input_states, input_duplicates)
-
- opt = None
- trainable_params = []
- send_dst = {}
- recv_src = {}
- mp_index_map = {} # model parallel index
- dp_index_map = {} # data parallel index
- node_cur_duplicate_map = {} # save nodes' duplicate information
- node_tar_duplicate_map = {} # save nodes' target states
- node_cur_states_map = {} # save nodes' current states
- node_tar_states_map = {} # save nodes' target states
- my_eval_nodes = []
- for node in node_list:
- assign_ctx(node)
-
- has_send_recv = send_dst != {} or recv_src != {}
- return my_eval_nodes, trainable_params, has_send_recv
|